├── README.md ├── __pycache__ ├── metrics.cpython-36.pyc └── utils.cpython-36.pyc ├── convert_albert_original_tf_checkpoint_to_pytorch.py ├── dataset ├── README.md └── THUNews │ ├── 5_100 │ ├── README.MD │ ├── dev.csv │ ├── test.csv │ └── train.csv │ ├── 5_5000 │ ├── dev.csv │ ├── test.csv │ └── train.csv │ └── cnews.vocab.txt ├── metrics.py ├── pretrained_models └── README.md ├── results └── README.md ├── run.py ├── run_classifier.sh ├── runs ├── Jan08_17-12-40_zhan │ └── events.out.tfevents.1578474760.zhan.4862.0 ├── Jan08_17-16-15_zhan │ └── events.out.tfevents.1578474975.zhan.5610.0 ├── Jan08_17-18-23_zhan │ └── events.out.tfevents.1578475103.zhan.5690.0 ├── Jan08_21-51-13_zhan │ └── events.out.tfevents.1578491473.zhan.7325.0 ├── Jan08_21-53-20_zhan │ └── events.out.tfevents.1578491600.zhan.7395.0 ├── Jan08_22-11-34_zhan │ └── events.out.tfevents.1578492694.zhan.8340.0 ├── Jan08_22-14-28_zhan │ └── events.out.tfevents.1578492868.zhan ├── Jan14_11-19-53_zhan │ └── events.out.tfevents.1578971993.zhan ├── Jan14_11-22-12_zhan │ └── events.out.tfevents.1578972132.zhan ├── Jan14_11-29-42_zhan │ └── events.out.tfevents.1578972582.zhan ├── Jan14_11-35-03_zhan │ └── events.out.tfevents.1578972903.zhan └── README.MD ├── transformers ├── __init__.py ├── __main__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── configuration_albert.cpython-36.pyc │ ├── configuration_albert.cpython-37.pyc │ ├── configuration_auto.cpython-36.pyc │ ├── configuration_auto.cpython-37.pyc │ ├── configuration_bert.cpython-36.pyc │ ├── configuration_bert.cpython-37.pyc │ ├── configuration_camembert.cpython-36.pyc │ ├── configuration_camembert.cpython-37.pyc │ ├── configuration_ctrl.cpython-36.pyc │ ├── configuration_ctrl.cpython-37.pyc │ ├── configuration_distilbert.cpython-36.pyc │ ├── configuration_distilbert.cpython-37.pyc │ ├── configuration_gpt2.cpython-36.pyc │ ├── configuration_gpt2.cpython-37.pyc │ ├── configuration_openai.cpython-36.pyc │ ├── configuration_openai.cpython-37.pyc │ ├── configuration_roberta.cpython-36.pyc │ ├── configuration_roberta.cpython-37.pyc │ ├── configuration_t5.cpython-36.pyc │ ├── configuration_t5.cpython-37.pyc │ ├── configuration_transfo_xl.cpython-36.pyc │ ├── configuration_transfo_xl.cpython-37.pyc │ ├── configuration_utils.cpython-36.pyc │ ├── configuration_utils.cpython-37.pyc │ ├── configuration_xlm.cpython-36.pyc │ ├── configuration_xlm.cpython-37.pyc │ ├── configuration_xlnet.cpython-36.pyc │ ├── configuration_xlnet.cpython-37.pyc │ ├── file_utils.cpython-36.pyc │ ├── file_utils.cpython-37.pyc │ ├── model_card.cpython-36.pyc │ ├── model_card.cpython-37.pyc │ ├── modeling_albert.cpython-36.pyc │ ├── modeling_auto.cpython-36.pyc │ ├── modeling_bert.cpython-36.pyc │ ├── modeling_camembert.cpython-36.pyc │ ├── modeling_ctrl.cpython-36.pyc │ ├── modeling_distilbert.cpython-36.pyc │ ├── modeling_encoder_decoder.cpython-36.pyc │ ├── modeling_gpt2.cpython-36.pyc │ ├── modeling_openai.cpython-36.pyc │ ├── modeling_roberta.cpython-36.pyc │ ├── modeling_t5.cpython-36.pyc │ ├── modeling_tf_pytorch_utils.cpython-36.pyc │ ├── modeling_transfo_xl.cpython-36.pyc │ ├── modeling_transfo_xl_utilities.cpython-36.pyc │ ├── modeling_utils.cpython-36.pyc │ ├── modeling_xlm.cpython-36.pyc │ ├── modeling_xlnet.cpython-36.pyc │ ├── optimization.cpython-36.pyc │ ├── tokenization_albert.cpython-36.pyc │ ├── tokenization_auto.cpython-36.pyc │ ├── tokenization_auto.cpython-37.pyc │ ├── tokenization_bert.cpython-36.pyc │ ├── tokenization_bert.cpython-37.pyc │ ├── tokenization_bert_japanese.cpython-36.pyc │ ├── tokenization_bert_japanese.cpython-37.pyc │ ├── tokenization_camembert.cpython-36.pyc │ ├── tokenization_camembert.cpython-37.pyc │ ├── tokenization_ctrl.cpython-36.pyc │ ├── tokenization_ctrl.cpython-37.pyc │ ├── tokenization_distilbert.cpython-36.pyc │ ├── tokenization_distilbert.cpython-37.pyc │ ├── tokenization_gpt2.cpython-36.pyc │ ├── tokenization_gpt2.cpython-37.pyc │ ├── tokenization_openai.cpython-36.pyc │ ├── tokenization_openai.cpython-37.pyc │ ├── tokenization_roberta.cpython-36.pyc │ ├── tokenization_roberta.cpython-37.pyc │ ├── tokenization_t5.cpython-36.pyc │ ├── tokenization_transfo_xl.cpython-36.pyc │ ├── tokenization_transfo_xl.cpython-37.pyc │ ├── tokenization_utils.cpython-36.pyc │ ├── tokenization_utils.cpython-37.pyc │ ├── tokenization_xlm.cpython-36.pyc │ ├── tokenization_xlm.cpython-37.pyc │ ├── tokenization_xlnet.cpython-36.pyc │ └── tokenization_xlnet.cpython-37.pyc ├── commands │ ├── __init__.py │ └── user.py ├── configuration_albert.py ├── configuration_albert_backup.py ├── configuration_auto.py ├── configuration_bert.py ├── configuration_camembert.py ├── configuration_ctrl.py ├── configuration_distilbert.py ├── configuration_gpt2.py ├── configuration_openai.py ├── configuration_roberta.py ├── configuration_t5.py ├── configuration_transfo_xl.py ├── configuration_utils.py ├── configuration_xlm.py ├── configuration_xlnet.py ├── convert_albert_original_tf_checkpoint_to_pytorch.py ├── convert_bert_original_tf_checkpoint_to_pytorch.py ├── convert_bert_pytorch_checkpoint_to_original_tf.py ├── convert_gpt2_original_tf_checkpoint_to_pytorch.py ├── convert_openai_original_tf_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf2.py ├── convert_roberta_original_pytorch_checkpoint_to_pytorch.py ├── convert_t5_original_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py ├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py ├── convert_xlnet_original_tf_checkpoint_to_pytorch.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-37.pyc │ ├── metrics │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── __init__.cpython-37.pyc │ │ └── squad_metrics.py │ └── processors │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── glue.cpython-36.pyc │ │ ├── glue.cpython-37.pyc │ │ ├── squad.cpython-36.pyc │ │ ├── squad.cpython-37.pyc │ │ ├── utils.cpython-36.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── xnli.cpython-36.pyc │ │ └── xnli.cpython-37.pyc │ │ ├── glue.py │ │ ├── squad.py │ │ ├── utils.py │ │ └── xnli.py ├── file_utils.py ├── hf_api.py ├── init__.py ├── model_card.py ├── modeling_albert.py ├── modeling_albert_backup.py ├── modeling_albert_backup1.py ├── modeling_auto.py ├── modeling_bert.py ├── modeling_camembert.py ├── modeling_ctrl.py ├── modeling_distilbert.py ├── modeling_encoder_decoder.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_roberta.py ├── modeling_t5.py ├── modeling_tf_albert.py ├── modeling_tf_auto.py ├── modeling_tf_bert.py ├── modeling_tf_ctrl.py ├── modeling_tf_distilbert.py ├── modeling_tf_gpt2.py ├── modeling_tf_openai.py ├── modeling_tf_pytorch_utils.py ├── modeling_tf_roberta.py ├── modeling_tf_t5.py ├── modeling_tf_transfo_xl.py ├── modeling_tf_transfo_xl_utilities.py ├── modeling_tf_utils.py ├── modeling_tf_xlm.py ├── modeling_tf_xlnet.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlnet.py ├── optimization.py ├── optimization_tf.py ├── tests │ ├── __init__.py │ ├── configuration_common_test.py │ ├── fixtures │ │ ├── empty.txt │ │ ├── input.txt │ │ ├── sample_text.txt │ │ ├── spiece.model │ │ └── test_sentencepiece.model │ ├── hf_api_test.py │ ├── model_card_test.py │ ├── modeling_albert_test.py │ ├── modeling_auto_test.py │ ├── modeling_bert_test.py │ ├── modeling_common_test.py │ ├── modeling_ctrl_test.py │ ├── modeling_distilbert_test.py │ ├── modeling_encoder_decoder_test.py │ ├── modeling_gpt2_test.py │ ├── modeling_openai_test.py │ ├── modeling_roberta_test.py │ ├── modeling_t5_test.py │ ├── modeling_tf_albert_test.py │ ├── modeling_tf_auto_test.py │ ├── modeling_tf_bert_test.py │ ├── modeling_tf_common_test.py │ ├── modeling_tf_ctrl_test.py │ ├── modeling_tf_distilbert_test.py │ ├── modeling_tf_gpt2_test.py │ ├── modeling_tf_openai_gpt_test.py │ ├── modeling_tf_roberta_test.py │ ├── modeling_tf_t5_test.py │ ├── modeling_tf_transfo_xl_test.py │ ├── modeling_tf_xlm_test.py │ ├── modeling_tf_xlnet_test.py │ ├── modeling_transfo_xl_test.py │ ├── modeling_xlm_test.py │ ├── modeling_xlnet_test.py │ ├── optimization_test.py │ ├── optimization_tf_test.py │ ├── tokenization_albert_test.py │ ├── tokenization_auto_test.py │ ├── tokenization_bert_japanese_test.py │ ├── tokenization_bert_test.py │ ├── tokenization_ctrl_test.py │ ├── tokenization_distilbert_test.py │ ├── tokenization_gpt2_test.py │ ├── tokenization_openai_test.py │ ├── tokenization_roberta_test.py │ ├── tokenization_t5_test.py │ ├── tokenization_tests_commons.py │ ├── tokenization_transfo_xl_test.py │ ├── tokenization_utils_test.py │ ├── tokenization_xlm_test.py │ ├── tokenization_xlnet_test.py │ └── utils.py ├── tokenization_albert.py ├── tokenization_albert_backup.py ├── tokenization_auto.py ├── tokenization_bert.py ├── tokenization_bert_japanese.py ├── tokenization_camembert.py ├── tokenization_ctrl.py ├── tokenization_distilbert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_roberta.py ├── tokenization_t5.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py └── tokenization_xlnet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Transformers_for_Text_Classification 2 | 3 | # 基于Transformers的文本分类 4 | 5 | 基于最新的 [huggingface](https://github.com/huggingface) 出品的 [transformers](https://github.com/huggingface/transformers/releases/tag/v2.2.2) v2.2.2代码进行重构。为了保证代码日后可以直接复现而不出现兼容性问题,这里将 [transformers](https://github.com/huggingface/transformers/releases/tag/v2.2.2) 放在本地进行调用。 6 | 7 | 8 | 9 | # Highlights 10 | 11 | - 支持transformer模型后接各种特征提取器 12 | - 支持测试集预测代码 13 | - 精简原始transformers代码,使之更适合文本分类任务 14 | - 优化logging终端输出,使之输出内容更加合理 15 | 16 | 17 | 18 | # Support 19 | 20 | **model_type:** 21 | 22 | - [x] bert 23 | - [x] bert_cnn 24 | - [x] bert_lstm 25 | - [x] bert_gru 26 | - [x] xlnet 27 | - [ ] xlnet_cnn 28 | - [x] xlnet_lstm 29 | - [x] xlnet_gru 30 | - [ ] albert 31 | 32 | 33 | 34 | # Content 35 | 36 | - dataset:存放数据集 37 | - pretrained_models:存放预训练模型 38 | - transformers:transformers文件夹 39 | - results:存放训练结果 40 | 41 | 42 | 43 | # Usage 44 | 45 | ## 1. 使用不同模型 46 | 47 | **在shell文件中修改`model_type`参数即可指定模型** 48 | 49 | 如,BERT后接FC全连接层,则直接设置`model_type=bert`;BERT后接CNN卷积层,则设置`model_type=bert_cnn`. 50 | 51 | 在本README的`Support`中列出了本项目中各个预训练模型支持的`model_type`。 52 | 53 | 最后,在终端直接运行shell文件即可,如: 54 | 55 | ``` 56 | bash run_classifier.sh 57 | ``` 58 | 59 | **注**:**在中文RoBERTa、ERNIE、BERT_wwm这三种预训练语言模型中,均使用BERT的model_type进行加载。** 60 | 61 | 62 | 63 | ## 2. 使用自定义数据集 64 | 65 | 1. 在`dataset`文件夹里存放自定义的数据集文件夹,如`TestData`. 66 | 2. 在根目录下的`utils.py`中,仿照`class THUNewsProcessor`写一个自己的类,如命名为`class TestDataProcessor`,并在`tasks_num_labels`, `processors`, `output_modes`三个dict中添加相应内容. 67 | 3. 最后,在你需要运行的shell文件中修改TASK_NAME为你的任务名称,如`TestData`. 68 | 69 | 70 | 71 | # Environment 72 | 73 | - one 2080Ti, 12GB RAM 74 | - Python: 3.6.5 75 | - PyTorch: 1.3.1 76 | 77 | - TensorFlow: 1.14.0(仅为了支持TensorBoard,无其他作用) 78 | - Numpy: 1.14.6 79 | 80 | 81 | 82 | # Performance 83 | 84 | 数据集: THUNews/5_5000 85 | 86 | epoch:1 87 | 88 | train_steps: 5000 89 | 90 | | model | dev set best F1 and Acc | remark | 91 | | ------------------ | -------------------------- | ----------------------------------------------- | 92 | | bert_base | 0.9308869881728941, 0.9324 | BERT接FC层, batch_size 8, learning_rate 2e-5 | 93 | | bert_base+cnn | 0.9136314735833212, 0.9156 | BERT接CNN层, batch_size 8, learning_rate 2e-5 | 94 | | bert_base+lstm | 0.9369254464106703, 0.9372 | BERT接LSTM层, batch_size 8, learning_rate 2e-5 | 95 | | bert_base+gru | 0.9379539112313108, 0.938 | BERT接GRU层, batch_size 8, learning_rate 2e-5 | 96 | | roberta_large | | RoBERTa接FC层, batch_size 2, learning_rate 2e-5 | 97 | | xlnet_mid | 0.9530066512880131, 0.954 | XLNet接FC层, batch_size 2, learning_rate 2e-5 | 98 | | xlnet_mid+lstm | 0.9269927348553552, 0.9304 | XLNet接LSTM层, batch_size 2, learning_rate 2e-5 | 99 | | xlnet_mid+gru | 0.9494631023945569, 0.9508 | XLNet接GRU层, batch_size 2, learning_rate 2e-5 | 100 | | albert_xlarge_183k | | | 101 | 102 | 103 | 104 | # Download Chinese Pre-trained Models 105 | 106 | [NPL_PEMDC](https://github.com/zhanlaoban/NLP_PEMDC) 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /convert_albert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ALBERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = AlbertConfig.from_json_file(albert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = AlbertForMaskedLM(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_albert(model, config, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path to the TensorFlow checkpoint path.") 52 | parser.add_argument("--albert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained ALBERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.albert_config_file, 66 | args.pytorch_dump_path) 67 | 68 | ''' 69 | python convert_albert_original_tf_checkpoint_to_pytorch.py \ 70 | --tf_checkpoint_path=/home/zhan/zyy/Github/Transformers_for_Text_Classification/pretrained_models/ALBERT/google/albert_base \ 71 | --albert_config_file=/home/zhan/zyy/Github/Transformers_for_Text_Classification/pretrained_models/ALBERT/google/albert_base \ 72 | --pytorch_dump_path=/home/zhan/zyy/Github/Transformers_for_Text_Classification/pretrained_models/ALBERT/google/albert_base/pytorch_model.bin 73 | ''' -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | 在这里存在数据 -------------------------------------------------------------------------------- /dataset/THUNews/5_100/README.MD: -------------------------------------------------------------------------------- 1 | 共5类 2 | 每类100条数据,按train/dev/test 划分为 80/10/10 3 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | try: 8 | from scipy.stats import pearsonr, spearmanr 9 | from sklearn.metrics import matthews_corrcoef, f1_score 10 | _has_sklearn = True 11 | except (AttributeError, ImportError) as e: 12 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 13 | _has_sklearn = False 14 | 15 | def is_sklearn_available(): 16 | return _has_sklearn 17 | 18 | if _has_sklearn: 19 | 20 | def simple_accuracy(preds, labels): 21 | return (preds == labels).mean() 22 | 23 | 24 | def acc_and_f1(preds, labels): 25 | acc = simple_accuracy(preds, labels) 26 | f1 = f1_score(y_true=labels, y_pred=preds, average='macro') 27 | return { 28 | "acc": acc, 29 | "f1": f1, 30 | "acc_and_f1": (acc + f1) / 2, 31 | } 32 | 33 | 34 | def pearson_and_spearman(preds, labels): 35 | pearson_corr = pearsonr(preds, labels)[0] 36 | spearman_corr = spearmanr(preds, labels)[0] 37 | return { 38 | "pearson": pearson_corr, 39 | "spearmanr": spearman_corr, 40 | "corr": (pearson_corr + spearman_corr) / 2, 41 | } 42 | 43 | 44 | def compute_metrics(task_name, preds, labels): 45 | assert len(preds) == len(labels) 46 | if task_name == "thunews": 47 | return {"acc": simple_accuracy(preds, labels)} 48 | else: 49 | raise KeyError(task_name) -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | 存放模型文件的文件夹 2 | 3 | 模型文件下载地址: [05_Transformers/README.MD](https://github.com/zhanlaoban/Text_Classification/blob/master/05_Transformers/README.MD) 4 | 5 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | 存放验证结果、预测结果文件 -------------------------------------------------------------------------------- /run_classifier.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | TASK_NAME="THUNews" 4 | 5 | python run.py \ 6 | --task_name=$TASK_NAME \ 7 | --model_type=albert \ 8 | --model_name_or_path ./pretrained_models/albert_xlarge_183k \ 9 | --data_dir ./dataset/THUNews/5_5000 \ 10 | --output_dir ./results/THUNews/albert \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_predict \ 14 | --do_lower_case \ 15 | --max_seq_length=512 \ 16 | --per_gpu_train_batch_size=1 \ 17 | --per_gpu_eval_batch_size=16 \ 18 | --gradient_accumulation_steps=2 \ 19 | --learning_rate=2e-5 \ 20 | --num_train_epochs=1.0 \ 21 | --logging_steps=14923 \ 22 | --save_steps=14923 \ 23 | --overwrite_output_dir \ 24 | --filter_sizes='3,4,5' \ 25 | --filter_num=256 \ 26 | --lstm_layers=1 \ 27 | --lstm_hidden_size=512 \ 28 | --lstm_dropout=0.1 \ 29 | --gru_layers=1 \ 30 | --gru_hidden_size=512 \ 31 | --gru_dropout=0.1 \ 32 | 33 | 34 | 35 | # 每一个epoch保存一次 36 | # 每一个epoch评估一次 -------------------------------------------------------------------------------- /runs/Jan08_17-12-40_zhan/events.out.tfevents.1578474760.zhan.4862.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_17-12-40_zhan/events.out.tfevents.1578474760.zhan.4862.0 -------------------------------------------------------------------------------- /runs/Jan08_17-16-15_zhan/events.out.tfevents.1578474975.zhan.5610.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_17-16-15_zhan/events.out.tfevents.1578474975.zhan.5610.0 -------------------------------------------------------------------------------- /runs/Jan08_17-18-23_zhan/events.out.tfevents.1578475103.zhan.5690.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_17-18-23_zhan/events.out.tfevents.1578475103.zhan.5690.0 -------------------------------------------------------------------------------- /runs/Jan08_21-51-13_zhan/events.out.tfevents.1578491473.zhan.7325.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_21-51-13_zhan/events.out.tfevents.1578491473.zhan.7325.0 -------------------------------------------------------------------------------- /runs/Jan08_21-53-20_zhan/events.out.tfevents.1578491600.zhan.7395.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_21-53-20_zhan/events.out.tfevents.1578491600.zhan.7395.0 -------------------------------------------------------------------------------- /runs/Jan08_22-11-34_zhan/events.out.tfevents.1578492694.zhan.8340.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_22-11-34_zhan/events.out.tfevents.1578492694.zhan.8340.0 -------------------------------------------------------------------------------- /runs/Jan08_22-14-28_zhan/events.out.tfevents.1578492868.zhan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan08_22-14-28_zhan/events.out.tfevents.1578492868.zhan -------------------------------------------------------------------------------- /runs/Jan14_11-19-53_zhan/events.out.tfevents.1578971993.zhan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan14_11-19-53_zhan/events.out.tfevents.1578971993.zhan -------------------------------------------------------------------------------- /runs/Jan14_11-22-12_zhan/events.out.tfevents.1578972132.zhan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan14_11-22-12_zhan/events.out.tfevents.1578972132.zhan -------------------------------------------------------------------------------- /runs/Jan14_11-29-42_zhan/events.out.tfevents.1578972582.zhan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan14_11-29-42_zhan/events.out.tfevents.1578972582.zhan -------------------------------------------------------------------------------- /runs/Jan14_11-35-03_zhan/events.out.tfevents.1578972903.zhan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/runs/Jan14_11-35-03_zhan/events.out.tfevents.1578972903.zhan -------------------------------------------------------------------------------- /runs/README.MD: -------------------------------------------------------------------------------- 1 | 保存模型运行信息文件 2 | -------------------------------------------------------------------------------- /transformers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_albert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_albert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_albert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_albert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_auto.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_auto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_auto.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_bert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_bert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_camembert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_camembert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_camembert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_camembert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_ctrl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_ctrl.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_ctrl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_ctrl.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_distilbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_distilbert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_distilbert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_distilbert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_openai.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_openai.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_roberta.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_roberta.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_t5.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_t5.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_t5.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_t5.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_utils.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_utils.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_xlm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_xlm.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/configuration_xlnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/configuration_xlnet.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/file_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/file_utils.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/model_card.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/model_card.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/model_card.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/model_card.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_albert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_albert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_auto.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_bert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_camembert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_camembert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_ctrl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_ctrl.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_distilbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_distilbert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_encoder_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_encoder_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_openai.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_t5.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_t5.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_tf_pytorch_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_tf_pytorch_utils.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_utils.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/modeling_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/modeling_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_albert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_albert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_auto.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_auto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_auto.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_bert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_bert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_bert_japanese.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_bert_japanese.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_bert_japanese.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_bert_japanese.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_camembert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_camembert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_camembert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_camembert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_ctrl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_ctrl.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_ctrl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_ctrl.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_distilbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_distilbert.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_distilbert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_distilbert.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_openai.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_openai.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_roberta.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_roberta.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_t5.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_t5.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_utils.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_utils.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_xlm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_xlm.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/__pycache__/tokenization_xlnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/__pycache__/tokenization_xlnet.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | class BaseTransformersCLICommand(ABC): 5 | @staticmethod 6 | @abstractmethod 7 | def register_subcommand(parser: ArgumentParser): 8 | raise NotImplementedError() 9 | 10 | @abstractmethod 11 | def run(self): 12 | raise NotImplementedError() 13 | -------------------------------------------------------------------------------- /transformers/configuration_albert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ ALBERT model configuration """ 17 | 18 | from .configuration_utils import PretrainedConfig 19 | 20 | ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 21 | 'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json", 22 | 'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json", 23 | 'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json", 24 | 'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json", 25 | 'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json", 26 | 'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json", 27 | 'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json", 28 | 'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json", 29 | } 30 | 31 | class AlbertConfig(PretrainedConfig): 32 | """Configuration for `AlbertModel`. 33 | 34 | The default settings match the configuration of model `albert_xxlarge`. 35 | """ 36 | 37 | pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 38 | 39 | def __init__(self, 40 | vocab_size=30000, 41 | embedding_size=128, 42 | hidden_size=4096, 43 | num_hidden_layers=12, 44 | num_hidden_groups=1, 45 | num_attention_heads=64, 46 | intermediate_size=16384, 47 | inner_group_num=1, 48 | hidden_act="gelu_new", 49 | hidden_dropout_prob=0, 50 | attention_probs_dropout_prob=0, 51 | max_position_embeddings=512, 52 | type_vocab_size=2, 53 | initializer_range=0.02, 54 | layer_norm_eps=1e-12, **kwargs): 55 | """Constructs AlbertConfig. 56 | 57 | Args: 58 | vocab_size: Vocabulary size of `inputs_ids` in `AlbertModel`. 59 | embedding_size: size of voc embeddings. 60 | hidden_size: Size of the encoder layers and the pooler layer. 61 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 62 | num_hidden_groups: Number of group for the hidden layers, parameters in 63 | the same group are shared. 64 | num_attention_heads: Number of attention heads for each attention layer in 65 | the Transformer encoder. 66 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 67 | layer in the Transformer encoder. 68 | inner_group_num: int, number of inner repetition of attention and ffn. 69 | down_scale_factor: float, the scale to apply 70 | hidden_act: The non-linear activation function (function or string) in the 71 | encoder and pooler. 72 | hidden_dropout_prob: The dropout probability for all fully connected 73 | layers in the embeddings, encoder, and pooler. 74 | attention_probs_dropout_prob: The dropout ratio for the attention 75 | probabilities. 76 | max_position_embeddings: The maximum sequence length that this model might 77 | ever be used with. Typically set this to something large just in case 78 | (e.g., 512 or 1024 or 2048). 79 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 80 | `AlbertModel`. 81 | initializer_range: The stdev of the truncated_normal_initializer for 82 | initializing all weight matrices. 83 | """ 84 | super(AlbertConfig, self).__init__(**kwargs) 85 | 86 | self.vocab_size = vocab_size 87 | self.embedding_size = embedding_size 88 | self.hidden_size = hidden_size 89 | self.num_hidden_layers = num_hidden_layers 90 | self.num_hidden_groups = num_hidden_groups 91 | self.num_attention_heads = num_attention_heads 92 | self.inner_group_num = inner_group_num 93 | self.hidden_act = hidden_act 94 | self.intermediate_size = intermediate_size 95 | self.hidden_dropout_prob = hidden_dropout_prob 96 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 97 | self.max_position_embeddings = max_position_embeddings 98 | self.type_vocab_size = type_vocab_size 99 | self.initializer_range = initializer_range 100 | self.layer_norm_eps = layer_norm_eps 101 | -------------------------------------------------------------------------------- /transformers/configuration_albert_backup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ ALBERT model configuration """ 17 | 18 | from .configuration_utils import PretrainedConfig 19 | 20 | ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 21 | 'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json", 22 | 'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json", 23 | 'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json", 24 | 'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json", 25 | 'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json", 26 | 'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json", 27 | 'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json", 28 | 'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json", 29 | } 30 | 31 | class AlbertConfig(PretrainedConfig): 32 | """Configuration for `AlbertModel`. 33 | 34 | The default settings match the configuration of model `albert_xxlarge`. 35 | """ 36 | 37 | pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 38 | 39 | def __init__(self, 40 | vocab_size=30000, 41 | embedding_size=128, 42 | hidden_size=4096, 43 | num_hidden_layers=12, 44 | num_hidden_groups=1, 45 | num_attention_heads=64, 46 | intermediate_size=16384, 47 | inner_group_num=1, 48 | hidden_act="gelu_new", 49 | hidden_dropout_prob=0, 50 | attention_probs_dropout_prob=0, 51 | max_position_embeddings=512, 52 | type_vocab_size=2, 53 | initializer_range=0.02, 54 | layer_norm_eps=1e-12, **kwargs): 55 | """Constructs AlbertConfig. 56 | 57 | Args: 58 | vocab_size: Vocabulary size of `inputs_ids` in `AlbertModel`. 59 | embedding_size: size of voc embeddings. 60 | hidden_size: Size of the encoder layers and the pooler layer. 61 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 62 | num_hidden_groups: Number of group for the hidden layers, parameters in 63 | the same group are shared. 64 | num_attention_heads: Number of attention heads for each attention layer in 65 | the Transformer encoder. 66 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 67 | layer in the Transformer encoder. 68 | inner_group_num: int, number of inner repetition of attention and ffn. 69 | down_scale_factor: float, the scale to apply 70 | hidden_act: The non-linear activation function (function or string) in the 71 | encoder and pooler. 72 | hidden_dropout_prob: The dropout probability for all fully connected 73 | layers in the embeddings, encoder, and pooler. 74 | attention_probs_dropout_prob: The dropout ratio for the attention 75 | probabilities. 76 | max_position_embeddings: The maximum sequence length that this model might 77 | ever be used with. Typically set this to something large just in case 78 | (e.g., 512 or 1024 or 2048). 79 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 80 | `AlbertModel`. 81 | initializer_range: The stdev of the truncated_normal_initializer for 82 | initializing all weight matrices. 83 | """ 84 | super(AlbertConfig, self).__init__(**kwargs) 85 | 86 | self.vocab_size = vocab_size 87 | self.embedding_size = embedding_size 88 | self.hidden_size = hidden_size 89 | self.num_hidden_layers = num_hidden_layers 90 | self.num_hidden_groups = num_hidden_groups 91 | self.num_attention_heads = num_attention_heads 92 | self.inner_group_num = inner_group_num 93 | self.hidden_act = hidden_act 94 | self.intermediate_size = intermediate_size 95 | self.hidden_dropout_prob = hidden_dropout_prob 96 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 97 | self.max_position_embeddings = max_position_embeddings 98 | self.type_vocab_size = type_vocab_size 99 | self.initializer_range = initializer_range 100 | self.layer_norm_eps = layer_norm_eps 101 | -------------------------------------------------------------------------------- /transformers/configuration_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ CamemBERT configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_roberta import RobertaConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", 29 | } 30 | 31 | 32 | class CamembertConfig(RobertaConfig): 33 | pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 34 | -------------------------------------------------------------------------------- /transformers/configuration_ctrl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Salesforce CTRL configuration """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import sys 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"} 29 | 30 | class CTRLConfig(PretrainedConfig): 31 | """Configuration class to store the configuration of a `CTRLModel`. 32 | 33 | Args: 34 | vocab_size: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 35 | n_positions: Number of positional embeddings. 36 | n_ctx: Size of the causal mask (usually same as n_positions). 37 | dff: Size of the inner dimension of the FFN. 38 | n_embd: Dimensionality of the embeddings and hidden states. 39 | n_layer: Number of hidden layers in the Transformer encoder. 40 | n_head: Number of attention heads for each attention layer in 41 | the Transformer encoder. 42 | layer_norm_epsilon: epsilon to use in the layer norm layers 43 | resid_pdrop: The dropout probabilitiy for all fully connected 44 | layers in the embeddings, encoder, and pooler. 45 | attn_pdrop: The dropout ratio for the attention 46 | probabilities. 47 | embd_pdrop: The dropout ratio for the embeddings. 48 | initializer_range: The sttdev of the truncated_normal_initializer for 49 | initializing all weight matrices. 50 | """ 51 | pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP 52 | 53 | def __init__( 54 | self, 55 | vocab_size=246534, 56 | n_positions=256, 57 | n_ctx=256, 58 | n_embd=1280, 59 | dff=8192, 60 | n_layer=48, 61 | n_head=16, 62 | resid_pdrop=0.1, 63 | embd_pdrop=0.1, 64 | attn_pdrop=0.1, 65 | layer_norm_epsilon=1e-6, 66 | initializer_range=0.02, 67 | summary_type='cls_index', 68 | summary_use_proj=True, 69 | summary_activation=None, 70 | summary_proj_to_labels=True, 71 | summary_first_dropout=0.1, 72 | **kwargs 73 | ): 74 | """Constructs CTRLConfig. 75 | 76 | Args: 77 | vocab_size: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 78 | n_positions: Number of positional embeddings. 79 | n_ctx: Size of the causal mask (usually same as n_positions). 80 | dff: Size of the inner dimension of the FFN. 81 | n_embd: Dimensionality of the embeddings and hidden states. 82 | n_layer: Number of hidden layers in the Transformer encoder. 83 | n_head: Number of attention heads for each attention layer in 84 | the Transformer encoder. 85 | layer_norm_epsilon: epsilon to use in the layer norm layers 86 | resid_pdrop: The dropout probabilitiy for all fully connected 87 | layers in the embeddings, encoder, and pooler. 88 | attn_pdrop: The dropout ratio for the attention 89 | probabilities. 90 | embd_pdrop: The dropout ratio for the embeddings. 91 | initializer_range: The sttdev of the truncated_normal_initializer for 92 | initializing all weight matrices. 93 | """ 94 | super(CTRLConfig, self).__init__(**kwargs) 95 | self.vocab_size = vocab_size 96 | self.n_ctx = n_ctx 97 | self.n_positions = n_positions 98 | self.n_embd = n_embd 99 | self.n_layer = n_layer 100 | self.n_head = n_head 101 | self.dff = dff 102 | self.resid_pdrop = resid_pdrop 103 | self.embd_pdrop = embd_pdrop 104 | self.attn_pdrop = attn_pdrop 105 | self.layer_norm_epsilon = layer_norm_epsilon 106 | self.initializer_range = initializer_range 107 | 108 | self.summary_type = summary_type 109 | self.summary_use_proj = summary_use_proj 110 | self.summary_activation = summary_activation 111 | self.summary_first_dropout = summary_first_dropout 112 | self.summary_proj_to_labels = summary_proj_to_labels 113 | 114 | @property 115 | def max_position_embeddings(self): 116 | return self.n_positions 117 | 118 | @property 119 | def hidden_size(self): 120 | return self.n_embd 121 | 122 | @property 123 | def num_attention_heads(self): 124 | return self.n_head 125 | 126 | @property 127 | def num_hidden_layers(self): 128 | return self.n_layer 129 | -------------------------------------------------------------------------------- /transformers/configuration_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ DistilBERT model configuration """ 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 30 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json", 31 | 'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json", 32 | 'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json", 33 | } 34 | 35 | 36 | class DistilBertConfig(PretrainedConfig): 37 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 38 | 39 | def __init__(self, 40 | vocab_size=30522, 41 | max_position_embeddings=512, 42 | sinusoidal_pos_embds=False, 43 | n_layers=6, 44 | n_heads=12, 45 | dim=768, 46 | hidden_dim=4*768, 47 | dropout=0.1, 48 | attention_dropout=0.1, 49 | activation='gelu', 50 | initializer_range=0.02, 51 | tie_weights_=True, 52 | qa_dropout=0.1, 53 | seq_classif_dropout=0.2, 54 | **kwargs): 55 | super(DistilBertConfig, self).__init__(**kwargs) 56 | self.vocab_size = vocab_size 57 | self.max_position_embeddings = max_position_embeddings 58 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 59 | self.n_layers = n_layers 60 | self.n_heads = n_heads 61 | self.dim = dim 62 | self.hidden_dim = hidden_dim 63 | self.dropout = dropout 64 | self.attention_dropout = attention_dropout 65 | self.activation = activation 66 | self.initializer_range = initializer_range 67 | self.tie_weights_ = tie_weights_ 68 | self.qa_dropout = qa_dropout 69 | self.seq_classif_dropout = seq_classif_dropout 70 | 71 | @property 72 | def hidden_size(self): 73 | return self.dim 74 | 75 | @property 76 | def num_attention_heads(self): 77 | return self.n_heads 78 | 79 | @property 80 | def num_hidden_layers(self): 81 | return self.n_layers 82 | -------------------------------------------------------------------------------- /transformers/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", 32 | "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", 33 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} 34 | 35 | class GPT2Config(PretrainedConfig): 36 | """Configuration class to store the configuration of a `GPT2Model`. 37 | 38 | Args: 39 | vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 40 | n_positions: Number of positional embeddings. 41 | n_ctx: Size of the causal mask (usually same as n_positions). 42 | n_embd: Dimensionality of the embeddings and hidden states. 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | layer_norm_epsilon: epsilon to use in the layer norm layers 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | initializer_range: The sttdev of the truncated_normal_initializer for 53 | initializing all weight matrices. 54 | """ 55 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 56 | 57 | def __init__( 58 | self, 59 | vocab_size=50257, 60 | n_positions=1024, 61 | n_ctx=1024, 62 | n_embd=768, 63 | n_layer=12, 64 | n_head=12, 65 | resid_pdrop=0.1, 66 | embd_pdrop=0.1, 67 | attn_pdrop=0.1, 68 | layer_norm_epsilon=1e-5, 69 | initializer_range=0.02, 70 | summary_type='cls_index', 71 | summary_use_proj=True, 72 | summary_activation=None, 73 | summary_proj_to_labels=True, 74 | summary_first_dropout=0.1, 75 | **kwargs 76 | ): 77 | """Constructs GPT2Config. 78 | 79 | Args: 80 | vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 81 | n_positions: Number of positional embeddings. 82 | n_ctx: Size of the causal mask (usually same as n_positions). 83 | n_embd: Dimensionality of the embeddings and hidden states. 84 | n_layer: Number of hidden layers in the Transformer encoder. 85 | n_head: Number of attention heads for each attention layer in 86 | the Transformer encoder. 87 | layer_norm_epsilon: epsilon to use in the layer norm layers 88 | resid_pdrop: The dropout probabilitiy for all fully connected 89 | layers in the embeddings, encoder, and pooler. 90 | attn_pdrop: The dropout ratio for the attention 91 | probabilities. 92 | embd_pdrop: The dropout ratio for the embeddings. 93 | initializer_range: The sttdev of the truncated_normal_initializer for 94 | initializing all weight matrices. 95 | """ 96 | super(GPT2Config, self).__init__(**kwargs) 97 | self.vocab_size = vocab_size 98 | self.n_ctx = n_ctx 99 | self.n_positions = n_positions 100 | self.n_embd = n_embd 101 | self.n_layer = n_layer 102 | self.n_head = n_head 103 | self.resid_pdrop = resid_pdrop 104 | self.embd_pdrop = embd_pdrop 105 | self.attn_pdrop = attn_pdrop 106 | self.layer_norm_epsilon = layer_norm_epsilon 107 | self.initializer_range = initializer_range 108 | self.summary_type = summary_type 109 | self.summary_use_proj = summary_use_proj 110 | self.summary_activation = summary_activation 111 | self.summary_first_dropout = summary_first_dropout 112 | self.summary_proj_to_labels = summary_proj_to_labels 113 | 114 | @property 115 | def max_position_embeddings(self): 116 | return self.n_positions 117 | 118 | @property 119 | def hidden_size(self): 120 | return self.n_embd 121 | 122 | @property 123 | def num_attention_heads(self): 124 | return self.n_head 125 | 126 | @property 127 | def num_hidden_layers(self): 128 | return self.n_layer 129 | -------------------------------------------------------------------------------- /transformers/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 31 | } 32 | 33 | class OpenAIGPTConfig(PretrainedConfig): 34 | """ 35 | Configuration class to store the configuration of a `OpenAIGPTModel`. 36 | 37 | Args: 38 | vocab_size: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. 39 | n_positions: Number of positional embeddings. 40 | n_ctx: Size of the causal mask (usually same as n_positions). 41 | n_embd: Dimensionality of the embeddings and hidden states. 42 | n_layer: Number of hidden layers in the Transformer encoder. 43 | n_head: Number of attention heads for each attention layer in 44 | the Transformer encoder. 45 | afn: The non-linear activation function (function or string) in the 46 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | layer_norm_epsilon: epsilon to use in the layer norm layers 53 | initializer_range: The sttdev of the truncated_normal_initializer for 54 | initializing all weight matrices. 55 | predict_special_tokens: should we predict special tokens (when the model has a LM head) 56 | """ 57 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 58 | 59 | def __init__( 60 | self, 61 | vocab_size=40478, 62 | n_positions=512, 63 | n_ctx=512, 64 | n_embd=768, 65 | n_layer=12, 66 | n_head=12, 67 | afn="gelu", 68 | resid_pdrop=0.1, 69 | embd_pdrop=0.1, 70 | attn_pdrop=0.1, 71 | layer_norm_epsilon=1e-5, 72 | initializer_range=0.02, 73 | predict_special_tokens=True, 74 | summary_type='cls_index', 75 | summary_use_proj=True, 76 | summary_activation=None, 77 | summary_proj_to_labels=True, 78 | summary_first_dropout=0.1, 79 | **kwargs 80 | ): 81 | """Constructs OpenAIGPTConfig. 82 | """ 83 | super(OpenAIGPTConfig, self).__init__(**kwargs) 84 | self.vocab_size = vocab_size 85 | self.n_ctx = n_ctx 86 | self.n_positions = n_positions 87 | self.n_embd = n_embd 88 | self.n_layer = n_layer 89 | self.n_head = n_head 90 | self.afn = afn 91 | self.resid_pdrop = resid_pdrop 92 | self.embd_pdrop = embd_pdrop 93 | self.attn_pdrop = attn_pdrop 94 | self.layer_norm_epsilon = layer_norm_epsilon 95 | self.initializer_range = initializer_range 96 | self.predict_special_tokens = predict_special_tokens 97 | self.summary_type = summary_type 98 | self.summary_use_proj = summary_use_proj 99 | self.summary_activation = summary_activation 100 | self.summary_first_dropout = summary_first_dropout 101 | self.summary_proj_to_labels = summary_proj_to_labels 102 | 103 | @property 104 | def max_position_embeddings(self): 105 | return self.n_positions 106 | 107 | @property 108 | def hidden_size(self): 109 | return self.n_embd 110 | 111 | @property 112 | def num_attention_heads(self): 113 | return self.n_head 114 | 115 | @property 116 | def num_hidden_layers(self): 117 | return self.n_layer 118 | -------------------------------------------------------------------------------- /transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_bert import BertConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 31 | 'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", 32 | 'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", 33 | 'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", 34 | } 35 | 36 | 37 | class RobertaConfig(BertConfig): 38 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 39 | -------------------------------------------------------------------------------- /transformers/configuration_t5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2010, The T5 Authors and HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ T5 model configuration """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import sys 22 | import six 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json", 31 | 't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json", 32 | 't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json", 33 | 't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json", 34 | 't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json", 35 | } 36 | 37 | 38 | class T5Config(PretrainedConfig): 39 | r""" 40 | :class:`~transformers.T5Config` is the configuration class to store the configuration of a 41 | `T5Model`. 42 | 43 | 44 | Arguments: 45 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `T5Model`. 46 | hidden_size: Size of the encoder layers and the pooler layer. 47 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 48 | num_attention_heads: Number of attention heads for each attention layer in 49 | the Transformer encoder. 50 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 51 | layer in the Transformer encoder. 52 | hidden_act: The non-linear activation function (function or string) in the 53 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 54 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 55 | layers in the embeddings, encoder, and pooler. 56 | attention_probs_dropout_prob: The dropout ratio for the attention 57 | probabilities. 58 | max_position_embeddings: The maximum sequence length that this model might 59 | ever be used with. Typically set this to something large just in case 60 | (e.g., 512 or 1024 or 2048). 61 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 62 | `T5Model`. 63 | initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing). 64 | layer_norm_eps: The epsilon used by LayerNorm. 65 | """ 66 | pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP 67 | 68 | def __init__(self, 69 | vocab_size=32128, 70 | n_positions=512, 71 | d_model=512, 72 | d_kv=64, 73 | d_ff=2048, 74 | num_layers=6, 75 | num_heads=8, 76 | relative_attention_num_buckets=32, 77 | dropout_rate=0.1, 78 | layer_norm_epsilon=1e-6, 79 | initializer_factor=1.0, 80 | **kwargs): 81 | super(T5Config, self).__init__(**kwargs) 82 | self.vocab_size = vocab_size 83 | self.n_positions = n_positions 84 | self.d_model = d_model 85 | self.d_kv = d_kv 86 | self.d_ff = d_ff 87 | self.num_layers = num_layers 88 | self.num_heads = num_heads 89 | self.relative_attention_num_buckets = relative_attention_num_buckets 90 | self.dropout_rate = dropout_rate 91 | self.layer_norm_epsilon = layer_norm_epsilon 92 | self.initializer_factor = initializer_factor 93 | 94 | @property 95 | def max_position_embeddings(self): 96 | return self.n_positions 97 | 98 | @property 99 | def hidden_size(self): 100 | return self.d_model 101 | 102 | @property 103 | def num_attention_heads(self): 104 | return self.num_heads 105 | 106 | @property 107 | def num_hidden_layers(self): 108 | return self.num_layers 109 | -------------------------------------------------------------------------------- /transformers/convert_albert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ALBERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = AlbertConfig.from_json_file(albert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = AlbertForMaskedLM(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_albert(model, config, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path to the TensorFlow checkpoint path.") 52 | parser.add_argument("--albert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained ALBERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.albert_config_file, 66 | args.pytorch_dump_path) 67 | 68 | ''' 69 | python transformers/convert_albert_original_tf_checkpoint_to_pytorch.py \ 70 | --tf_checkpoint_path=/home/zhan/zyy/Github/Transformers_for_Text_Classification/pretrained_models/ALBERT/google/albert_base\ 71 | --albert_config_file=/home/zhan/zyy/Github/Transformers_for_Text_Classification/pretrained_models/ALBERT/google/albert_base\ 72 | --pytorch_dump_path=/home/zhan/zyy/Github/Transformers_for_Text_Classification/pretrained_models/ALBERT/google/albert_base/pytorch_model.bin 73 | ''' -------------------------------------------------------------------------------- /transformers/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /transformers/convert_bert_pytorch_checkpoint_to_original_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from transformers import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /transformers/convert_openai_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /transformers/convert_t5_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The T5 authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert T5 checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import T5Config, T5Model, load_tf_weights_in_t5 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = T5Config.from_json_file(config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = T5Model(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_t5(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained T5 model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import transformers.tokenization_transfo_xl as data_utils 27 | 28 | from transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from transformers import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from transformers import CONFIG_NAME, WEIGHTS_NAME 27 | from transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | state_dict = chkpt['model'] 37 | 38 | # We have the base model one level deeper than the original XLM repository 39 | two_levels_state_dict = {} 40 | for k, v in state_dict.items(): 41 | if 'pred_layer' in k: 42 | two_levels_state_dict[k] = v 43 | else: 44 | two_levels_state_dict['transformer.' + k] = v 45 | 46 | config = chkpt['params'] 47 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 48 | 49 | vocab = chkpt['dico_word2id'] 50 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 51 | 52 | # Save pytorch-model 53 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 54 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 56 | 57 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 58 | torch.save(two_levels_state_dict, pytorch_weights_dump_path) 59 | 60 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 61 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 62 | f.write(json.dumps(config, indent=2) + "\n") 63 | 64 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 65 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 66 | f.write(json.dumps(vocab, indent=2) + "\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | ## Required parameters 72 | parser.add_argument("--xlm_checkpoint_path", 73 | default = None, 74 | type = str, 75 | required = True, 76 | help = "Path the official PyTorch dump.") 77 | parser.add_argument("--pytorch_dump_folder_path", 78 | default = None, 79 | type = str, 80 | required = True, 81 | help = "Path to the output PyTorch model.") 82 | args = parser.parse_args() 83 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 84 | -------------------------------------------------------------------------------- /transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures 2 | from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor 4 | from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels 5 | 6 | from .metrics import is_sklearn_available 7 | if is_sklearn_available(): 8 | from .metrics import glue_compute_metrics, xnli_compute_metrics 9 | -------------------------------------------------------------------------------- /transformers/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | try: 24 | from scipy.stats import pearsonr, spearmanr 25 | from sklearn.metrics import matthews_corrcoef, f1_score 26 | _has_sklearn = True 27 | except (AttributeError, ImportError) as e: 28 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 29 | _has_sklearn = False 30 | 31 | def is_sklearn_available(): 32 | return _has_sklearn 33 | 34 | if _has_sklearn: 35 | 36 | def simple_accuracy(preds, labels): 37 | return (preds == labels).mean() 38 | 39 | 40 | def acc_and_f1(preds, labels): 41 | acc = simple_accuracy(preds, labels) 42 | f1 = f1_score(y_true=labels, y_pred=preds) 43 | return { 44 | "acc": acc, 45 | "f1": f1, 46 | "acc_and_f1": (acc + f1) / 2, 47 | } 48 | 49 | 50 | def pearson_and_spearman(preds, labels): 51 | pearson_corr = pearsonr(preds, labels)[0] 52 | spearman_corr = spearmanr(preds, labels)[0] 53 | return { 54 | "pearson": pearson_corr, 55 | "spearmanr": spearman_corr, 56 | "corr": (pearson_corr + spearman_corr) / 2, 57 | } 58 | 59 | 60 | def glue_compute_metrics(task_name, preds, labels): 61 | assert len(preds) == len(labels) 62 | if task_name == "cola": 63 | return {"mcc": matthews_corrcoef(labels, preds)} 64 | elif task_name == "sst-2": 65 | return {"acc": simple_accuracy(preds, labels)} 66 | elif task_name == "mrpc": 67 | return acc_and_f1(preds, labels) 68 | elif task_name == "sts-b": 69 | return pearson_and_spearman(preds, labels) 70 | elif task_name == "qqp": 71 | return acc_and_f1(preds, labels) 72 | elif task_name == "mnli": 73 | return {"acc": simple_accuracy(preds, labels)} 74 | elif task_name == "mnli-mm": 75 | return {"acc": simple_accuracy(preds, labels)} 76 | elif task_name == "qnli": 77 | return {"acc": simple_accuracy(preds, labels)} 78 | elif task_name == "rte": 79 | return {"acc": simple_accuracy(preds, labels)} 80 | elif task_name == "wnli": 81 | return {"acc": simple_accuracy(preds, labels)} 82 | else: 83 | raise KeyError(task_name) 84 | 85 | 86 | def xnli_compute_metrics(task_name, preds, labels): 87 | assert len(preds) == len(labels) 88 | if task_name == "xnli": 89 | return {"acc": simple_accuracy(preds, labels)} 90 | else: 91 | raise KeyError(task_name) 92 | -------------------------------------------------------------------------------- /transformers/data/metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import InputExample, InputFeatures, DataProcessor 2 | from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor 4 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/glue.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/glue.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/glue.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/glue.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/squad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/squad.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/squad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/squad.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/xnli.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/xnli.cpython-36.pyc -------------------------------------------------------------------------------- /transformers/data/processors/__pycache__/xnli.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/data/processors/__pycache__/xnli.cpython-37.pyc -------------------------------------------------------------------------------- /transformers/data/processors/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import copy 20 | import json 21 | 22 | class InputExample(object): 23 | """ 24 | A single training/test example for simple sequence classification. 25 | 26 | Args: 27 | guid: Unique id for the example. 28 | text_a: string. The untokenized text of the first sequence. For single 29 | sequence tasks, only this sequence must be specified. 30 | text_b: (Optional) string. The untokenized text of the second sequence. 31 | Only must be specified for sequence pair tasks. 32 | label: (Optional) string. The label of the example. This should be 33 | specified for train and dev examples, but not for test examples. 34 | """ 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | self.guid = guid 37 | self.text_a = text_a 38 | self.text_b = text_b 39 | self.label = label 40 | 41 | def __repr__(self): 42 | return str(self.to_json_string()) 43 | 44 | def to_dict(self): 45 | """Serializes this instance to a Python dictionary.""" 46 | output = copy.deepcopy(self.__dict__) 47 | return output 48 | 49 | def to_json_string(self): 50 | """Serializes this instance to a JSON string.""" 51 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 52 | 53 | 54 | class InputFeatures(object): 55 | """ 56 | A single set of features of data. 57 | 58 | Args: 59 | input_ids: Indices of input sequence tokens in the vocabulary. 60 | attention_mask: Mask to avoid performing attention on padding token indices. 61 | Mask values selected in ``[0, 1]``: 62 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 63 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 64 | label: Label corresponding to the input 65 | """ 66 | 67 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 68 | self.input_ids = input_ids 69 | self.attention_mask = attention_mask 70 | self.token_type_ids = token_type_ids 71 | self.label = label 72 | 73 | def __repr__(self): 74 | return str(self.to_json_string()) 75 | 76 | def to_dict(self): 77 | """Serializes this instance to a Python dictionary.""" 78 | output = copy.deepcopy(self.__dict__) 79 | return output 80 | 81 | def to_json_string(self): 82 | """Serializes this instance to a JSON string.""" 83 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 84 | 85 | 86 | class DataProcessor(object): 87 | """Base class for data converters for sequence classification data sets.""" 88 | 89 | def get_example_from_tensor_dict(self, tensor_dict): 90 | """Gets an example from a dict with tensorflow tensors 91 | 92 | Args: 93 | tensor_dict: Keys and values should match the corresponding Glue 94 | tensorflow_dataset examples. 95 | """ 96 | raise NotImplementedError() 97 | 98 | def get_train_examples(self, data_dir): 99 | """Gets a collection of `InputExample`s for the train set.""" 100 | raise NotImplementedError() 101 | 102 | def get_dev_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the dev set.""" 104 | raise NotImplementedError() 105 | 106 | def get_labels(self): 107 | """Gets the list of labels for this data set.""" 108 | raise NotImplementedError() 109 | 110 | def tfds_map(self, example): 111 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. 112 | This method converts examples to the correct format.""" 113 | if len(self.get_labels()) > 1: 114 | example.label = self.get_labels()[int(example.label)] 115 | return example 116 | 117 | @classmethod 118 | def _read_tsv(cls, input_file, quotechar=None): 119 | """Reads a tab separated value file.""" 120 | with open(input_file, "r", encoding="utf-8-sig") as f: 121 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 122 | lines = [] 123 | for line in reader: 124 | if sys.version_info[0] == 2: 125 | line = list(unicode(cell, 'utf-8') for cell in line) 126 | lines.append(line) 127 | return lines 128 | -------------------------------------------------------------------------------- /transformers/data/processors/xnli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XNLI utils (dataset loading and evaluation) """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import logging 21 | import os 22 | 23 | from .utils import DataProcessor, InputExample 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | class XnliProcessor(DataProcessor): 28 | """Processor for the XNLI dataset. 29 | Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207""" 30 | 31 | def __init__(self, language, train_language = None): 32 | self.language = language 33 | self.train_language = train_language 34 | 35 | def get_train_examples(self, data_dir): 36 | """See base class.""" 37 | lg = self.language if self.train_language is None else self.train_language 38 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(lg))) 39 | examples = [] 40 | for (i, line) in enumerate(lines): 41 | if i == 0: 42 | continue 43 | guid = "%s-%s" % ('train', i) 44 | text_a = line[0] 45 | text_b = line[1] 46 | label = "contradiction" if line[2] == "contradictory" else line[2] 47 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 48 | examples.append( 49 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 50 | return examples 51 | 52 | def get_test_examples(self, data_dir): 53 | """See base class.""" 54 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv")) 55 | examples = [] 56 | for (i, line) in enumerate(lines): 57 | if i == 0: 58 | continue 59 | language = line[0] 60 | if language != self.language: 61 | continue 62 | guid = "%s-%s" % ('test', i) 63 | text_a = line[6] 64 | text_b = line[7] 65 | label = line[1] 66 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) 67 | examples.append( 68 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 69 | return examples 70 | 71 | def get_labels(self): 72 | """See base class.""" 73 | return ["contradiction", "entailment", "neutral"] 74 | 75 | xnli_processors = { 76 | "xnli": XnliProcessor, 77 | } 78 | 79 | xnli_output_modes = { 80 | "xnli": "classification", 81 | } 82 | 83 | xnli_tasks_num_labels = { 84 | "xnli": 3, 85 | } 86 | -------------------------------------------------------------------------------- /transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/tests/__init__.py -------------------------------------------------------------------------------- /transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import json 21 | import tempfile 22 | 23 | import unittest 24 | from .tokenization_tests_commons import TemporaryDirectory 25 | 26 | 27 | class ConfigTester(object): 28 | def __init__(self, parent, config_class=None, **kwargs): 29 | self.parent = parent 30 | self.config_class = config_class 31 | self.inputs_dict = kwargs 32 | 33 | def create_and_test_config_common_properties(self): 34 | config = self.config_class(**self.inputs_dict) 35 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 36 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 37 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 38 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 39 | 40 | def create_and_test_config_to_json_string(self): 41 | config = self.config_class(**self.inputs_dict) 42 | obj = json.loads(config.to_json_string()) 43 | for key, value in self.inputs_dict.items(): 44 | self.parent.assertEqual(obj[key], value) 45 | 46 | def create_and_test_config_to_json_file(self): 47 | config_first = self.config_class(**self.inputs_dict) 48 | 49 | with TemporaryDirectory() as tmpdirname: 50 | json_file_path = os.path.join(tmpdirname, "config.json") 51 | config_first.to_json_file(json_file_path) 52 | config_second = self.config_class.from_json_file(json_file_path) 53 | 54 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 55 | 56 | def create_and_test_config_from_and_save_pretrained(self): 57 | config_first = self.config_class(**self.inputs_dict) 58 | 59 | with TemporaryDirectory() as tmpdirname: 60 | config_first.save_pretrained(tmpdirname) 61 | config_second = self.config_class.from_pretrained(tmpdirname) 62 | 63 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 64 | 65 | def run_common_tests(self): 66 | self.create_and_test_config_common_properties() 67 | self.create_and_test_config_to_json_string() 68 | self.create_and_test_config_to_json_file() 69 | self.create_and_test_config_from_and_save_pretrained() 70 | 71 | if __name__ == "__main__": 72 | unittest.main() -------------------------------------------------------------------------------- /transformers/tests/fixtures/empty.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/tests/fixtures/empty.txt -------------------------------------------------------------------------------- /transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/tests/fixtures/spiece.model -------------------------------------------------------------------------------- /transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhanlaoban/Transformers_for_Text_Classification/5e12b21616b29e445e11fe307948e5c55084bb0e/transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /transformers/tests/hf_api_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function 16 | 17 | import os 18 | import time 19 | import unittest 20 | 21 | import requests 22 | import six 23 | 24 | from transformers.hf_api import HfApi, HfFolder, HTTPError, PresignedUrl, S3Obj 25 | 26 | USER = "__DUMMY_TRANSFORMERS_USER__" 27 | PASS = "__DUMMY_TRANSFORMERS_PASS__" 28 | FILES = [ 29 | ( 30 | "Test-{}.txt".format(int(time.time())), 31 | os.path.join( 32 | os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt" 33 | ) 34 | ), 35 | ( 36 | "yoyo {}.txt".format(int(time.time())), # space is intentional 37 | os.path.join( 38 | os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt" 39 | ) 40 | ), 41 | ] 42 | 43 | 44 | 45 | class HfApiCommonTest(unittest.TestCase): 46 | _api = HfApi(endpoint="https://moon-staging.huggingface.co") 47 | 48 | 49 | class HfApiLoginTest(HfApiCommonTest): 50 | def test_login_invalid(self): 51 | with self.assertRaises(HTTPError): 52 | self._api.login(username=USER, password="fake") 53 | 54 | def test_login_valid(self): 55 | token = self._api.login(username=USER, password=PASS) 56 | self.assertIsInstance(token, six.string_types) 57 | 58 | 59 | class HfApiEndpointsTest(HfApiCommonTest): 60 | @classmethod 61 | def setUpClass(cls): 62 | """ 63 | Share this valid token in all tests below. 64 | """ 65 | cls._token = cls._api.login(username=USER, password=PASS) 66 | 67 | def test_whoami(self): 68 | user = self._api.whoami(token=self._token) 69 | self.assertEqual(user, USER) 70 | 71 | def test_presign(self): 72 | for FILE_KEY, FILE_PATH in FILES: 73 | urls = self._api.presign(token=self._token, filename=FILE_KEY) 74 | self.assertIsInstance(urls, PresignedUrl) 75 | self.assertEqual(urls.type, "text/plain") 76 | 77 | def test_presign_and_upload(self): 78 | for FILE_KEY, FILE_PATH in FILES: 79 | access_url = self._api.presign_and_upload( 80 | token=self._token, filename=FILE_KEY, filepath=FILE_PATH 81 | ) 82 | self.assertIsInstance(access_url, six.string_types) 83 | with open(FILE_PATH, 'r') as f: 84 | body = f.read() 85 | r = requests.get(access_url) 86 | self.assertEqual(r.text, body) 87 | 88 | def test_list_objs(self): 89 | objs = self._api.list_objs(token=self._token) 90 | self.assertIsInstance(objs, list) 91 | if len(objs) > 0: 92 | o = objs[-1] 93 | self.assertIsInstance(o, S3Obj) 94 | 95 | 96 | 97 | class HfFolderTest(unittest.TestCase): 98 | def test_token_workflow(self): 99 | """ 100 | Test the whole token save/get/delete workflow, 101 | with the desired behavior with respect to non-existent tokens. 102 | """ 103 | token = "token-{}".format(int(time.time())) 104 | HfFolder.save_token(token) 105 | self.assertEqual( 106 | HfFolder.get_token(), 107 | token 108 | ) 109 | HfFolder.delete_token() 110 | HfFolder.delete_token() 111 | # ^^ not an error, we test that the 112 | # second call does not fail. 113 | self.assertEqual( 114 | HfFolder.get_token(), 115 | None 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /transformers/tests/model_card_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | 21 | from transformers.model_card import ModelCard 22 | from .tokenization_tests_commons import TemporaryDirectory 23 | 24 | class ModelCardTester(unittest.TestCase): 25 | 26 | def setUp(self): 27 | self.inputs_dict = {'model_details': { 28 | 'Organization': 'testing', 29 | 'Model date': 'today', 30 | 'Model version': 'v2.1, Developed by Test Corp in 2019.', 31 | 'Architecture': 'Convolutional Neural Network.', 32 | }, 33 | 'metrics': 'BLEU and ROUGE-1', 34 | 'evaluation_data':{ 35 | 'Datasets':{ 36 | 'BLEU': 'My-great-dataset-v1', 37 | 'ROUGE-1': 'My-short-dataset-v2.1', 38 | }, 39 | 'Preprocessing': 'See details on https://arxiv.org/pdf/1810.03993.pdf' 40 | }, 41 | 'training_data':{ 42 | 'Dataset': 'English Wikipedia dump dated 2018-12-01', 43 | 'Preprocessing': 'Using SentencePiece vocabulary of size 52k tokens. See details on https://arxiv.org/pdf/1810.03993.pdf' 44 | }, 45 | 'quantitative_analyses': { 46 | 'BLEU': 55.1, 47 | 'ROUGE-1': 76, 48 | }, 49 | } 50 | 51 | def test_model_card_common_properties(self): 52 | model_card = ModelCard.from_dict(self.inputs_dict) 53 | self.assertTrue(hasattr(model_card, 'model_details')) 54 | self.assertTrue(hasattr(model_card, 'intended_use')) 55 | self.assertTrue(hasattr(model_card, 'factors')) 56 | self.assertTrue(hasattr(model_card, 'metrics')) 57 | self.assertTrue(hasattr(model_card, 'evaluation_data')) 58 | self.assertTrue(hasattr(model_card, 'training_data')) 59 | self.assertTrue(hasattr(model_card, 'quantitative_analyses')) 60 | self.assertTrue(hasattr(model_card, 'ethical_considerations')) 61 | self.assertTrue(hasattr(model_card, 'caveats_and_recommendations')) 62 | 63 | def test_model_card_to_json_string(self): 64 | model_card = ModelCard.from_dict(self.inputs_dict) 65 | obj = json.loads(model_card.to_json_string()) 66 | for key, value in self.inputs_dict.items(): 67 | self.assertEqual(obj[key], value) 68 | 69 | def test_model_card_to_json_file(self): 70 | model_card_first = ModelCard.from_dict(self.inputs_dict) 71 | 72 | with TemporaryDirectory() as tmpdirname: 73 | filename = os.path.join(tmpdirname, u"model_card.json") 74 | model_card_first.to_json_file(filename) 75 | model_card_second = ModelCard.from_json_file(filename) 76 | 77 | self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) 78 | 79 | def test_model_card_from_and_save_pretrained(self): 80 | model_card_first = ModelCard.from_dict(self.inputs_dict) 81 | 82 | with TemporaryDirectory() as tmpdirname: 83 | model_card_first.save_pretrained(tmpdirname) 84 | model_card_second = ModelCard.from_pretrained(tmpdirname) 85 | 86 | self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) 87 | 88 | if __name__ == "__main__": 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import logging 22 | 23 | from transformers import is_torch_available 24 | 25 | from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER 26 | 27 | if is_torch_available(): 28 | from transformers import (AutoConfig, BertConfig, 29 | AutoModel, BertModel, 30 | AutoModelWithLMHead, BertForMaskedLM, 31 | AutoModelForSequenceClassification, BertForSequenceClassification, 32 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 33 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 34 | 35 | from .modeling_common_test import (CommonTestCases, ids_tensor) 36 | from .configuration_common_test import ConfigTester 37 | 38 | 39 | @require_torch 40 | class AutoModelTest(unittest.TestCase): 41 | @slow 42 | def test_model_from_pretrained(self): 43 | logging.basicConfig(level=logging.INFO) 44 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 45 | config = AutoConfig.from_pretrained(model_name) 46 | self.assertIsNotNone(config) 47 | self.assertIsInstance(config, BertConfig) 48 | 49 | model = AutoModel.from_pretrained(model_name) 50 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 51 | self.assertIsNotNone(model) 52 | self.assertIsInstance(model, BertModel) 53 | for value in loading_info.values(): 54 | self.assertEqual(len(value), 0) 55 | 56 | @slow 57 | def test_lmhead_model_from_pretrained(self): 58 | logging.basicConfig(level=logging.INFO) 59 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 60 | config = AutoConfig.from_pretrained(model_name) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = AutoModelWithLMHead.from_pretrained(model_name) 65 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 66 | self.assertIsNotNone(model) 67 | self.assertIsInstance(model, BertForMaskedLM) 68 | 69 | @slow 70 | def test_sequence_classification_model_from_pretrained(self): 71 | logging.basicConfig(level=logging.INFO) 72 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 73 | config = AutoConfig.from_pretrained(model_name) 74 | self.assertIsNotNone(config) 75 | self.assertIsInstance(config, BertConfig) 76 | 77 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 78 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 79 | self.assertIsNotNone(model) 80 | self.assertIsInstance(model, BertForSequenceClassification) 81 | 82 | @slow 83 | def test_question_answering_model_from_pretrained(self): 84 | logging.basicConfig(level=logging.INFO) 85 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 86 | config = AutoConfig.from_pretrained(model_name) 87 | self.assertIsNotNone(config) 88 | self.assertIsInstance(config, BertConfig) 89 | 90 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 91 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 92 | self.assertIsNotNone(model) 93 | self.assertIsInstance(model, BertForQuestionAnswering) 94 | 95 | def test_from_pretrained_identifier(self): 96 | logging.basicConfig(level=logging.INFO) 97 | model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) 98 | self.assertIsInstance(model, BertForMaskedLM) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /transformers/tests/modeling_encoder_decoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Hugging Face Inc. Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import unittest 18 | 19 | from transformers import is_torch_available 20 | from .utils import require_torch, slow 21 | 22 | if is_torch_available(): 23 | from transformers import BertModel, BertForMaskedLM, Model2Model 24 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 25 | 26 | 27 | @require_torch 28 | class EncoderDecoderModelTest(unittest.TestCase): 29 | @slow 30 | def test_model2model_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 33 | model = Model2Model.from_pretrained(model_name) 34 | self.assertIsInstance(model.encoder, BertModel) 35 | self.assertIsInstance(model.decoder, BertForMaskedLM) 36 | self.assertEqual(model.decoder.config.is_decoder, True) 37 | self.assertEqual(model.encoder.config.is_decoder, False) 38 | 39 | def test_model2model_from_pretrained_not_bert(self): 40 | logging.basicConfig(level=logging.INFO) 41 | with self.assertRaises(ValueError): 42 | _ = Model2Model.from_pretrained('roberta') 43 | 44 | with self.assertRaises(ValueError): 45 | _ = Model2Model.from_pretrained('distilbert') 46 | 47 | with self.assertRaises(ValueError): 48 | _ = Model2Model.from_pretrained('does-not-exist') 49 | 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /transformers/tests/modeling_tf_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import logging 22 | 23 | from transformers import is_tf_available 24 | 25 | from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER 26 | 27 | if is_tf_available(): 28 | from transformers import (AutoConfig, BertConfig, 29 | TFAutoModel, TFBertModel, 30 | TFAutoModelWithLMHead, TFBertForMaskedLM, 31 | TFAutoModelForSequenceClassification, TFBertForSequenceClassification, 32 | TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering) 33 | from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP 34 | 35 | from .modeling_common_test import (CommonTestCases, ids_tensor) 36 | from .configuration_common_test import ConfigTester 37 | 38 | 39 | @require_tf 40 | class TFAutoModelTest(unittest.TestCase): 41 | @slow 42 | def test_model_from_pretrained(self): 43 | import h5py 44 | self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) 45 | 46 | logging.basicConfig(level=logging.INFO) 47 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 48 | for model_name in ['bert-base-uncased']: 49 | config = AutoConfig.from_pretrained(model_name, force_download=True) 50 | self.assertIsNotNone(config) 51 | self.assertIsInstance(config, BertConfig) 52 | 53 | model = TFAutoModel.from_pretrained(model_name, force_download=True) 54 | self.assertIsNotNone(model) 55 | self.assertIsInstance(model, TFBertModel) 56 | 57 | @slow 58 | def test_lmhead_model_from_pretrained(self): 59 | logging.basicConfig(level=logging.INFO) 60 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 61 | for model_name in ['bert-base-uncased']: 62 | config = AutoConfig.from_pretrained(model_name, force_download=True) 63 | self.assertIsNotNone(config) 64 | self.assertIsInstance(config, BertConfig) 65 | 66 | model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) 67 | self.assertIsNotNone(model) 68 | self.assertIsInstance(model, TFBertForMaskedLM) 69 | 70 | @slow 71 | def test_sequence_classification_model_from_pretrained(self): 72 | logging.basicConfig(level=logging.INFO) 73 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 74 | for model_name in ['bert-base-uncased']: 75 | config = AutoConfig.from_pretrained(model_name, force_download=True) 76 | self.assertIsNotNone(config) 77 | self.assertIsInstance(config, BertConfig) 78 | 79 | model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) 80 | self.assertIsNotNone(model) 81 | self.assertIsInstance(model, TFBertForSequenceClassification) 82 | 83 | @slow 84 | def test_question_answering_model_from_pretrained(self): 85 | logging.basicConfig(level=logging.INFO) 86 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 87 | for model_name in ['bert-base-uncased']: 88 | config = AutoConfig.from_pretrained(model_name, force_download=True) 89 | self.assertIsNotNone(config) 90 | self.assertIsInstance(config, BertConfig) 91 | 92 | model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) 93 | self.assertIsNotNone(model) 94 | self.assertIsInstance(model, TFBertForQuestionAnswering) 95 | 96 | def test_from_pretrained_identifier(self): 97 | logging.basicConfig(level=logging.INFO) 98 | model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) 99 | self.assertIsInstance(model, TFBertForMaskedLM) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /transformers/tests/optimization_tf_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import unittest 6 | 7 | from transformers import is_tf_available 8 | 9 | from .utils import require_tf 10 | 11 | if is_tf_available(): 12 | import tensorflow as tf 13 | from tensorflow.python.eager import context 14 | from tensorflow.python.framework import ops 15 | from transformers import (create_optimizer, GradientAccumulator) 16 | 17 | 18 | @require_tf 19 | class OptimizationFTest(unittest.TestCase): 20 | def assertListAlmostEqual(self, list1, list2, tol): 21 | self.assertEqual(len(list1), len(list2)) 22 | for a, b in zip(list1, list2): 23 | self.assertAlmostEqual(a, b, delta=tol) 24 | 25 | def testGradientAccumulator(self): 26 | accumulator = GradientAccumulator() 27 | accumulator([tf.constant([1.0, 2.0])]) 28 | accumulator([tf.constant([-2.0, 1.0])]) 29 | accumulator([tf.constant([-1.0, 2.0])]) 30 | with self.assertRaises(ValueError): 31 | accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])]) 32 | self.assertEqual(accumulator.step, 3) 33 | self.assertEqual(len(accumulator.gradients), 1) 34 | self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2) 35 | accumulator.reset() 36 | self.assertEqual(accumulator.step, 0) 37 | self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2) 38 | 39 | def testGradientAccumulatorDistributionStrategy(self): 40 | context._context = None 41 | ops.enable_eager_execution_internal() 42 | physical_devices = tf.config.experimental.list_physical_devices("CPU") 43 | tf.config.experimental.set_virtual_device_configuration( 44 | physical_devices[0], 45 | [tf.config.experimental.VirtualDeviceConfiguration(), 46 | tf.config.experimental.VirtualDeviceConfiguration()]) 47 | 48 | devices = tf.config.experimental.list_logical_devices(device_type="CPU") 49 | strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices]) 50 | 51 | with strategy.scope(): 52 | accumulator = GradientAccumulator() 53 | variable = tf.Variable([4.0, 3.0]) 54 | optimizer = create_optimizer(5e-5, 10, 5) 55 | gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False) 56 | 57 | def accumulate_on_replica(gradient): 58 | accumulator([gradient]) 59 | 60 | def apply_on_replica(): 61 | optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0) 62 | 63 | @tf.function 64 | def accumulate(grad1, grad2): 65 | with strategy.scope(): 66 | gradient_placeholder.values[0].assign(grad1) 67 | gradient_placeholder.values[1].assign(grad2) 68 | strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,)) 69 | 70 | @tf.function 71 | def apply_grad(): 72 | with strategy.scope(): 73 | strategy.experimental_run_v2(apply_on_replica) 74 | 75 | accumulate([1.0, 2.0], [-1.0, 1.0]) 76 | accumulate([3.0, -1.0], [-1.0, -1.0]) 77 | accumulate([-2.0, 2.0], [3.0, -2.0]) 78 | self.assertEqual(accumulator.step, 3) 79 | self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2) 80 | self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2) 81 | apply_grad() 82 | self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2) 83 | accumulator.reset() 84 | self.assertEqual(accumulator.step, 0) 85 | self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2) 86 | self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2) 87 | 88 | 89 | if __name__ == "__main__": 90 | unittest.main() -------------------------------------------------------------------------------- /transformers/tests/tokenization_albert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Hugging Face inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from transformers.tokenization_albert import (AlbertTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/spiece.model') 26 | 27 | class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | 29 | tokenizer_class = AlbertTokenizer 30 | 31 | def setUp(self): 32 | super(AlbertTokenizationTest, self).setUp() 33 | 34 | # We have a SentencePiece fixture for testing 35 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB) 36 | tokenizer.save_pretrained(self.tmpdirname) 37 | 38 | def get_tokenizer(self, **kwargs): 39 | return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 40 | 41 | def get_input_output_texts(self): 42 | input_text = u"this is a test" 43 | output_text = u"this is a test" 44 | return input_text, output_text 45 | 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB, keep_accents=True) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁this', u'▁is', u'▁a', u'▁test']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [48, 25, 21, 1289]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [u'▁i', u'▁was', u'▁born', u'▁in', u'▁9', u'2000', u',', u'▁and', u'▁this', u'▁is', u'▁fal', u's', u'é', u'.']) 58 | ids = tokenizer.convert_tokens_to_ids(tokens) 59 | self.assertListEqual(ids, [31, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9]) 60 | 61 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 62 | self.assertListEqual(back_tokens, ['▁i', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fal', 's', '', '.']) 63 | 64 | def test_sequence_builders(self): 65 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB) 66 | 67 | text = tokenizer.encode("sequence builders") 68 | text_2 = tokenizer.encode("multi-sequence build") 69 | 70 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 71 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 72 | 73 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] 74 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [tokenizer.sep_token_id] 75 | 76 | 77 | if __name__ == '__main__': 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import logging 22 | 23 | from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 24 | from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 25 | 26 | from .utils import slow, SMALL_MODEL_IDENTIFIER 27 | 28 | 29 | class AutoTokenizerTest(unittest.TestCase): 30 | @slow 31 | def test_tokenizer_from_pretrained(self): 32 | logging.basicConfig(level=logging.INFO) 33 | for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 34 | tokenizer = AutoTokenizer.from_pretrained(model_name) 35 | self.assertIsNotNone(tokenizer) 36 | self.assertIsInstance(tokenizer, BertTokenizer) 37 | self.assertGreater(len(tokenizer), 0) 38 | 39 | for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 40 | tokenizer = AutoTokenizer.from_pretrained(model_name) 41 | self.assertIsNotNone(tokenizer) 42 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 43 | self.assertGreater(len(tokenizer), 0) 44 | 45 | def test_tokenizer_from_pretrained_identifier(self): 46 | logging.basicConfig(level=logging.INFO) 47 | tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) 48 | self.assertIsInstance(tokenizer, BertTokenizer) 49 | self.assertEqual(len(tokenizer), 12) 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | from .utils import slow 29 | 30 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 31 | 32 | tokenizer_class = BertTokenizer 33 | 34 | def setUp(self): 35 | super(BertTokenizationTest, self).setUp() 36 | 37 | vocab_tokens = [ 38 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 39 | "##ing", ",", "low", "lowest", 40 | ] 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 43 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 47 | 48 | def get_input_output_texts(self): 49 | input_text = u"UNwant\u00E9d,running" 50 | output_text = u"unwanted, running" 51 | return input_text, output_text 52 | 53 | def test_full_tokenizer(self): 54 | tokenizer = self.tokenizer_class(self.vocab_file) 55 | 56 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 57 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 58 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 59 | 60 | def test_chinese(self): 61 | tokenizer = BasicTokenizer() 62 | 63 | self.assertListEqual( 64 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 65 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 66 | 67 | def test_basic_tokenizer_lower(self): 68 | tokenizer = BasicTokenizer(do_lower_case=True) 69 | 70 | self.assertListEqual( 71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 72 | ["hello", "!", "how", "are", "you", "?"]) 73 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 74 | 75 | def test_basic_tokenizer_no_lower(self): 76 | tokenizer = BasicTokenizer(do_lower_case=False) 77 | 78 | self.assertListEqual( 79 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 80 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 81 | 82 | def test_wordpiece_tokenizer(self): 83 | vocab_tokens = [ 84 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 85 | "##ing" 86 | ] 87 | 88 | vocab = {} 89 | for (i, token) in enumerate(vocab_tokens): 90 | vocab[token] = i 91 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 92 | 93 | self.assertListEqual(tokenizer.tokenize(""), []) 94 | 95 | self.assertListEqual( 96 | tokenizer.tokenize("unwanted running"), 97 | ["un", "##want", "##ed", "runn", "##ing"]) 98 | 99 | self.assertListEqual( 100 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 101 | 102 | def test_is_whitespace(self): 103 | self.assertTrue(_is_whitespace(u" ")) 104 | self.assertTrue(_is_whitespace(u"\t")) 105 | self.assertTrue(_is_whitespace(u"\r")) 106 | self.assertTrue(_is_whitespace(u"\n")) 107 | self.assertTrue(_is_whitespace(u"\u00A0")) 108 | 109 | self.assertFalse(_is_whitespace(u"A")) 110 | self.assertFalse(_is_whitespace(u"-")) 111 | 112 | def test_is_control(self): 113 | self.assertTrue(_is_control(u"\u0005")) 114 | 115 | self.assertFalse(_is_control(u"A")) 116 | self.assertFalse(_is_control(u" ")) 117 | self.assertFalse(_is_control(u"\t")) 118 | self.assertFalse(_is_control(u"\r")) 119 | 120 | def test_is_punctuation(self): 121 | self.assertTrue(_is_punctuation(u"-")) 122 | self.assertTrue(_is_punctuation(u"$")) 123 | self.assertTrue(_is_punctuation(u"`")) 124 | self.assertTrue(_is_punctuation(u".")) 125 | 126 | self.assertFalse(_is_punctuation(u"A")) 127 | self.assertFalse(_is_punctuation(u" ")) 128 | 129 | @slow 130 | def test_sequence_builders(self): 131 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 132 | 133 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 134 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 135 | 136 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 137 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 138 | 139 | assert encoded_sentence == [101] + text + [102] 140 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 141 | 142 | 143 | if __name__ == '__main__': 144 | unittest.main() 145 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_ctrl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import, division, print_function, unicode_literals 15 | 16 | import os 17 | import unittest 18 | import json 19 | from io import open 20 | 21 | from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = CTRLTokenizer 28 | 29 | def setUp(self): 30 | super(CTRLTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', ''] 34 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 35 | merges = ["#version: 0.2", 'a p', 'ap t', 'r e', 'a d', 'ad apt', ''] 36 | self.special_tokens_map = {"unk_token": ""} 37 | 38 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 39 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 40 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 41 | fp.write(json.dumps(vocab_tokens) + "\n") 42 | with open(self.merges_file, "w", encoding="utf-8") as fp: 43 | fp.write("\n".join(merges)) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | kwargs.update(self.special_tokens_map) 47 | return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"adapt react readapt apt" 51 | output_text = u"adapt react readapt apt" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 56 | text = "adapt react readapt apt" 57 | bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split() 58 | tokens = tokenizer.tokenize(text) 59 | self.assertListEqual(tokens, bpe_tokens) 60 | 61 | input_tokens = tokens + [tokenizer.unk_token] 62 | 63 | input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6] 64 | self.assertListEqual( 65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_distilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers.tokenization_distilbert import (DistilBertTokenizer) 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .tokenization_bert_test import BertTokenizationTest 25 | from .utils import slow 26 | 27 | class DistilBertTokenizationTest(BertTokenizationTest): 28 | 29 | tokenizer_class = DistilBertTokenizer 30 | 31 | def get_tokenizer(self, **kwargs): 32 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 33 | 34 | @slow 35 | def test_sequence_builders(self): 36 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 37 | 38 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 39 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 40 | 41 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 42 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 43 | 44 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] 45 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \ 46 | text_2 + [tokenizer.sep_token_id] 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | from io import open 21 | 22 | from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = GPT2Tokenizer 29 | 30 | def setUp(self): 31 | super(GPT2TokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | if __name__ == '__main__': 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | from io import open 21 | 22 | from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .utils import slow 25 | 26 | 27 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | tokenizer_class = RobertaTokenizer 29 | 30 | def setUp(self): 31 | super(RobertaTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | def roberta_dict_integration_testing(self): 71 | tokenizer = self.get_tokenizer() 72 | 73 | self.assertListEqual( 74 | tokenizer.encode('Hello world!', add_special_tokens=False), 75 | [0, 31414, 232, 328, 2] 76 | ) 77 | self.assertListEqual( 78 | tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False), 79 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 80 | ) 81 | 82 | @slow 83 | def test_sequence_builders(self): 84 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 85 | 86 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 87 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 88 | 89 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 90 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 91 | 92 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 93 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 94 | 95 | assert encoded_sentence == encoded_text_from_decode 96 | assert encoded_pair == encoded_pair_from_decode 97 | 98 | 99 | if __name__ == '__main__': 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_t5_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google T5 Authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from transformers.tokenization_t5 import (T5Tokenizer) 21 | from transformers.tokenization_xlnet import SPIECE_UNDERLINE 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 26 | 'fixtures/test_sentencepiece.model') 27 | 28 | class T5TokenizationTest(CommonTestCases.CommonTokenizerTester): 29 | 30 | tokenizer_class = T5Tokenizer 31 | 32 | def setUp(self): 33 | super(T5TokenizationTest, self).setUp() 34 | 35 | # We have a SentencePiece fixture for testing 36 | tokenizer = T5Tokenizer(SAMPLE_VOCAB) 37 | tokenizer.save_pretrained(self.tmpdirname) 38 | 39 | def get_tokenizer(self, **kwargs): 40 | return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 41 | 42 | def get_input_output_texts(self): 43 | input_text = u"This is a test" 44 | output_text = u"This is a test" 45 | return input_text, output_text 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = T5Tokenizer(SAMPLE_VOCAB) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 58 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 59 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 60 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 61 | ids = tokenizer.convert_tokens_to_ids(tokens) 62 | self.assertListEqual( 63 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 64 | 602, 347, 347, 347, 3, 12, 66, 65 | 46, 72, 80, 6, 0, 4]) 66 | 67 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 68 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 69 | u'or', u'n', SPIECE_UNDERLINE + u'in', 70 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 71 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 72 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 73 | u'', u'.']) 74 | 75 | 76 | if __name__ == '__main__': 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from transformers import is_torch_available 22 | 23 | if is_torch_available(): 24 | import torch 25 | from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | from .utils import require_torch 29 | 30 | 31 | @require_torch 32 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 33 | 34 | tokenizer_class = TransfoXLTokenizer if is_torch_available() else None 35 | 36 | def setUp(self): 37 | super(TransfoXLTokenizationTest, self).setUp() 38 | 39 | vocab_tokens = [ 40 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 41 | "running", ",", "low", "l", 42 | ] 43 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 44 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 45 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 46 | 47 | def get_tokenizer(self, **kwargs): 48 | kwargs['lower_case'] = True 49 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u" UNwanted , running" 53 | output_text = u" unwanted, running" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 58 | 59 | tokens = tokenizer.tokenize(u" UNwanted , running") 60 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 61 | 62 | self.assertListEqual( 63 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 64 | 65 | def test_full_tokenizer_lower(self): 66 | tokenizer = TransfoXLTokenizer(lower_case=True) 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 70 | ["hello", "!", "how", "are", "you", "?"]) 71 | 72 | def test_full_tokenizer_no_lower(self): 73 | tokenizer = TransfoXLTokenizer(lower_case=False) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 77 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from transformers import PreTrainedTokenizer 23 | from transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | from .utils import slow 26 | 27 | class TokenizerUtilsTest(unittest.TestCase): 28 | 29 | def check_tokenizer_from_pretrained(self, tokenizer_class): 30 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 31 | for model_name in s3_models[:1]: 32 | tokenizer = tokenizer_class.from_pretrained(model_name) 33 | self.assertIsNotNone(tokenizer) 34 | self.assertIsInstance(tokenizer, tokenizer_class) 35 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 36 | 37 | for special_tok in tokenizer.all_special_tokens: 38 | if six.PY2: 39 | self.assertIsInstance(special_tok, unicode) 40 | else: 41 | self.assertIsInstance(special_tok, str) 42 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 43 | self.assertIsInstance(special_tok_id, int) 44 | 45 | @slow 46 | def test_pretrained_tokenizers(self): 47 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .utils import slow 25 | 26 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = XLMTokenizer 29 | 30 | def setUp(self): 31 | super(XLMTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 58 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | @slow 71 | def test_sequence_builders(self): 72 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 73 | 74 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 75 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 76 | 77 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 78 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 79 | 80 | assert encoded_sentence == [1] + text + [1] 81 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | from .utils import slow 24 | 25 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 26 | 'fixtures/test_sentencepiece.model') 27 | 28 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 29 | 30 | tokenizer_class = XLNetTokenizer 31 | 32 | def setUp(self): 33 | super(XLNetTokenizationTest, self).setUp() 34 | 35 | # We have a SentencePiece fixture for testing 36 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 37 | tokenizer.save_pretrained(self.tmpdirname) 38 | 39 | def get_tokenizer(self, **kwargs): 40 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) 41 | 42 | def get_input_output_texts(self): 43 | input_text = u"This is a test" 44 | output_text = u"This is a test" 45 | return input_text, output_text 46 | 47 | 48 | def test_full_tokenizer(self): 49 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 50 | 51 | tokens = tokenizer.tokenize(u'This is a test') 52 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 53 | 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 56 | 57 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 58 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 59 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 60 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 61 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 62 | ids = tokenizer.convert_tokens_to_ids(tokens) 63 | self.assertListEqual( 64 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 65 | 602, 347, 347, 347, 3, 12, 66, 66 | 46, 72, 80, 6, 0, 4]) 67 | 68 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 69 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 70 | u'or', u'n', SPIECE_UNDERLINE + u'in', 71 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 72 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 73 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 74 | u'', u'.']) 75 | 76 | def test_tokenizer_lower(self): 77 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 78 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 79 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 80 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 81 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 82 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 83 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 84 | 85 | def test_tokenizer_no_lower(self): 86 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 87 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 88 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 89 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 90 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 91 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 92 | 93 | @slow 94 | def test_sequence_builders(self): 95 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 96 | 97 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 98 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 99 | 100 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 101 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 102 | 103 | assert encoded_sentence == text + [4, 3] 104 | assert encoded_pair == text + [4] + text_2 + [4, 3] 105 | 106 | 107 | if __name__ == '__main__': 108 | unittest.main() 109 | -------------------------------------------------------------------------------- /transformers/tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from distutils.util import strtobool 5 | 6 | from transformers.file_utils import _tf_available, _torch_available 7 | 8 | 9 | SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" 10 | 11 | 12 | def parse_flag_from_env(key, default=False): 13 | try: 14 | value = os.environ[key] 15 | except KeyError: 16 | # KEY isn't set, default to `default`. 17 | _value = default 18 | else: 19 | # KEY is set, convert it to True or False. 20 | try: 21 | _value = strtobool(value) 22 | except ValueError: 23 | # More values are supported, but let's keep the message simple. 24 | raise ValueError("If set, {} must be yes or no.".format(key)) 25 | return _value 26 | 27 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) 28 | _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) 29 | 30 | 31 | def slow(test_case): 32 | """ 33 | Decorator marking a test as slow. 34 | 35 | Slow tests are skipped by default. Set the RUN_SLOW environment variable 36 | to a truthy value to run them. 37 | 38 | """ 39 | if not _run_slow_tests: 40 | test_case = unittest.skip("test is slow")(test_case) 41 | return test_case 42 | 43 | 44 | def custom_tokenizers(test_case): 45 | """ 46 | Decorator marking a test for a custom tokenizer. 47 | 48 | Custom tokenizers require additional dependencies, and are skipped 49 | by default. Set the RUN_CUSTOM_TOKENIZERS environment variable 50 | to a truthy value to run them. 51 | """ 52 | if not _run_custom_tokenizers: 53 | test_case = unittest.skip("test of custom tokenizers")(test_case) 54 | return test_case 55 | 56 | 57 | def require_torch(test_case): 58 | """ 59 | Decorator marking a test that requires PyTorch. 60 | 61 | These tests are skipped when PyTorch isn't installed. 62 | 63 | """ 64 | if not _torch_available: 65 | test_case = unittest.skip("test requires PyTorch")(test_case) 66 | return test_case 67 | 68 | 69 | def require_tf(test_case): 70 | """ 71 | Decorator marking a test that requires TensorFlow. 72 | 73 | These tests are skipped when TensorFlow isn't installed. 74 | 75 | """ 76 | if not _tf_available: 77 | test_case = unittest.skip("test requires TensorFlow")(test_case) 78 | return test_case 79 | 80 | 81 | if _torch_available: 82 | # Set the USE_CUDA environment variable to select a GPU. 83 | torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" 84 | else: 85 | torch_device = None 86 | -------------------------------------------------------------------------------- /transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_bert import BertTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | 'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt", 37 | 'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 38 | } 39 | } 40 | 41 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 42 | 'distilbert-base-uncased': 512, 43 | 'distilbert-base-uncased-distilled-squad': 512, 44 | 'distilbert-base-german-cased': 512, 45 | 'distilbert-base-multilingual-cased': 512, 46 | } 47 | 48 | 49 | class DistilBertTokenizer(BertTokenizer): 50 | r""" 51 | Constructs a DistilBertTokenizer. 52 | :class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 53 | 54 | Args: 55 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 56 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 57 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 58 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 59 | minimum of this value (if specified) and the underlying BERT model's sequence length. 60 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 61 | do_wordpiece_only=False 62 | """ 63 | 64 | vocab_files_names = VOCAB_FILES_NAMES 65 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 66 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 67 | --------------------------------------------------------------------------------