├── .idea ├── TSBERT.iml ├── deployment.xml ├── encodings.xml ├── misc.xml ├── modules.xml ├── remote-mappings.xml └── vcs.xml ├── README.md ├── examples ├── DATASET │ ├── ch │ │ ├── cut_result_zys_reference_v2.txt │ │ └── document_zys_2.json │ └── en │ │ ├── document_en_dataset_2.json │ │ └── en_dataset_reference_2.txt ├── data │ ├── alime │ │ └── put_E-commerce_data_here.txt │ ├── douban │ │ └── put_douban_data_here.txt │ └── ubuntu │ │ └── put_ubuntu_data_here.txt ├── data_process.py ├── run_TSbert_v3.py ├── segmentation_BERTCLS.py ├── utils_TSbert_v3.py └── utils_segmentation.py └── pytorch_transformers ├── __init__.py ├── __main__.py ├── configuration_auto.py ├── configuration_bert.py ├── configuration_distilbert.py ├── configuration_gpt2.py ├── configuration_openai.py ├── configuration_roberta.py ├── configuration_transfo_xl.py ├── configuration_utils.py ├── configuration_xlm.py ├── configuration_xlnet.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf.py ├── convert_roberta_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── convert_xlm_checkpoint_to_pytorch.py ├── convert_xlnet_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling_TSbert.py ├── modeling_TSbert_v3.py ├── modeling_auto.py ├── modeling_bert.py ├── modeling_distilbert.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_roberta.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlnet.py ├── optimization.py ├── optimization_bert.py ├── tests ├── __init__.py ├── configuration_common_test.py ├── conftest.py ├── fixtures │ ├── input.txt │ ├── sample_text.txt │ └── test_sentencepiece.model ├── modeling_auto_test.py ├── modeling_bert_test.py ├── modeling_common_test.py ├── modeling_distilbert_test.py ├── modeling_gpt2_test.py ├── modeling_openai_test.py ├── modeling_roberta_test.py ├── modeling_transfo_xl_test.py ├── modeling_xlm_test.py ├── modeling_xlnet_test.py ├── optimization_test.py ├── tokenization_auto_test.py ├── tokenization_bert_test.py ├── tokenization_dilbert_test.py ├── tokenization_gpt2_test.py ├── tokenization_openai_test.py ├── tokenization_roberta_test.py ├── tokenization_tests_commons.py ├── tokenization_transfo_xl_test.py ├── tokenization_utils_test.py ├── tokenization_xlm_test.py └── tokenization_xlnet_test.py ├── tokenization_auto.py ├── tokenization_bert.py ├── tokenization_distilbert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_roberta.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py └── tokenization_xlnet.py /.idea/TSBERT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | Please download datasets to the corresponding directory under "data" 3 | 4 | E-commerce 5 | https://drive.google.com/file/d/154J-neBo20ABtSmJDvm7DK0eTuieAuvw/view?usp=sharing. 6 | 7 | Ubuntu 8 | https://www.dropbox.com/s/2fdn26rj6h9bpvl/ubuntudata.zip?dl=0 9 | 10 | Douban 11 | https://www.dropbox.com/s/90t0qtji9ow20ca/DoubanConversaionCorpus.zip?dl=0&file_subpath=%2FDoubanConversaionCorpus 12 | 13 | Our own dataset for segmentation is under DATASET directory 14 | 15 | ## Source Code 16 | * prepare data 17 | 18 | generate cutlist.txt 19 | 20 | python segmentation_BERTCLS.py --datapath=data/xxx/xxx.txt 21 | 22 | gather segmented data: data/xxx/xxxseg.txt: 23 | 24 | set interval = 2 for train.txt, interval = 10 for test.txt 25 | set corresponding datafile and dataset in data_process.py 26 | 27 | python data_process.py 28 | 29 | * train 30 | 31 | python run_TSbert_v3.py --task=alime --do_train --train_batch_size=20 --learning_rate=2e-5 32 | 33 | The data will be saved in data/alime/input_cache_v3 34 | 35 | model will be saved in data/alime/model_save_v3, training log will also be saved in log.txt 36 | 37 | 38 | * eval 39 | 40 | python run_TSbert_v3.py --task=xxx 41 | 42 | You can also load our trained model for testing https://drive.google.com/drive/folders/1_sRSmwlaAK_TPaVYYNhao81rXUW92z98?usp=sharing 43 | 44 | ### Environment: 45 | we use pre-trained BERT of pytorch version from https://github.com/huggingface/transformers 46 | 47 | torch>=1.0.0 48 | 49 | package: tqdm, boto3, requests, regex, sacremoses, openpyxl, numpy, sentencepiece 50 | 51 | ### Reference 52 | 53 | If you use this code please cite our paper: 54 | ``` 55 | @article{xu2020topic, 56 | title={Topic-aware multi-turn dialogue modeling}, 57 | author={Xu, Yi and Zhao, Hai and Zhang, Zhuosheng}, 58 | journal={arXiv preprint arXiv:2009.12539}, 59 | year={2020} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /examples/data/alime/put_E-commerce_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/examples/data/alime/put_E-commerce_data_here.txt -------------------------------------------------------------------------------- /examples/data/douban/put_douban_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/examples/data/douban/put_douban_data_here.txt -------------------------------------------------------------------------------- /examples/data/ubuntu/put_ubuntu_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/examples/data/ubuntu/put_ubuntu_data_here.txt -------------------------------------------------------------------------------- /examples/data_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | input_file = "data/alime/train.txt" 3 | output_file_seg = "data/alime/train_seg.txt" 4 | cut_list_file = "data/alime/cutlist_train.json" 5 | 6 | final_outputfile="data/alime/trainseg.txt" 7 | interval = 2 8 | 9 | def generate_output_file_seg(): 10 | contexts=[] 11 | num=0 12 | with open(input_file,'r',encoding='utf-8') as rf: 13 | for line in rf: 14 | line=line.strip() 15 | if(line and num%interval==0): 16 | contexts.append(line.split("\t")[1:-1]) 17 | num += 1 18 | print(num) 19 | with open(cut_list_file, 'r', encoding='utf-8') as rf: 20 | cutlist = json.loads(rf.read()) 21 | 22 | print(len(contexts)) 23 | print(len(cutlist)) 24 | 25 | with open(output_file_seg,'w',encoding='utf-8') as wf: 26 | for context, cutl in zip(contexts, cutlist): 27 | seg_list = [] 28 | seg="" 29 | index_1 = 0 30 | index_2 = 0 31 | for index, utt in enumerate(context): 32 | if(seg): 33 | seg=seg+" "+utt 34 | else: 35 | seg=utt 36 | if (index_1 == cutl[index_2]): 37 | index_2 += 1 38 | seg_list.append(seg) 39 | seg="" 40 | index_1 += 1 41 | # a="\t".join(seg_list)+'\n' 42 | wf.write("\t".join(seg_list)+'\n') 43 | 44 | 45 | def write_segfile(inputfile_utt, inputfile_seg, outputfile, interval): 46 | lab_res_list=[] 47 | with open(inputfile_utt, 'r', encoding='utf-8') as rf: 48 | for line in rf: 49 | line=line.strip() 50 | if(line): 51 | lab=line.split('\t')[0] 52 | res=line.split('\t')[-1] 53 | lab_res_list.append([lab,None,res]) 54 | print(len(lab_res_list)) 55 | seg_list = [] 56 | with open(inputfile_seg, 'r', encoding='utf-8') as rf: 57 | for line in rf: 58 | line = line.strip() 59 | if(line): 60 | seg_list.append(line) 61 | print(len(seg_list)) 62 | i = 0 63 | for seg in seg_list: 64 | for _ in range(interval): 65 | lab_res_list[i][1] = seg 66 | i += 1 67 | print(i) 68 | 69 | with open(outputfile,'w',encoding='utf-8') as wf: 70 | for lab_seg_res in lab_res_list: 71 | wf.write('\t'.join(lab_seg_res)+'\n') 72 | 73 | 74 | generate_output_file_seg() 75 | write_segfile(input_file, output_file_seg, final_outputfile, interval) -------------------------------------------------------------------------------- /examples/segmentation_BERTCLS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import json 4 | import argparse 5 | import numpy as np 6 | import os 7 | 8 | from pytorch_transformers import BertConfig, BertModel, BertTokenizer 9 | from utils_segmentation import convert_examples_to_features, read_expamples_2 10 | WINDOW_SIZE = 2 11 | SEGMENT_JUMP_STEP = 2 12 | SIMILARITY_THRESHOLD = 0.6 13 | MAX_SEGMENT_ROUND = 6 14 | MAX_SEQ_LENGTH = 50 15 | MODEL_CLASSES = { 16 | 'bert': (BertConfig, BertModel, BertTokenizer), 17 | } 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | def similarity(A, B): 20 | return np.dot(A, B) / (np.linalg.norm(A) * np.linalg.norm(B)) 21 | 22 | 23 | def generate_vectors_2(model,examples,tokenizer,device): 24 | features = convert_examples_to_features(examples, MAX_SEQ_LENGTH, tokenizer, 25 | cls_token=tokenizer.cls_token, 26 | cls_token_segment_id=0, 27 | sep_token=tokenizer.sep_token, 28 | ) 29 | 30 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device) 31 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long).to(device) 32 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long).to(device) 33 | 34 | vectors = model(input_ids=all_input_ids, attention_mask=all_input_mask, token_type_ids=all_segment_ids)[1] 35 | return vectors.cpu().detach().numpy() 36 | 37 | def segmentation(documents): 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | config_class, model_class, tokenizer_class = MODEL_CLASSES['bert'] 40 | config = config_class.from_pretrained(args.berttype) 41 | tokenizer = tokenizer_class.from_pretrained(args.berttype,do_lower_case=True) 42 | model = model_class.from_pretrained(args.berttype ,config=config).to(device)#, from_tf=bool('.ckpt' in args.model_name_or_path), 43 | model.eval() 44 | 45 | all_cut_list = [] 46 | pbar = tqdm(total=len(documents)) 47 | for document_o in documents: 48 | if(len(document_o)%2): 49 | document=document_o[1:] 50 | else: 51 | document = document_o 52 | cut_index=0 53 | cut_list = [] 54 | while(cut_index0): 62 | index=WINDOW_SIZE 63 | while(index>0): 64 | left_sent+=document[cut_index-index] 65 | index-=1 66 | 67 | else: 68 | temp_index=0 69 | while(temp_index right_value else right_value 104 | if(not left_sent and not right_sent):#防止前后都没有参考窗口,即len(document)<=MAX_SEGMENT_ROUND 105 | 106 | larger_value=SIMILARITY_THRESHOLD #如果中间截断的情况的最小相似性都大于0.8则这段通话不进行切分,中间截断的情况只有小于 107 | #这个阈值才会截断, 108 | if(larger_value 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: 5 | print( 6 | "Should be used as one of: \n" 7 | ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" 8 | ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" 9 | ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" 10 | ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" 11 | ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" 12 | ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") 13 | else: 14 | if sys.argv[1] == "bert": 15 | try: 16 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 17 | except ImportError: 18 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 19 | "In that case, it requires TensorFlow to be installed. Please see " 20 | "https://www.tensorflow.org/install/ for installation instructions.") 21 | raise 22 | 23 | if len(sys.argv) != 5: 24 | # pylint: disable=line-too-long 25 | print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 26 | else: 27 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 28 | TF_CONFIG = sys.argv.pop() 29 | TF_CHECKPOINT = sys.argv.pop() 30 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 31 | elif sys.argv[1] == "gpt": 32 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 33 | if len(sys.argv) < 4 or len(sys.argv) > 5: 34 | # pylint: disable=line-too-long 35 | print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") 36 | else: 37 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 38 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 39 | if len(sys.argv) == 5: 40 | OPENAI_GPT_CONFIG = sys.argv[4] 41 | else: 42 | OPENAI_GPT_CONFIG = "" 43 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 44 | OPENAI_GPT_CONFIG, 45 | PYTORCH_DUMP_OUTPUT) 46 | elif sys.argv[1] == "transfo_xl": 47 | try: 48 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 49 | except ImportError: 50 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 51 | "In that case, it requires TensorFlow to be installed. Please see " 52 | "https://www.tensorflow.org/install/ for installation instructions.") 53 | raise 54 | if len(sys.argv) < 4 or len(sys.argv) > 5: 55 | # pylint: disable=line-too-long 56 | print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 57 | else: 58 | if 'ckpt' in sys.argv[2].lower(): 59 | TF_CHECKPOINT = sys.argv[2] 60 | TF_DATASET_FILE = "" 61 | else: 62 | TF_DATASET_FILE = sys.argv[2] 63 | TF_CHECKPOINT = "" 64 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 65 | if len(sys.argv) == 5: 66 | TF_CONFIG = sys.argv[4] 67 | else: 68 | TF_CONFIG = "" 69 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 70 | elif sys.argv[1] == "gpt2": 71 | try: 72 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 73 | except ImportError: 74 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 75 | "In that case, it requires TensorFlow to be installed. Please see " 76 | "https://www.tensorflow.org/install/ for installation instructions.") 77 | raise 78 | 79 | if len(sys.argv) < 4 or len(sys.argv) > 5: 80 | # pylint: disable=line-too-long 81 | print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 82 | else: 83 | TF_CHECKPOINT = sys.argv[2] 84 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 85 | if len(sys.argv) == 5: 86 | TF_CONFIG = sys.argv[4] 87 | else: 88 | TF_CONFIG = "" 89 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 90 | elif sys.argv[1] == "xlnet": 91 | try: 92 | from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch 93 | except ImportError: 94 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 95 | "In that case, it requires TensorFlow to be installed. Please see " 96 | "https://www.tensorflow.org/install/ for installation instructions.") 97 | raise 98 | 99 | if len(sys.argv) < 5 or len(sys.argv) > 6: 100 | # pylint: disable=line-too-long 101 | print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") 102 | else: 103 | TF_CHECKPOINT = sys.argv[2] 104 | TF_CONFIG = sys.argv[3] 105 | PYTORCH_DUMP_OUTPUT = sys.argv[4] 106 | if len(sys.argv) == 6: 107 | FINETUNING_TASK = sys.argv[5] 108 | else: 109 | FINETUNING_TASK = None 110 | 111 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, 112 | TF_CONFIG, 113 | PYTORCH_DUMP_OUTPUT, 114 | FINETUNING_TASK) 115 | elif sys.argv[1] == "xlm": 116 | from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch 117 | 118 | if len(sys.argv) != 4: 119 | # pylint: disable=line-too-long 120 | print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") 121 | else: 122 | XLM_CHECKPOINT_PATH = sys.argv[2] 123 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 124 | 125 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .configuration_bert import BertConfig 22 | from .configuration_openai import OpenAIGPTConfig 23 | from .configuration_gpt2 import GPT2Config 24 | from .configuration_transfo_xl import TransfoXLConfig 25 | from .configuration_xlnet import XLNetConfig 26 | from .configuration_xlm import XLMConfig 27 | from .configuration_roberta import RobertaConfig 28 | from .configuration_distilbert import DistilBertConfig 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class AutoConfig(object): 34 | r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class 35 | that will be instantiated as one of the configuration classes of the library 36 | when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` 37 | class method. 38 | 39 | The `from_pretrained()` method take care of returning the correct model class instance 40 | using pattern matching on the `pretrained_model_name_or_path` string. 41 | 42 | The base model class to instantiate is selected as the first pattern matching 43 | in the `pretrained_model_name_or_path` string (in the following order): 44 | - contains `distilbert`: DistilBertConfig (DistilBERT model) 45 | - contains `bert`: BertConfig (Bert model) 46 | - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) 47 | - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) 48 | - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) 49 | - contains `xlnet`: XLNetConfig (XLNet model) 50 | - contains `xlm`: XLMConfig (XLM model) 51 | - contains `roberta`: RobertaConfig (RoBERTa model) 52 | 53 | This class cannot be instantiated using `__init__()` (throw an error). 54 | """ 55 | def __init__(self): 56 | raise EnvironmentError("AutoConfig is designed to be instantiated " 57 | "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") 58 | 59 | @classmethod 60 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 61 | r""" Instantiate a one of the configuration classes of the library 62 | from a pre-trained model configuration. 63 | 64 | The configuration class to instantiate is selected as the first pattern matching 65 | in the `pretrained_model_name_or_path` string (in the following order): 66 | - contains `distilbert`: DistilBertConfig (DistilBERT model) 67 | - contains `bert`: BertConfig (Bert model) 68 | - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) 69 | - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) 70 | - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) 71 | - contains `xlnet`: XLNetConfig (XLNet model) 72 | - contains `xlm`: XLMConfig (XLM model) 73 | - contains `roberta`: RobertaConfig (RoBERTa model) 74 | 75 | Params: 76 | pretrained_model_name_or_path: either: 77 | 78 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 79 | - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 80 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 81 | 82 | cache_dir: (`optional`) string: 83 | Path to a directory in which a downloaded pre-trained model 84 | configuration should be cached if the standard cache should not be used. 85 | 86 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 87 | 88 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 89 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 90 | 91 | force_download: (`optional`) boolean, default False: 92 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 93 | 94 | proxies: (`optional`) dict, default None: 95 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 96 | The proxies are used on each request. 97 | 98 | return_unused_kwargs: (`optional`) bool: 99 | 100 | - If False, then this function returns just the final configuration object. 101 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 102 | 103 | Examples:: 104 | 105 | config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 106 | config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 107 | config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json') 108 | config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 109 | assert config.output_attention == True 110 | config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, 111 | foo=False, return_unused_kwargs=True) 112 | assert config.output_attention == True 113 | assert unused_kwargs == {'foo': False} 114 | 115 | """ 116 | if 'distilbert' in pretrained_model_name_or_path: 117 | return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 118 | elif 'roberta' in pretrained_model_name_or_path: 119 | return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 120 | elif 'bert' in pretrained_model_name_or_path: 121 | return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 122 | elif 'openai-gpt' in pretrained_model_name_or_path: 123 | return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 124 | elif 'gpt2' in pretrained_model_name_or_path: 125 | return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) 126 | elif 'transfo-xl' in pretrained_model_name_or_path: 127 | return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 128 | elif 'xlnet' in pretrained_model_name_or_path: 129 | return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 130 | elif 'xlm' in pretrained_model_name_or_path: 131 | return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 132 | 133 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 134 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " 135 | "'xlm', 'roberta'".format(pretrained_model_name_or_path)) 136 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | } 44 | 45 | 46 | class BertConfig(PretrainedConfig): 47 | r""" 48 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a 49 | `BertModel`. 50 | 51 | 52 | Arguments: 53 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 54 | hidden_size: Size of the encoder layers and the pooler layer. 55 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 56 | num_attention_heads: Number of attention heads for each attention layer in 57 | the Transformer encoder. 58 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 59 | layer in the Transformer encoder. 60 | hidden_act: The non-linear activation function (function or string) in the 61 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 62 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 63 | layers in the embeddings, encoder, and pooler. 64 | attention_probs_dropout_prob: The dropout ratio for the attention 65 | probabilities. 66 | max_position_embeddings: The maximum sequence length that this model might 67 | ever be used with. Typically set this to something large just in case 68 | (e.g., 512 or 1024 or 2048). 69 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 70 | `BertModel`. 71 | initializer_range: The sttdev of the truncated_normal_initializer for 72 | initializing all weight matrices. 73 | layer_norm_eps: The epsilon used by LayerNorm. 74 | """ 75 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 76 | 77 | def __init__(self, 78 | vocab_size_or_config_json_file=30522, 79 | hidden_size=768, 80 | num_hidden_layers=12, 81 | num_attention_heads=12, 82 | intermediate_size=3072, 83 | hidden_act="gelu", 84 | hidden_dropout_prob=0.1, 85 | attention_probs_dropout_prob=0.1, 86 | max_position_embeddings=512, 87 | type_vocab_size=2, 88 | initializer_range=0.02, 89 | layer_norm_eps=1e-12, 90 | **kwargs): 91 | super(BertConfig, self).__init__(**kwargs) 92 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 93 | and isinstance(vocab_size_or_config_json_file, unicode)): 94 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 95 | json_config = json.loads(reader.read()) 96 | for key, value in json_config.items(): 97 | self.__dict__[key] = value 98 | elif isinstance(vocab_size_or_config_json_file, int): 99 | self.vocab_size = vocab_size_or_config_json_file 100 | self.hidden_size = hidden_size 101 | self.num_hidden_layers = num_hidden_layers 102 | self.num_attention_heads = num_attention_heads 103 | self.hidden_act = hidden_act 104 | self.intermediate_size = intermediate_size 105 | self.hidden_dropout_prob = hidden_dropout_prob 106 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 107 | self.max_position_embeddings = max_position_embeddings 108 | self.type_vocab_size = type_vocab_size 109 | self.initializer_range = initializer_range 110 | self.layer_norm_eps = layer_norm_eps 111 | else: 112 | raise ValueError("First argument must be either a vocabulary size (int)" 113 | " or the path to a pretrained model config file (str)") 114 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ DistilBERT model configuration """ 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 30 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json" 31 | } 32 | 33 | 34 | class DistilBertConfig(PretrainedConfig): 35 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | 37 | def __init__(self, 38 | vocab_size_or_config_json_file=30522, 39 | max_position_embeddings=512, 40 | sinusoidal_pos_embds=True, 41 | n_layers=6, 42 | n_heads=12, 43 | dim=768, 44 | hidden_dim=4*768, 45 | dropout=0.1, 46 | attention_dropout=0.1, 47 | activation='gelu', 48 | initializer_range=0.02, 49 | tie_weights_=True, 50 | qa_dropout=0.1, 51 | seq_classif_dropout=0.2, 52 | **kwargs): 53 | super(DistilBertConfig, self).__init__(**kwargs) 54 | 55 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 56 | and isinstance(vocab_size_or_config_json_file, unicode)): 57 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 58 | json_config = json.loads(reader.read()) 59 | for key, value in json_config.items(): 60 | self.__dict__[key] = value 61 | elif isinstance(vocab_size_or_config_json_file, int): 62 | self.vocab_size = vocab_size_or_config_json_file 63 | self.max_position_embeddings = max_position_embeddings 64 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 65 | self.n_layers = n_layers 66 | self.n_heads = n_heads 67 | self.dim = dim 68 | self.hidden_dim = hidden_dim 69 | self.dropout = dropout 70 | self.attention_dropout = attention_dropout 71 | self.activation = activation 72 | self.initializer_range = initializer_range 73 | self.tie_weights_ = tie_weights_ 74 | self.qa_dropout = qa_dropout 75 | self.seq_classif_dropout = seq_classif_dropout 76 | else: 77 | raise ValueError("First argument must be either a vocabulary size (int)" 78 | " or the path to a pretrained model config file (str)") 79 | @property 80 | def hidden_size(self): 81 | return self.dim 82 | 83 | @property 84 | def num_attention_heads(self): 85 | return self.n_heads 86 | 87 | @property 88 | def num_hidden_layers(self): 89 | return self.n_layers 90 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"} 32 | 33 | class GPT2Config(PretrainedConfig): 34 | """Configuration class to store the configuration of a `GPT2Model`. 35 | 36 | Args: 37 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 38 | n_positions: Number of positional embeddings. 39 | n_ctx: Size of the causal mask (usually same as n_positions). 40 | n_embd: Dimensionality of the embeddings and hidden states. 41 | n_layer: Number of hidden layers in the Transformer encoder. 42 | n_head: Number of attention heads for each attention layer in 43 | the Transformer encoder. 44 | layer_norm_epsilon: epsilon to use in the layer norm layers 45 | resid_pdrop: The dropout probabilitiy for all fully connected 46 | layers in the embeddings, encoder, and pooler. 47 | attn_pdrop: The dropout ratio for the attention 48 | probabilities. 49 | embd_pdrop: The dropout ratio for the embeddings. 50 | initializer_range: The sttdev of the truncated_normal_initializer for 51 | initializing all weight matrices. 52 | """ 53 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 54 | 55 | def __init__( 56 | self, 57 | vocab_size_or_config_json_file=50257, 58 | n_positions=1024, 59 | n_ctx=1024, 60 | n_embd=768, 61 | n_layer=12, 62 | n_head=12, 63 | resid_pdrop=0.1, 64 | embd_pdrop=0.1, 65 | attn_pdrop=0.1, 66 | layer_norm_epsilon=1e-5, 67 | initializer_range=0.02, 68 | 69 | num_labels=1, 70 | summary_type='cls_index', 71 | summary_use_proj=True, 72 | summary_activation=None, 73 | summary_proj_to_labels=True, 74 | summary_first_dropout=0.1, 75 | **kwargs 76 | ): 77 | """Constructs GPT2Config. 78 | 79 | Args: 80 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 81 | n_positions: Number of positional embeddings. 82 | n_ctx: Size of the causal mask (usually same as n_positions). 83 | n_embd: Dimensionality of the embeddings and hidden states. 84 | n_layer: Number of hidden layers in the Transformer encoder. 85 | n_head: Number of attention heads for each attention layer in 86 | the Transformer encoder. 87 | layer_norm_epsilon: epsilon to use in the layer norm layers 88 | resid_pdrop: The dropout probabilitiy for all fully connected 89 | layers in the embeddings, encoder, and pooler. 90 | attn_pdrop: The dropout ratio for the attention 91 | probabilities. 92 | embd_pdrop: The dropout ratio for the embeddings. 93 | initializer_range: The sttdev of the truncated_normal_initializer for 94 | initializing all weight matrices. 95 | """ 96 | super(GPT2Config, self).__init__(**kwargs) 97 | 98 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 99 | and isinstance(vocab_size_or_config_json_file, unicode)): 100 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 101 | json_config = json.loads(reader.read()) 102 | for key, value in json_config.items(): 103 | self.__dict__[key] = value 104 | elif isinstance(vocab_size_or_config_json_file, int): 105 | self.vocab_size = vocab_size_or_config_json_file 106 | self.n_ctx = n_ctx 107 | self.n_positions = n_positions 108 | self.n_embd = n_embd 109 | self.n_layer = n_layer 110 | self.n_head = n_head 111 | self.resid_pdrop = resid_pdrop 112 | self.embd_pdrop = embd_pdrop 113 | self.attn_pdrop = attn_pdrop 114 | self.layer_norm_epsilon = layer_norm_epsilon 115 | self.initializer_range = initializer_range 116 | 117 | self.num_labels = num_labels 118 | self.summary_type = summary_type 119 | self.summary_use_proj = summary_use_proj 120 | self.summary_activation = summary_activation 121 | self.summary_first_dropout = summary_first_dropout 122 | self.summary_proj_to_labels = summary_proj_to_labels 123 | else: 124 | raise ValueError( 125 | "First argument must be either a vocabulary size (int)" 126 | "or the path to a pretrained model config file (str)" 127 | ) 128 | 129 | @property 130 | def max_position_embeddings(self): 131 | return self.n_positions 132 | 133 | @property 134 | def hidden_size(self): 135 | return self.n_embd 136 | 137 | @property 138 | def num_attention_heads(self): 139 | return self.n_head 140 | 141 | @property 142 | def num_hidden_layers(self): 143 | return self.n_layer 144 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 31 | } 32 | 33 | class OpenAIGPTConfig(PretrainedConfig): 34 | """ 35 | Configuration class to store the configuration of a `OpenAIGPTModel`. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. 39 | n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...) 40 | n_positions: Number of positional embeddings. 41 | n_ctx: Size of the causal mask (usually same as n_positions). 42 | n_embd: Dimensionality of the embeddings and hidden states. 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | afn: The non-linear activation function (function or string) in the 47 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 48 | resid_pdrop: The dropout probabilitiy for all fully connected 49 | layers in the embeddings, encoder, and pooler. 50 | attn_pdrop: The dropout ratio for the attention 51 | probabilities. 52 | embd_pdrop: The dropout ratio for the embeddings. 53 | layer_norm_epsilon: epsilon to use in the layer norm layers 54 | initializer_range: The sttdev of the truncated_normal_initializer for 55 | initializing all weight matrices. 56 | predict_special_tokens: should we predict special tokens (when the model has a LM head) 57 | """ 58 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 59 | 60 | def __init__( 61 | self, 62 | vocab_size_or_config_json_file=40478, 63 | n_positions=512, 64 | n_ctx=512, 65 | n_embd=768, 66 | n_layer=12, 67 | n_head=12, 68 | afn="gelu", 69 | resid_pdrop=0.1, 70 | embd_pdrop=0.1, 71 | attn_pdrop=0.1, 72 | layer_norm_epsilon=1e-5, 73 | initializer_range=0.02, 74 | predict_special_tokens=True, 75 | 76 | num_labels=1, 77 | summary_type='cls_index', 78 | summary_use_proj=True, 79 | summary_activation=None, 80 | summary_proj_to_labels=True, 81 | summary_first_dropout=0.1, 82 | **kwargs 83 | ): 84 | """Constructs OpenAIGPTConfig. 85 | """ 86 | super(OpenAIGPTConfig, self).__init__(**kwargs) 87 | 88 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 89 | and isinstance(vocab_size_or_config_json_file, unicode)): 90 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 91 | json_config = json.loads(reader.read()) 92 | for key, value in json_config.items(): 93 | self.__dict__[key] = value 94 | elif isinstance(vocab_size_or_config_json_file, int): 95 | self.vocab_size = vocab_size_or_config_json_file 96 | self.n_ctx = n_ctx 97 | self.n_positions = n_positions 98 | self.n_embd = n_embd 99 | self.n_layer = n_layer 100 | self.n_head = n_head 101 | self.afn = afn 102 | self.resid_pdrop = resid_pdrop 103 | self.embd_pdrop = embd_pdrop 104 | self.attn_pdrop = attn_pdrop 105 | self.layer_norm_epsilon = layer_norm_epsilon 106 | self.initializer_range = initializer_range 107 | self.predict_special_tokens = predict_special_tokens 108 | 109 | self.num_labels = num_labels 110 | self.summary_type = summary_type 111 | self.summary_use_proj = summary_use_proj 112 | self.summary_activation = summary_activation 113 | self.summary_first_dropout = summary_first_dropout 114 | self.summary_proj_to_labels = summary_proj_to_labels 115 | else: 116 | raise ValueError( 117 | "First argument must be either a vocabulary size (int)" 118 | "or the path to a pretrained model config file (str)" 119 | ) 120 | 121 | @property 122 | def max_position_embeddings(self): 123 | return self.n_positions 124 | 125 | @property 126 | def hidden_size(self): 127 | return self.n_embd 128 | 129 | @property 130 | def num_attention_heads(self): 131 | return self.n_head 132 | 133 | @property 134 | def num_hidden_layers(self): 135 | return self.n_layer 136 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_bert import BertConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 31 | } 32 | 33 | 34 | class RobertaConfig(BertConfig): 35 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Transformer XL configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 31 | } 32 | 33 | class TransfoXLConfig(PretrainedConfig): 34 | """Configuration class to store the configuration of a `TransfoXLModel`. 35 | 36 | Args: 37 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. 38 | cutoffs: cutoffs for the adaptive softmax 39 | d_model: Dimensionality of the model's hidden states. 40 | d_embed: Dimensionality of the embeddings 41 | d_head: Dimensionality of the model's heads. 42 | div_val: divident value for adapative input and softmax 43 | pre_lnorm: apply LayerNorm to the input instead of the output 44 | d_inner: Inner dimension in FF 45 | n_layer: Number of hidden layers in the Transformer encoder. 46 | n_head: Number of attention heads for each attention layer in 47 | the Transformer encoder. 48 | tgt_len: number of tokens to predict 49 | ext_len: length of the extended context 50 | mem_len: length of the retained previous heads 51 | same_length: use the same attn length for all tokens 52 | proj_share_all_but_first: True to share all but first projs, False not to share. 53 | attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 54 | clamp_len: use the same pos embeddings after clamp_len 55 | sample_softmax: number of samples in sampled softmax 56 | adaptive: use adaptive softmax 57 | tie_weight: tie the word embedding and softmax weights 58 | dropout: The dropout probabilitiy for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | dropatt: The dropout ratio for the attention probabilities. 61 | untie_r: untie relative position biases 62 | embd_pdrop: The dropout ratio for the embeddings. 63 | init: parameter initializer to use 64 | init_range: parameters initialized by U(-init_range, init_range). 65 | proj_init_std: parameters initialized by N(0, init_std) 66 | init_std: parameters initialized by N(0, init_std) 67 | """ 68 | pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP 69 | 70 | def __init__(self, 71 | vocab_size_or_config_json_file=267735, 72 | cutoffs=[20000, 40000, 200000], 73 | d_model=1024, 74 | d_embed=1024, 75 | n_head=16, 76 | d_head=64, 77 | d_inner=4096, 78 | div_val=4, 79 | pre_lnorm=False, 80 | n_layer=18, 81 | tgt_len=128, 82 | ext_len=0, 83 | mem_len=1600, 84 | clamp_len=1000, 85 | same_length=True, 86 | proj_share_all_but_first=True, 87 | attn_type=0, 88 | sample_softmax=-1, 89 | adaptive=True, 90 | tie_weight=True, 91 | dropout=0.1, 92 | dropatt=0.0, 93 | untie_r=True, 94 | init="normal", 95 | init_range=0.01, 96 | proj_init_std=0.01, 97 | init_std=0.02, 98 | **kwargs): 99 | """Constructs TransfoXLConfig. 100 | """ 101 | super(TransfoXLConfig, self).__init__(**kwargs) 102 | 103 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 104 | and isinstance(vocab_size_or_config_json_file, unicode)): 105 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 106 | json_config = json.loads(reader.read()) 107 | for key, value in json_config.items(): 108 | self.__dict__[key] = value 109 | elif isinstance(vocab_size_or_config_json_file, int): 110 | self.n_token = vocab_size_or_config_json_file 111 | self.cutoffs = [] 112 | self.cutoffs.extend(cutoffs) 113 | self.tie_weight = tie_weight 114 | if proj_share_all_but_first: 115 | self.tie_projs = [False] + [True] * len(self.cutoffs) 116 | else: 117 | self.tie_projs = [False] + [False] * len(self.cutoffs) 118 | self.d_model = d_model 119 | self.d_embed = d_embed 120 | self.d_head = d_head 121 | self.d_inner = d_inner 122 | self.div_val = div_val 123 | self.pre_lnorm = pre_lnorm 124 | self.n_layer = n_layer 125 | self.n_head = n_head 126 | self.tgt_len = tgt_len 127 | self.ext_len = ext_len 128 | self.mem_len = mem_len 129 | self.same_length = same_length 130 | self.attn_type = attn_type 131 | self.clamp_len = clamp_len 132 | self.sample_softmax = sample_softmax 133 | self.adaptive = adaptive 134 | self.dropout = dropout 135 | self.dropatt = dropatt 136 | self.untie_r = untie_r 137 | self.init = init 138 | self.init_range = init_range 139 | self.proj_init_std = proj_init_std 140 | self.init_std = init_std 141 | else: 142 | raise ValueError("First argument must be either a vocabulary size (int)" 143 | " or the path to a pretrained model config file (str)") 144 | 145 | @property 146 | def max_position_embeddings(self): 147 | return self.tgt_len + self.ext_len + self.mem_len 148 | 149 | @property 150 | def vocab_size(self): 151 | return self.n_token 152 | 153 | @vocab_size.setter 154 | def vocab_size(self, value): 155 | self.n_token = value 156 | 157 | @property 158 | def hidden_size(self): 159 | return self.d_model 160 | 161 | @property 162 | def num_attention_heads(self): 163 | return self.n_head 164 | 165 | @property 166 | def num_hidden_layers(self): 167 | return self.n_layer 168 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_xlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ XLM configuration """ 16 | from __future__ import absolute_import, division, print_function, unicode_literals 17 | 18 | import json 19 | import logging 20 | import sys 21 | from io import open 22 | 23 | from .configuration_utils import PretrainedConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 29 | 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json", 30 | 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json", 31 | 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json", 32 | 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json", 33 | 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", 34 | 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", 35 | 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", 36 | 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", 37 | 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json", 38 | } 39 | 40 | 41 | class XLMConfig(PretrainedConfig): 42 | """Configuration class to store the configuration of a `XLMModel`. 43 | 44 | Args: 45 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`. 46 | d_model: Size of the encoder layers and the pooler layer. 47 | n_layer: Number of hidden layers in the Transformer encoder. 48 | n_head: Number of attention heads for each attention layer in 49 | the Transformer encoder. 50 | d_inner: The size of the "intermediate" (i.e., feed-forward) 51 | layer in the Transformer encoder. 52 | ff_activation: The non-linear activation function (function or string) in the 53 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 54 | untie_r: untie relative position biases 55 | attn_type: 'bi' for XLM, 'uni' for Transformer-XL 56 | 57 | dropout: The dropout probabilitiy for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | dropatt: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | initializer_range: The sttdev of the truncated_normal_initializer for 65 | initializing all weight matrices. 66 | layer_norm_eps: The epsilon used by LayerNorm. 67 | 68 | dropout: float, dropout rate. 69 | dropatt: float, dropout rate on attention probabilities. 70 | init: str, the initialization scheme, either "normal" or "uniform". 71 | init_range: float, initialize the parameters with a uniform distribution 72 | in [-init_range, init_range]. Only effective when init="uniform". 73 | init_std: float, initialize the parameters with a normal distribution 74 | with mean 0 and stddev init_std. Only effective when init="normal". 75 | mem_len: int, the number of tokens to cache. 76 | reuse_len: int, the number of tokens in the currect batch to be cached 77 | and reused in the future. 78 | bi_data: bool, whether to use bidirectional input pipeline. 79 | Usually set to True during pretraining and False during finetuning. 80 | clamp_len: int, clamp all relative distances larger than clamp_len. 81 | -1 means no clamping. 82 | same_length: bool, whether to use the same attention length for each token. 83 | """ 84 | pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP 85 | 86 | def __init__(self, 87 | vocab_size_or_config_json_file=30145, 88 | emb_dim=2048, 89 | n_layers=12, 90 | n_heads=16, 91 | dropout=0.1, 92 | attention_dropout=0.1, 93 | gelu_activation=True, 94 | sinusoidal_embeddings=False, 95 | causal=False, 96 | asm=False, 97 | n_langs=1, 98 | use_lang_emb=True, 99 | max_position_embeddings=512, 100 | embed_init_std=2048 ** -0.5, 101 | layer_norm_eps=1e-12, 102 | init_std=0.02, 103 | bos_index=0, 104 | eos_index=1, 105 | pad_index=2, 106 | unk_index=3, 107 | mask_index=5, 108 | is_encoder=True, 109 | 110 | finetuning_task=None, 111 | num_labels=2, 112 | summary_type='first', 113 | summary_use_proj=True, 114 | summary_activation=None, 115 | summary_proj_to_labels=True, 116 | summary_first_dropout=0.1, 117 | start_n_top=5, 118 | end_n_top=5, 119 | **kwargs): 120 | """Constructs XLMConfig. 121 | """ 122 | super(XLMConfig, self).__init__(**kwargs) 123 | 124 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 125 | and isinstance(vocab_size_or_config_json_file, unicode)): 126 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 127 | json_config = json.loads(reader.read()) 128 | for key, value in json_config.items(): 129 | self.__dict__[key] = value 130 | elif isinstance(vocab_size_or_config_json_file, int): 131 | self.n_words = vocab_size_or_config_json_file 132 | self.emb_dim = emb_dim 133 | self.n_layers = n_layers 134 | self.n_heads = n_heads 135 | self.dropout = dropout 136 | self.attention_dropout = attention_dropout 137 | self.gelu_activation = gelu_activation 138 | self.sinusoidal_embeddings = sinusoidal_embeddings 139 | self.causal = causal 140 | self.asm = asm 141 | self.n_langs = n_langs 142 | self.use_lang_emb = use_lang_emb 143 | self.layer_norm_eps = layer_norm_eps 144 | self.bos_index = bos_index 145 | self.eos_index = eos_index 146 | self.pad_index = pad_index 147 | self.unk_index = unk_index 148 | self.mask_index = mask_index 149 | self.is_encoder = is_encoder 150 | self.max_position_embeddings = max_position_embeddings 151 | self.embed_init_std = embed_init_std 152 | self.init_std = init_std 153 | self.finetuning_task = finetuning_task 154 | self.num_labels = num_labels 155 | self.summary_type = summary_type 156 | self.summary_use_proj = summary_use_proj 157 | self.summary_activation = summary_activation 158 | self.summary_proj_to_labels = summary_proj_to_labels 159 | self.summary_first_dropout = summary_first_dropout 160 | self.start_n_top = start_n_top 161 | self.end_n_top = end_n_top 162 | else: 163 | raise ValueError("First argument must be either a vocabulary size (int)" 164 | " or the path to a pretrained model config file (str)") 165 | 166 | @property 167 | def vocab_size(self): 168 | return self.n_words 169 | 170 | @vocab_size.setter 171 | def vocab_size(self, value): 172 | self.n_words = value 173 | 174 | @property 175 | def hidden_size(self): 176 | return self.emb_dim 177 | 178 | @property 179 | def num_attention_heads(self): 180 | return self.n_heads 181 | 182 | @property 183 | def num_hidden_layers(self): 184 | return self.n_layers 185 | -------------------------------------------------------------------------------- /pytorch_transformers/configuration_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XLNet configuration """ 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import sys 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", 30 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 31 | } 32 | 33 | 34 | class XLNetConfig(PretrainedConfig): 35 | """Configuration class to store the configuration of a ``XLNetModel``. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``. 39 | d_model: Size of the encoder layers and the pooler layer. 40 | n_layer: Number of hidden layers in the Transformer encoder. 41 | n_head: Number of attention heads for each attention layer in 42 | the Transformer encoder. 43 | d_inner: The size of the "intermediate" (i.e., feed-forward) 44 | layer in the Transformer encoder. 45 | ff_activation: The non-linear activation function (function or string) in the 46 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 47 | untie_r: untie relative position biases 48 | attn_type: 'bi' for XLNet, 'uni' for Transformer-XL 49 | 50 | dropout: The dropout probabilitiy for all fully connected 51 | layers in the embeddings, encoder, and pooler. 52 | dropatt: The dropout ratio for the attention 53 | probabilities. 54 | initializer_range: The sttdev of the truncated_normal_initializer for 55 | initializing all weight matrices. 56 | layer_norm_eps: The epsilon used by LayerNorm. 57 | 58 | dropout: float, dropout rate. 59 | dropatt: float, dropout rate on attention probabilities. 60 | init: str, the initialization scheme, either "normal" or "uniform". 61 | init_range: float, initialize the parameters with a uniform distribution 62 | in [-init_range, init_range]. Only effective when init="uniform". 63 | init_std: float, initialize the parameters with a normal distribution 64 | with mean 0 and stddev init_std. Only effective when init="normal". 65 | mem_len: int, the number of tokens to cache. 66 | reuse_len: int, the number of tokens in the currect batch to be cached 67 | and reused in the future. 68 | bi_data: bool, whether to use bidirectional input pipeline. 69 | Usually set to True during pretraining and False during finetuning. 70 | clamp_len: int, clamp all relative distances larger than clamp_len. 71 | -1 means no clamping. 72 | same_length: bool, whether to use the same attention length for each token. 73 | finetuning_task: name of the glue task on which the model was fine-tuned if any 74 | """ 75 | pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP 76 | 77 | def __init__(self, 78 | vocab_size_or_config_json_file=32000, 79 | d_model=1024, 80 | n_layer=24, 81 | n_head=16, 82 | d_inner=4096, 83 | ff_activation="gelu", 84 | untie_r=True, 85 | attn_type="bi", 86 | 87 | initializer_range=0.02, 88 | layer_norm_eps=1e-12, 89 | 90 | dropout=0.1, 91 | mem_len=None, 92 | reuse_len=None, 93 | bi_data=False, 94 | clamp_len=-1, 95 | same_length=False, 96 | 97 | finetuning_task=None, 98 | num_labels=2, 99 | summary_type='last', 100 | summary_use_proj=True, 101 | summary_activation='tanh', 102 | summary_last_dropout=0.1, 103 | start_n_top=5, 104 | end_n_top=5, 105 | **kwargs): 106 | """Constructs XLNetConfig. 107 | """ 108 | super(XLNetConfig, self).__init__(**kwargs) 109 | 110 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 111 | and isinstance(vocab_size_or_config_json_file, unicode)): 112 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 113 | json_config = json.loads(reader.read()) 114 | for key, value in json_config.items(): 115 | self.__dict__[key] = value 116 | elif isinstance(vocab_size_or_config_json_file, int): 117 | self.n_token = vocab_size_or_config_json_file 118 | self.d_model = d_model 119 | self.n_layer = n_layer 120 | self.n_head = n_head 121 | assert d_model % n_head == 0 122 | self.d_head = d_model // n_head 123 | self.ff_activation = ff_activation 124 | self.d_inner = d_inner 125 | self.untie_r = untie_r 126 | self.attn_type = attn_type 127 | 128 | self.initializer_range = initializer_range 129 | self.layer_norm_eps = layer_norm_eps 130 | 131 | self.dropout = dropout 132 | self.mem_len = mem_len 133 | self.reuse_len = reuse_len 134 | self.bi_data = bi_data 135 | self.clamp_len = clamp_len 136 | self.same_length = same_length 137 | 138 | self.finetuning_task = finetuning_task 139 | self.num_labels = num_labels 140 | self.summary_type = summary_type 141 | self.summary_use_proj = summary_use_proj 142 | self.summary_activation = summary_activation 143 | self.summary_last_dropout = summary_last_dropout 144 | self.start_n_top = start_n_top 145 | self.end_n_top = end_n_top 146 | else: 147 | raise ValueError("First argument must be either a vocabulary size (int)" 148 | " or the path to a pretrained model config file (str)") 149 | 150 | @property 151 | def max_position_embeddings(self): 152 | return -1 153 | 154 | @property 155 | def vocab_size(self): 156 | return self.n_token 157 | 158 | @vocab_size.setter 159 | def vocab_size(self, value): 160 | self.n_token = value 161 | 162 | @property 163 | def hidden_size(self): 164 | return self.d_model 165 | 166 | @property 167 | def num_attention_heads(self): 168 | return self.n_head 169 | 170 | @property 171 | def num_hidden_layers(self): 172 | return self.n_layer 173 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_pytorch_checkpoint_to_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from pytorch_transformers import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert RoBERTa checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import logging 21 | import numpy as np 22 | import torch 23 | 24 | from fairseq.models.roberta import RobertaModel as FairseqRobertaModel 25 | from fairseq.modules import TransformerSentenceEncoderLayer 26 | from pytorch_transformers import (BertConfig, BertEncoder, 27 | BertIntermediate, BertLayer, 28 | BertModel, BertOutput, 29 | BertSelfAttention, 30 | BertSelfOutput) 31 | from pytorch_transformers import (RobertaEmbeddings, 32 | RobertaForMaskedLM, 33 | RobertaForSequenceClassification, 34 | RobertaModel) 35 | 36 | logging.basicConfig(level=logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | SAMPLE_TEXT = 'Hello world! cécé herlolip' 40 | 41 | 42 | def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head): 43 | """ 44 | Copy/paste/tweak roberta's weights to our BERT structure. 45 | """ 46 | roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) 47 | roberta.eval() # disable dropout 48 | config = BertConfig( 49 | vocab_size_or_config_json_file=50265, 50 | hidden_size=roberta.args.encoder_embed_dim, 51 | num_hidden_layers=roberta.args.encoder_layers, 52 | num_attention_heads=roberta.args.encoder_attention_heads, 53 | intermediate_size=roberta.args.encoder_ffn_embed_dim, 54 | max_position_embeddings=514, 55 | type_vocab_size=1, 56 | layer_norm_eps=1e-5, # PyTorch default used in fairseq 57 | ) 58 | if classification_head: 59 | config.num_labels = roberta.args.num_classes 60 | print("Our BERT config:", config) 61 | 62 | model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config) 63 | model.eval() 64 | 65 | # Now let's copy all the weights. 66 | # Embeddings 67 | roberta_sent_encoder = roberta.model.decoder.sentence_encoder 68 | model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight 69 | model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight 70 | model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them. 71 | model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight 72 | model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias 73 | 74 | for i in range(config.num_hidden_layers): 75 | # Encoder: start of layer 76 | layer: BertLayer = model.roberta.encoder.layer[i] 77 | roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] 78 | 79 | ### self attention 80 | self_attn: BertSelfAttention = layer.attention.self 81 | assert( 82 | roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size)) 83 | ) 84 | # we use three distinct linear layers so we split the source layer here. 85 | self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :] 86 | self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size] 87 | self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :] 88 | self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size] 89 | self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :] 90 | self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:] 91 | 92 | ### self-attention output 93 | self_output: BertSelfOutput = layer.attention.output 94 | assert( 95 | self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape 96 | ) 97 | self_output.dense.weight = roberta_layer.self_attn.out_proj.weight 98 | self_output.dense.bias = roberta_layer.self_attn.out_proj.bias 99 | self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight 100 | self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias 101 | 102 | ### intermediate 103 | intermediate: BertIntermediate = layer.intermediate 104 | assert( 105 | intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape 106 | ) 107 | intermediate.dense.weight = roberta_layer.fc1.weight 108 | intermediate.dense.bias = roberta_layer.fc1.bias 109 | 110 | ### output 111 | bert_output: BertOutput = layer.output 112 | assert( 113 | bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape 114 | ) 115 | bert_output.dense.weight = roberta_layer.fc2.weight 116 | bert_output.dense.bias = roberta_layer.fc2.bias 117 | bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight 118 | bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias 119 | #### end of layer 120 | 121 | if classification_head: 122 | model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight 123 | model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias 124 | model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight 125 | model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias 126 | else: 127 | # LM Head 128 | model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight 129 | model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias 130 | model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight 131 | model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias 132 | model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight 133 | model.lm_head.bias = roberta.model.decoder.lm_head.bias 134 | 135 | # Let's check that we get the same results. 136 | input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 137 | 138 | our_output = model(input_ids)[0] 139 | if classification_head: 140 | their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids)) 141 | else: 142 | their_output = roberta.model(input_ids)[0] 143 | print(our_output.shape, their_output.shape) 144 | max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() 145 | print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 146 | success = torch.allclose(our_output, their_output, atol=1e-3) 147 | print( 148 | "Do both models output the same tensors?", 149 | "🔥" if success else "💩" 150 | ) 151 | if not success: 152 | raise Exception("Something went wRoNg") 153 | 154 | print(f"Saving model to {pytorch_dump_folder_path}") 155 | model.save_pretrained(pytorch_dump_folder_path) 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | ## Required parameters 161 | parser.add_argument("--roberta_checkpoint_path", 162 | default = None, 163 | type = str, 164 | required = True, 165 | help = "Path the official PyTorch dump.") 166 | parser.add_argument("--pytorch_dump_folder_path", 167 | default = None, 168 | type = str, 169 | required = True, 170 | help = "Path to the output PyTorch model.") 171 | parser.add_argument("--classification_head", 172 | action = "store_true", 173 | help = "Whether to convert a final classification head.") 174 | args = parser.parse_args() 175 | convert_roberta_checkpoint_to_pytorch( 176 | args.roberta_checkpoint_path, 177 | args.pytorch_dump_folder_path, 178 | args.classification_head 179 | ) 180 | 181 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from pytorch_transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_transformers.tokenization_transfo_xl as data_utils 27 | 28 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from pytorch_transformers import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | model = chkpt['model'] 37 | 38 | config = chkpt['params'] 39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 40 | 41 | vocab = chkpt['dico_word2id'] 42 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 48 | 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model, pytorch_weights_dump_path) 51 | 52 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 54 | f.write(json.dumps(config, indent=2) + "\n") 55 | 56 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(vocab, indent=2) + "\n") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | ## Required parameters 64 | parser.add_argument("--xlm_checkpoint_path", 65 | default = None, 66 | type = str, 67 | required = True, 68 | help = "Path the official PyTorch dump.") 69 | parser.add_argument("--pytorch_dump_folder_path", 70 | default = None, 71 | type = str, 72 | required = True, 73 | help = "Path to the output PyTorch model.") 74 | args = parser.parse_args() 75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /pytorch_transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | self.warmup_steps = warmup_steps 40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 41 | 42 | def lr_lambda(self, step): 43 | if step < self.warmup_steps: 44 | return float(step) / float(max(1.0, self.warmup_steps)) 45 | return 1. 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | self.t_total = t_total 56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1, self.warmup_steps)) 61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """ Linear warmup and then cosine decay. 66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 75 | 76 | def lr_lambda(self, step): 77 | if step < self.warmup_steps: 78 | return float(step) / float(max(1.0, self.warmup_steps)) 79 | # progress after warmup 80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 82 | 83 | 84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 85 | """ Linear warmup and then cosine cycles with hard restarts. 86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 88 | learning rate (with hard restarts). 89 | """ 90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 91 | self.warmup_steps = warmup_steps 92 | self.t_total = t_total 93 | self.cycles = cycles 94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 95 | 96 | def lr_lambda(self, step): 97 | if step < self.warmup_steps: 98 | return float(step) / float(max(1, self.warmup_steps)) 99 | # progress after warmup 100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 101 | if progress >= 1.0: 102 | return 0.0 103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) 104 | 105 | 106 | 107 | class AdamW(Optimizer): 108 | """ Implements Adam algorithm with weight decay fix. 109 | 110 | Parameters: 111 | lr (float): learning rate. Default 1e-3. 112 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 113 | eps (float): Adams epsilon. Default: 1e-6 114 | weight_decay (float): Weight decay. Default: 0.0 115 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 116 | """ 117 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 118 | if lr < 0.0: 119 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 120 | if not 0.0 <= betas[0] < 1.0: 121 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 122 | if not 0.0 <= betas[1] < 1.0: 123 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 124 | if not 0.0 <= eps: 125 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 126 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 127 | correct_bias=correct_bias) 128 | super(AdamW, self).__init__(params, defaults) 129 | 130 | def step(self, closure=None): 131 | """Performs a single optimization step. 132 | 133 | Arguments: 134 | closure (callable, optional): A closure that reevaluates the model 135 | and returns the loss. 136 | """ 137 | loss = None 138 | if closure is not None: 139 | loss = closure() 140 | 141 | for group in self.param_groups: 142 | for p in group['params']: 143 | if p.grad is None: 144 | continue 145 | grad = p.grad.data 146 | if grad.is_sparse: 147 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 148 | 149 | state = self.state[p] 150 | 151 | # State initialization 152 | if len(state) == 0: 153 | state['step'] = 0 154 | # Exponential moving average of gradient values 155 | state['exp_avg'] = torch.zeros_like(p.data) 156 | # Exponential moving average of squared gradient values 157 | state['exp_avg_sq'] = torch.zeros_like(p.data) 158 | 159 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 160 | beta1, beta2 = group['betas'] 161 | 162 | state['step'] += 1 163 | 164 | # Decay the first and second moment running average coefficient 165 | # In-place operations to update the averages at the same time 166 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 167 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 168 | denom = exp_avg_sq.sqrt().add_(group['eps']) 169 | 170 | step_size = group['lr'] 171 | if group['correct_bias']: # No bias correction for Bert 172 | bias_correction1 = 1.0 - beta1 ** state['step'] 173 | bias_correction2 = 1.0 - beta2 ** state['step'] 174 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 175 | 176 | p.data.addcdiv_(-step_size, exp_avg, denom) 177 | 178 | # Just adding the square of the weights to the loss function is *not* 179 | # the correct way of using L2 regularization/weight decay with Adam, 180 | # since that will interact with the m and v parameters in strange ways. 181 | # 182 | # Instead we want to decay the weights in a manner that doesn't interact 183 | # with the m/v parameters. This is equivalent to adding the square 184 | # of the weights to the loss with plain (non-momentum) SGD. 185 | # Add weight decay at the end (fixed version) 186 | if group['weight_decay'] > 0.0: 187 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 188 | 189 | return loss 190 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/pytorch_transformers/tests/__init__.py -------------------------------------------------------------------------------- /pytorch_transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import copy 20 | import os 21 | import shutil 22 | import json 23 | import random 24 | import uuid 25 | 26 | import unittest 27 | import logging 28 | 29 | 30 | class ConfigTester(object): 31 | def __init__(self, parent, config_class=None, **kwargs): 32 | self.parent = parent 33 | self.config_class = config_class 34 | self.inputs_dict = kwargs 35 | 36 | def create_and_test_config_common_properties(self): 37 | config = self.config_class(**self.inputs_dict) 38 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 39 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 42 | 43 | def create_and_test_config_to_json_string(self): 44 | config = self.config_class(**self.inputs_dict) 45 | obj = json.loads(config.to_json_string()) 46 | for key, value in self.inputs_dict.items(): 47 | self.parent.assertEqual(obj[key], value) 48 | 49 | def create_and_test_config_to_json_file(self): 50 | config_first = self.config_class(**self.inputs_dict) 51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") 52 | config_first.to_json_file(json_file_path) 53 | config_second = self.config_class.from_json_file(json_file_path) 54 | os.remove(json_file_path) 55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 56 | 57 | def run_common_tests(self): 58 | self.create_and_test_config_common_properties() 59 | self.create_and_test_config_to_json_string() 60 | self.create_and_test_config_to_json_file() 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /pytorch_transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/pytorch_transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import (AutoConfig, BertConfig, 25 | AutoModel, BertModel, 26 | AutoModelWithLMHead, BertForMaskedLM, 27 | AutoModelForSequenceClassification, BertForSequenceClassification, 28 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 29 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 30 | 31 | from .modeling_common_test import (CommonTestCases, ids_tensor) 32 | from .configuration_common_test import ConfigTester 33 | 34 | 35 | class AutoModelTest(unittest.TestCase): 36 | def test_model_from_pretrained(self): 37 | logging.basicConfig(level=logging.INFO) 38 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 39 | config = AutoConfig.from_pretrained(model_name) 40 | self.assertIsNotNone(config) 41 | self.assertIsInstance(config, BertConfig) 42 | 43 | model = AutoModel.from_pretrained(model_name) 44 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 45 | self.assertIsNotNone(model) 46 | self.assertIsInstance(model, BertModel) 47 | for value in loading_info.values(): 48 | self.assertEqual(len(value), 0) 49 | 50 | def test_lmhead_model_from_pretrained(self): 51 | logging.basicConfig(level=logging.INFO) 52 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 53 | config = AutoConfig.from_pretrained(model_name) 54 | self.assertIsNotNone(config) 55 | self.assertIsInstance(config, BertConfig) 56 | 57 | model = AutoModelWithLMHead.from_pretrained(model_name) 58 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 59 | self.assertIsNotNone(model) 60 | self.assertIsInstance(model, BertForMaskedLM) 61 | 62 | def test_sequence_classification_model_from_pretrained(self): 63 | logging.basicConfig(level=logging.INFO) 64 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 65 | config = AutoConfig.from_pretrained(model_name) 66 | self.assertIsNotNone(config) 67 | self.assertIsInstance(config, BertConfig) 68 | 69 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 70 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 71 | self.assertIsNotNone(model) 72 | self.assertIsInstance(model, BertForSequenceClassification) 73 | 74 | def test_question_answering_model_from_pretrained(self): 75 | logging.basicConfig(level=logging.INFO) 76 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 77 | config = AutoConfig.from_pretrained(model_name) 78 | self.assertIsNotNone(config) 79 | self.assertIsInstance(config, BertConfig) 80 | 81 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 82 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 83 | self.assertIsNotNone(model) 84 | self.assertIsInstance(model, BertForQuestionAnswering) 85 | 86 | 87 | if __name__ == "__main__": 88 | unittest.main() 89 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, 25 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 26 | 27 | from .tokenization_tests_commons import TemporaryDirectory 28 | 29 | 30 | def unwrap_schedule(scheduler, num_steps=10): 31 | lrs = [] 32 | for _ in range(num_steps): 33 | scheduler.step() 34 | lrs.append(scheduler.get_lr()) 35 | return lrs 36 | 37 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 38 | lrs = [] 39 | for step in range(num_steps): 40 | scheduler.step() 41 | lrs.append(scheduler.get_lr()) 42 | if step == num_steps // 2: 43 | with TemporaryDirectory() as tmpdirname: 44 | file_name = os.path.join(tmpdirname, 'schedule.bin') 45 | torch.save(scheduler.state_dict(), file_name) 46 | 47 | state_dict = torch.load(file_name) 48 | scheduler.load_state_dict(state_dict) 49 | return lrs 50 | 51 | class OptimizationTest(unittest.TestCase): 52 | 53 | def assertListAlmostEqual(self, list1, list2, tol): 54 | self.assertEqual(len(list1), len(list2)) 55 | for a, b in zip(list1, list2): 56 | self.assertAlmostEqual(a, b, delta=tol) 57 | 58 | def test_adam_w(self): 59 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 60 | target = torch.tensor([0.4, 0.2, -0.5]) 61 | criterion = torch.nn.MSELoss() 62 | # No warmup, constant schedule, no gradient clipping 63 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 64 | for _ in range(100): 65 | loss = criterion(w, target) 66 | loss.backward() 67 | optimizer.step() 68 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 69 | w.grad.zero_() 70 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 71 | 72 | 73 | class ScheduleInitTest(unittest.TestCase): 74 | m = torch.nn.Linear(50, 50) 75 | optimizer = AdamW(m.parameters(), lr=10.) 76 | num_steps = 10 77 | 78 | def assertListAlmostEqual(self, list1, list2, tol): 79 | self.assertEqual(len(list1), len(list2)) 80 | for a, b in zip(list1, list2): 81 | self.assertAlmostEqual(a, b, delta=tol) 82 | 83 | def test_constant_scheduler(self): 84 | scheduler = ConstantLRSchedule(self.optimizer) 85 | lrs = unwrap_schedule(scheduler, self.num_steps) 86 | expected_learning_rates = [10.] * self.num_steps 87 | self.assertEqual(len(lrs[0]), 1) 88 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 89 | 90 | scheduler = ConstantLRSchedule(self.optimizer) 91 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 92 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 93 | 94 | def test_warmup_constant_scheduler(self): 95 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 96 | lrs = unwrap_schedule(scheduler, self.num_steps) 97 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 98 | self.assertEqual(len(lrs[0]), 1) 99 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 100 | 101 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 102 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 103 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 104 | 105 | def test_warmup_linear_scheduler(self): 106 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 107 | lrs = unwrap_schedule(scheduler, self.num_steps) 108 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 109 | self.assertEqual(len(lrs[0]), 1) 110 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 111 | 112 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 113 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 114 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 115 | 116 | def test_warmup_cosine_scheduler(self): 117 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 118 | lrs = unwrap_schedule(scheduler, self.num_steps) 119 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 120 | self.assertEqual(len(lrs[0]), 1) 121 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 122 | 123 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 124 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 125 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 126 | 127 | def test_warmup_cosine_hard_restart_scheduler(self): 128 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 129 | lrs = unwrap_schedule(scheduler, self.num_steps) 130 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 131 | self.assertEqual(len(lrs[0]), 1) 132 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 133 | 134 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 135 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 136 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 26 | from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP 27 | 28 | 29 | class AutoTokenizerTest(unittest.TestCase): 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | 29 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 30 | 31 | tokenizer_class = BertTokenizer 32 | 33 | def setUp(self): 34 | super(BertTokenizationTest, self).setUp() 35 | 36 | vocab_tokens = [ 37 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 38 | "##ing", ",", "low", "lowest", 39 | ] 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 42 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 43 | 44 | def get_tokenizer(self, **kwargs): 45 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 46 | 47 | def get_input_output_texts(self): 48 | input_text = u"UNwant\u00E9d,running" 49 | output_text = u"unwanted, running" 50 | return input_text, output_text 51 | 52 | def test_full_tokenizer(self): 53 | tokenizer = self.tokenizer_class(self.vocab_file) 54 | 55 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 56 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 57 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 58 | 59 | def test_chinese(self): 60 | tokenizer = BasicTokenizer() 61 | 62 | self.assertListEqual( 63 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 64 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 65 | 66 | def test_basic_tokenizer_lower(self): 67 | tokenizer = BasicTokenizer(do_lower_case=True) 68 | 69 | self.assertListEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["hello", "!", "how", "are", "you", "?"]) 72 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 73 | 74 | def test_basic_tokenizer_no_lower(self): 75 | tokenizer = BasicTokenizer(do_lower_case=False) 76 | 77 | self.assertListEqual( 78 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 79 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 80 | 81 | def test_wordpiece_tokenizer(self): 82 | vocab_tokens = [ 83 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 84 | "##ing" 85 | ] 86 | 87 | vocab = {} 88 | for (i, token) in enumerate(vocab_tokens): 89 | vocab[token] = i 90 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 91 | 92 | self.assertListEqual(tokenizer.tokenize(""), []) 93 | 94 | self.assertListEqual( 95 | tokenizer.tokenize("unwanted running"), 96 | ["un", "##want", "##ed", "runn", "##ing"]) 97 | 98 | self.assertListEqual( 99 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 100 | 101 | def test_is_whitespace(self): 102 | self.assertTrue(_is_whitespace(u" ")) 103 | self.assertTrue(_is_whitespace(u"\t")) 104 | self.assertTrue(_is_whitespace(u"\r")) 105 | self.assertTrue(_is_whitespace(u"\n")) 106 | self.assertTrue(_is_whitespace(u"\u00A0")) 107 | 108 | self.assertFalse(_is_whitespace(u"A")) 109 | self.assertFalse(_is_whitespace(u"-")) 110 | 111 | def test_is_control(self): 112 | self.assertTrue(_is_control(u"\u0005")) 113 | 114 | self.assertFalse(_is_control(u"A")) 115 | self.assertFalse(_is_control(u" ")) 116 | self.assertFalse(_is_control(u"\t")) 117 | self.assertFalse(_is_control(u"\r")) 118 | 119 | def test_is_punctuation(self): 120 | self.assertTrue(_is_punctuation(u"-")) 121 | self.assertTrue(_is_punctuation(u"$")) 122 | self.assertTrue(_is_punctuation(u"`")) 123 | self.assertTrue(_is_punctuation(u".")) 124 | 125 | self.assertFalse(_is_punctuation(u"A")) 126 | self.assertFalse(_is_punctuation(u" ")) 127 | 128 | def test_sequence_builders(self): 129 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 130 | 131 | text = tokenizer.encode("sequence builders") 132 | text_2 = tokenizer.encode("multi-sequence build") 133 | 134 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 135 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 136 | 137 | assert encoded_sentence == [101] + text + [102] 138 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 139 | 140 | if __name__ == '__main__': 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_dilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer) 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | from .tokenization_bert_test import BertTokenizationTest 25 | 26 | class DistilBertTokenizationTest(BertTokenizationTest): 27 | 28 | tokenizer_class = DistilBertTokenizer 29 | 30 | def get_tokenizer(self, **kwargs): 31 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 32 | 33 | def test_sequence_builders(self): 34 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 35 | 36 | text = tokenizer.encode("sequence builders") 37 | text_2 = tokenizer.encode("multi-sequence build") 38 | 39 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 40 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 41 | 42 | assert encoded_sentence == [101] + text + [102] 43 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | from io import open 21 | 22 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = GPT2Tokenizer 29 | 30 | def setUp(self): 31 | super(GPT2TokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u" lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | from io import open 21 | 22 | from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | tokenizer_class = RobertaTokenizer 28 | 29 | def setUp(self): 30 | super(RobertaTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "\u0120", "\u0120l", "\u0120n", 35 | "\u0120lo", "\u0120low", "er", 36 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 37 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 38 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 39 | self.special_tokens_map = {"unk_token": ""} 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 44 | fp.write(json.dumps(vocab_tokens) + "\n") 45 | with open(self.merges_file, "w", encoding="utf-8") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | kwargs.update(self.special_tokens_map) 50 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) 51 | 52 | def get_input_output_texts(self): 53 | input_text = u"lower newer" 54 | output_text = u" lower newer" 55 | return input_text, output_text 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 59 | text = "lower newer" 60 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 61 | tokens = tokenizer.tokenize(text) 62 | self.assertListEqual(tokens, bpe_tokens) 63 | 64 | input_tokens = tokens + [tokenizer.unk_token] 65 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 66 | self.assertListEqual( 67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 68 | 69 | def roberta_dict_integration_testing(self): 70 | tokenizer = self.get_tokenizer() 71 | 72 | self.assertListEqual( 73 | tokenizer.encode('Hello world!'), 74 | [0, 31414, 232, 328, 2] 75 | ) 76 | self.assertListEqual( 77 | tokenizer.encode('Hello world! cécé herlolip 418'), 78 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 79 | ) 80 | 81 | def test_sequence_builders(self): 82 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 83 | 84 | text = tokenizer.encode("sequence builders") 85 | text_2 = tokenizer.encode("multi-sequence build") 86 | 87 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 88 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 89 | 90 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 91 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 92 | 93 | assert encoded_sentence == encoded_text_from_decode 94 | assert encoded_pair == encoded_pair_from_decode 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_tests_commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import sys 19 | from io import open 20 | import tempfile 21 | import shutil 22 | import unittest 23 | 24 | if sys.version_info[0] == 2: 25 | import cPickle as pickle 26 | 27 | class TemporaryDirectory(object): 28 | """Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" 29 | def __enter__(self): 30 | self.name = tempfile.mkdtemp() 31 | return self.name 32 | def __exit__(self, exc_type, exc_value, traceback): 33 | shutil.rmtree(self.name) 34 | else: 35 | import pickle 36 | TemporaryDirectory = tempfile.TemporaryDirectory 37 | unicode = str 38 | 39 | 40 | class CommonTestCases: 41 | 42 | class CommonTokenizerTester(unittest.TestCase): 43 | 44 | tokenizer_class = None 45 | 46 | def setUp(self): 47 | self.tmpdirname = tempfile.mkdtemp() 48 | 49 | def tearDown(self): 50 | shutil.rmtree(self.tmpdirname) 51 | 52 | def get_tokenizer(self, **kwargs): 53 | raise NotImplementedError 54 | 55 | def get_input_output_texts(self): 56 | raise NotImplementedError 57 | 58 | def test_tokenizers_common_properties(self): 59 | tokenizer = self.get_tokenizer() 60 | attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token", 61 | "pad_token", "cls_token", "mask_token"] 62 | for attr in attributes_list: 63 | self.assertTrue(hasattr(tokenizer, attr)) 64 | self.assertTrue(hasattr(tokenizer, attr + "_id")) 65 | 66 | self.assertTrue(hasattr(tokenizer, "additional_special_tokens")) 67 | self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids')) 68 | 69 | attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", 70 | "added_tokens_decoder"] 71 | for attr in attributes_list: 72 | self.assertTrue(hasattr(tokenizer, attr)) 73 | 74 | def test_save_and_load_tokenizer(self): 75 | # safety check on max_len default value so we are sure the test works 76 | tokenizer = self.get_tokenizer() 77 | self.assertNotEqual(tokenizer.max_len, 42) 78 | 79 | # Now let's start the test 80 | tokenizer = self.get_tokenizer(max_len=42) 81 | 82 | before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 83 | 84 | with TemporaryDirectory() as tmpdirname: 85 | tokenizer.save_pretrained(tmpdirname) 86 | tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) 87 | 88 | after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 89 | self.assertListEqual(before_tokens, after_tokens) 90 | 91 | self.assertEqual(tokenizer.max_len, 42) 92 | tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43) 93 | self.assertEqual(tokenizer.max_len, 43) 94 | 95 | def test_pickle_tokenizer(self): 96 | tokenizer = self.get_tokenizer() 97 | self.assertIsNotNone(tokenizer) 98 | 99 | text = u"Munich and Berlin are nice cities" 100 | subwords = tokenizer.tokenize(text) 101 | 102 | with TemporaryDirectory() as tmpdirname: 103 | 104 | filename = os.path.join(tmpdirname, u"tokenizer.bin") 105 | pickle.dump(tokenizer, open(filename, "wb")) 106 | 107 | tokenizer_new = pickle.load(open(filename, "rb")) 108 | 109 | subwords_loaded = tokenizer_new.tokenize(text) 110 | 111 | self.assertListEqual(subwords, subwords_loaded) 112 | 113 | 114 | def test_add_tokens_tokenizer(self): 115 | tokenizer = self.get_tokenizer() 116 | 117 | vocab_size = tokenizer.vocab_size 118 | all_size = len(tokenizer) 119 | 120 | self.assertNotEqual(vocab_size, 0) 121 | self.assertEqual(vocab_size, all_size) 122 | 123 | new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"] 124 | added_toks = tokenizer.add_tokens(new_toks) 125 | vocab_size_2 = tokenizer.vocab_size 126 | all_size_2 = len(tokenizer) 127 | 128 | self.assertNotEqual(vocab_size_2, 0) 129 | self.assertEqual(vocab_size, vocab_size_2) 130 | self.assertEqual(added_toks, len(new_toks)) 131 | self.assertEqual(all_size_2, all_size + len(new_toks)) 132 | 133 | tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l") 134 | out_string = tokenizer.decode(tokens) 135 | 136 | self.assertGreaterEqual(len(tokens), 4) 137 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1) 138 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 139 | 140 | new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", 141 | 'pad_token': "<<<<<|||>|>>>>|>"} 142 | added_toks_2 = tokenizer.add_special_tokens(new_toks_2) 143 | vocab_size_3 = tokenizer.vocab_size 144 | all_size_3 = len(tokenizer) 145 | 146 | self.assertNotEqual(vocab_size_3, 0) 147 | self.assertEqual(vocab_size, vocab_size_3) 148 | self.assertEqual(added_toks_2, len(new_toks_2)) 149 | self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) 150 | 151 | tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") 152 | out_string = tokenizer.decode(tokens) 153 | 154 | self.assertGreaterEqual(len(tokens), 6) 155 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1) 156 | self.assertGreater(tokens[0], tokens[1]) 157 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 158 | self.assertGreater(tokens[-2], tokens[-3]) 159 | self.assertEqual(tokens[0], tokenizer.eos_token_id) 160 | self.assertEqual(tokens[-2], tokenizer.pad_token_id) 161 | 162 | 163 | def test_required_methods_tokenizer(self): 164 | tokenizer = self.get_tokenizer() 165 | input_text, output_text = self.get_input_output_texts() 166 | 167 | tokens = tokenizer.tokenize(input_text) 168 | ids = tokenizer.convert_tokens_to_ids(tokens) 169 | ids_2 = tokenizer.encode(input_text) 170 | self.assertListEqual(ids, ids_2) 171 | 172 | tokens_2 = tokenizer.convert_ids_to_tokens(ids) 173 | text_2 = tokenizer.decode(ids) 174 | 175 | self.assertEqual(text_2, output_text) 176 | 177 | self.assertNotEqual(len(tokens_2), 0) 178 | self.assertIsInstance(text_2, (str, unicode)) 179 | 180 | 181 | def test_pretrained_model_lists(self): 182 | weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) 183 | weights_lists_2 = [] 184 | for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items(): 185 | weights_lists_2.append(list(map_list.keys())) 186 | 187 | for weights_list_2 in weights_lists_2: 188 | self.assertListEqual(weights_list, weights_list_2) 189 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from.tokenization_tests_commons import CommonTestCases 24 | 25 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = TransfoXLTokenizer 28 | 29 | def setUp(self): 30 | super(TransfoXLTokenizationTest, self).setUp() 31 | 32 | vocab_tokens = [ 33 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 34 | "running", ",", "low", "l", 35 | ] 36 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 37 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 38 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 39 | 40 | def get_tokenizer(self, **kwargs): 41 | kwargs['lower_case'] = True 42 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 43 | 44 | def get_input_output_texts(self): 45 | input_text = u" UNwanted , running" 46 | output_text = u" unwanted, running" 47 | return input_text, output_text 48 | 49 | def test_full_tokenizer(self): 50 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 51 | 52 | tokens = tokenizer.tokenize(u" UNwanted , running") 53 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 54 | 55 | self.assertListEqual( 56 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 57 | 58 | def test_full_tokenizer_lower(self): 59 | tokenizer = TransfoXLTokenizer(lower_case=True) 60 | 61 | self.assertListEqual( 62 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | 65 | def test_full_tokenizer_no_lower(self): 66 | tokenizer = TransfoXLTokenizer(lower_case=False) 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 70 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from pytorch_transformers import PreTrainedTokenizer 23 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | class TokenizerUtilsTest(unittest.TestCase): 26 | def check_tokenizer_from_pretrained(self, tokenizer_class): 27 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 28 | for model_name in s3_models[:1]: 29 | tokenizer = tokenizer_class.from_pretrained(model_name) 30 | self.assertIsNotNone(tokenizer) 31 | self.assertIsInstance(tokenizer, tokenizer_class) 32 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 33 | 34 | for special_tok in tokenizer.all_special_tokens: 35 | if six.PY2: 36 | self.assertIsInstance(special_tok, unicode) 37 | else: 38 | self.assertIsInstance(special_tok, str) 39 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 40 | self.assertIsInstance(special_tok_id, int) 41 | 42 | def test_pretrained_tokenizers(self): 43 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = XLMTokenizer 28 | 29 | def setUp(self): 30 | super(XLMTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "w", "r", "t", 35 | "lo", "low", "er", 36 | "low", "lowest", "newer", "wider", ""] 37 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 38 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 39 | 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 42 | with open(self.vocab_file, "w") as fp: 43 | fp.write(json.dumps(vocab_tokens)) 44 | with open(self.merges_file, "w") as fp: 45 | fp.write("\n".join(merges)) 46 | 47 | def get_tokenizer(self, **kwargs): 48 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) 49 | 50 | def get_input_output_texts(self): 51 | input_text = u"lower newer" 52 | output_text = u"lower newer" 53 | return input_text, output_text 54 | 55 | def test_full_tokenizer(self): 56 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 57 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 58 | 59 | text = "lower" 60 | bpe_tokens = ["low", "er"] 61 | tokens = tokenizer.tokenize(text) 62 | self.assertListEqual(tokens, bpe_tokens) 63 | 64 | input_tokens = tokens + [""] 65 | input_bpe_tokens = [14, 15, 20] 66 | self.assertListEqual( 67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 68 | 69 | def test_sequence_builders(self): 70 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 71 | 72 | text = tokenizer.encode("sequence builders") 73 | text_2 = tokenizer.encode("multi-sequence build") 74 | 75 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 76 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 77 | 78 | assert encoded_sentence == [1] + text + [1] 79 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 80 | 81 | if __name__ == '__main__': 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/test_sentencepiece.model') 26 | 27 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | 29 | tokenizer_class = XLNetTokenizer 30 | 31 | def setUp(self): 32 | super(XLNetTokenizationTest, self).setUp() 33 | 34 | # We have a SentencePiece fixture for testing 35 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 36 | tokenizer.save_pretrained(self.tmpdirname) 37 | 38 | def get_tokenizer(self, **kwargs): 39 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) 40 | 41 | def get_input_output_texts(self): 42 | input_text = u"This is a test" 43 | output_text = u"This is a test" 44 | return input_text, output_text 45 | 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 58 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 59 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 60 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 61 | ids = tokenizer.convert_tokens_to_ids(tokens) 62 | self.assertListEqual( 63 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 64 | 602, 347, 347, 347, 3, 12, 66, 65 | 46, 72, 80, 6, 0, 4]) 66 | 67 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 68 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 69 | u'or', u'n', SPIECE_UNDERLINE + u'in', 70 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 71 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 72 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 73 | u'', u'.']) 74 | 75 | def test_tokenizer_lower(self): 76 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 77 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 78 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 79 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 80 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 81 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 82 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 83 | 84 | def test_tokenizer_no_lower(self): 85 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 86 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 87 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 88 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 89 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 90 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 91 | 92 | def test_sequence_builders(self): 93 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 94 | 95 | text = tokenizer.encode("sequence builders") 96 | text_2 = tokenizer.encode("multi-sequence build") 97 | 98 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 99 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 100 | 101 | assert encoded_sentence == text + [4, 3] 102 | assert encoded_pair == text + [4] + text_2 + [4, 3] 103 | 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .tokenization_bert import BertTokenizer 22 | from .tokenization_openai import OpenAIGPTTokenizer 23 | from .tokenization_gpt2 import GPT2Tokenizer 24 | from .tokenization_transfo_xl import TransfoXLTokenizer 25 | from .tokenization_xlnet import XLNetTokenizer 26 | from .tokenization_xlm import XLMTokenizer 27 | from .tokenization_roberta import RobertaTokenizer 28 | from .tokenization_distilbert import DistilBertTokenizer 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | class AutoTokenizer(object): 33 | r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class 34 | that will be instantiated as one of the tokenizer classes of the library 35 | when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` 36 | class method. 37 | 38 | The `from_pretrained()` method take care of returning the correct tokenizer class instance 39 | using pattern matching on the `pretrained_model_name_or_path` string. 40 | 41 | The tokenizer class to instantiate is selected as the first pattern matching 42 | in the `pretrained_model_name_or_path` string (in the following order): 43 | - contains `distilbert`: DistilBertTokenizer (DistilBert model) 44 | - contains `roberta`: RobertaTokenizer (RoBERTa model) 45 | - contains `bert`: BertTokenizer (Bert model) 46 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 47 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 48 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 49 | - contains `xlnet`: XLNetTokenizer (XLNet model) 50 | - contains `xlm`: XLMTokenizer (XLM model) 51 | 52 | This class cannot be instantiated using `__init__()` (throw an error). 53 | """ 54 | def __init__(self): 55 | raise EnvironmentError("AutoTokenizer is designed to be instantiated " 56 | "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") 57 | 58 | @classmethod 59 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 60 | r""" Instantiate a one of the tokenizer classes of the library 61 | from a pre-trained model vocabulary. 62 | 63 | The tokenizer class to instantiate is selected as the first pattern matching 64 | in the `pretrained_model_name_or_path` string (in the following order): 65 | - contains `distilbert`: DistilBertTokenizer (DistilBert model) 66 | - contains `roberta`: RobertaTokenizer (XLM model) 67 | - contains `bert`: BertTokenizer (Bert model) 68 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 69 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 70 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 71 | - contains `xlnet`: XLNetTokenizer (XLNet model) 72 | - contains `xlm`: XLMTokenizer (XLM model) 73 | 74 | Params: 75 | pretrained_model_name_or_path: either: 76 | 77 | - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. 78 | - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. 79 | - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. 80 | 81 | cache_dir: (`optional`) string: 82 | Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. 83 | 84 | force_download: (`optional`) boolean, default False: 85 | Force to (re-)download the vocabulary files and override the cached versions if they exists. 86 | 87 | proxies: (`optional`) dict, default None: 88 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 89 | The proxies are used on each request. 90 | 91 | inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. 92 | 93 | kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. 94 | 95 | Examples:: 96 | 97 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. 98 | tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` 99 | 100 | """ 101 | if 'distilbert' in pretrained_model_name_or_path: 102 | return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 103 | elif 'roberta' in pretrained_model_name_or_path: 104 | return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 105 | elif 'bert' in pretrained_model_name_or_path: 106 | return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 107 | elif 'openai-gpt' in pretrained_model_name_or_path: 108 | return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 109 | elif 'gpt2' in pretrained_model_name_or_path: 110 | return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 111 | elif 'transfo-xl' in pretrained_model_name_or_path: 112 | return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 113 | elif 'xlnet' in pretrained_model_name_or_path: 114 | return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 115 | elif 'xlm' in pretrained_model_name_or_path: 116 | return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 117 | 118 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 119 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " 120 | "'xlm', 'roberta'".format(pretrained_model_name_or_path)) 121 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_bert import BertTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | } 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | 'distilbert-base-uncased': 512, 41 | 'distilbert-base-uncased-distilled-squad': 512, 42 | } 43 | 44 | 45 | class DistilBertTokenizer(BertTokenizer): 46 | r""" 47 | Constructs a DistilBertTokenizer. 48 | :class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 49 | 50 | Args: 51 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 52 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 53 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 54 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 55 | minimum of this value (if specified) and the underlying BERT model's sequence length. 56 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 57 | do_wordpiece_only=False 58 | """ 59 | 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | from .tokenization_bert import BasicTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = { 31 | 'vocab_file': 'vocab.json', 32 | 'merges_file': 'merges.txt', 33 | } 34 | 35 | PRETRAINED_VOCAB_FILES_MAP = { 36 | 'vocab_file': 37 | { 38 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 39 | }, 40 | 'merges_file': 41 | { 42 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 43 | }, 44 | } 45 | 46 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 47 | 'openai-gpt': 512, 48 | } 49 | 50 | def get_pairs(word): 51 | """ 52 | Return set of symbol pairs in a word. 53 | word is represented as tuple of symbols (symbols being variable-length strings) 54 | """ 55 | pairs = set() 56 | prev_char = word[0] 57 | for char in word[1:]: 58 | pairs.add((prev_char, char)) 59 | prev_char = char 60 | return pairs 61 | 62 | def text_standardize(text): 63 | """ 64 | fixes some issues the spacy tokenizer had on books corpus 65 | also does some whitespace standardization 66 | """ 67 | text = text.replace('—', '-') 68 | text = text.replace('–', '-') 69 | text = text.replace('―', '-') 70 | text = text.replace('…', '...') 71 | text = text.replace('´', "'") 72 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 73 | text = re.sub(r'\s*\n\s*', ' \n ', text) 74 | text = re.sub(r'[^\S\n]+', ' ', text) 75 | return text.strip() 76 | 77 | class OpenAIGPTTokenizer(PreTrainedTokenizer): 78 | """ 79 | BPE tokenizer. Peculiarities: 80 | - lower case all inputs 81 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 82 | """ 83 | vocab_files_names = VOCAB_FILES_NAMES 84 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 85 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 86 | 87 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): 88 | super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) 89 | 90 | self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens 91 | self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens 92 | 93 | try: 94 | import ftfy 95 | from spacy.lang.en import English 96 | _nlp = English() 97 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp) 98 | self.fix_text = ftfy.fix_text 99 | except ImportError: 100 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 101 | self.nlp = BasicTokenizer(do_lower_case=True) 102 | self.fix_text = None 103 | 104 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 105 | self.decoder = {v:k for k,v in self.encoder.items()} 106 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 107 | merges = [tuple(merge.split()) for merge in merges] 108 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 109 | self.cache = {} 110 | 111 | @property 112 | def vocab_size(self): 113 | return len(self.encoder) 114 | 115 | def bpe(self, token): 116 | word = tuple(token[:-1]) + (token[-1] + '',) 117 | if token in self.cache: 118 | return self.cache[token] 119 | pairs = get_pairs(word) 120 | 121 | if not pairs: 122 | return token+'' 123 | 124 | while True: 125 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 126 | if bigram not in self.bpe_ranks: 127 | break 128 | first, second = bigram 129 | new_word = [] 130 | i = 0 131 | while i < len(word): 132 | try: 133 | j = word.index(first, i) 134 | new_word.extend(word[i:j]) 135 | i = j 136 | except: 137 | new_word.extend(word[i:]) 138 | break 139 | 140 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 141 | new_word.append(first+second) 142 | i += 2 143 | else: 144 | new_word.append(word[i]) 145 | i += 1 146 | new_word = tuple(new_word) 147 | word = new_word 148 | if len(word) == 1: 149 | break 150 | else: 151 | pairs = get_pairs(word) 152 | word = ' '.join(word) 153 | if word == '\n ': 154 | word = '\n' 155 | self.cache[token] = word 156 | return word 157 | 158 | def _tokenize(self, text): 159 | """ Tokenize a string. """ 160 | split_tokens = [] 161 | if self.fix_text is None: 162 | # Using BERT's BasicTokenizer 163 | text = self.nlp.tokenize(text) 164 | for token in text: 165 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 166 | else: 167 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 168 | text = self.nlp(text_standardize(self.fix_text(text))) 169 | for token in text: 170 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 171 | return split_tokens 172 | 173 | def _convert_token_to_id(self, token): 174 | """ Converts a token (str/unicode) in an id using the vocab. """ 175 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 176 | 177 | def _convert_id_to_token(self, index): 178 | """Converts an id in a token (BPE) using the vocab.""" 179 | return self.decoder.get(index, self.unk_token) 180 | 181 | def convert_tokens_to_string(self, tokens): 182 | """ Converts a sequence of tokens (string) in a single string. """ 183 | out_string = ''.join(tokens).replace('', ' ').strip() 184 | return out_string 185 | 186 | def save_vocabulary(self, save_directory): 187 | """Save the tokenizer vocabulary and merge files to a directory.""" 188 | if not os.path.isdir(save_directory): 189 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 190 | return 191 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 192 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 193 | 194 | with open(vocab_file, 'w', encoding='utf-8') as f: 195 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 196 | 197 | index = 0 198 | with open(merge_file, "w", encoding="utf-8") as writer: 199 | writer.write(u'#version: 0.2\n') 200 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 201 | if index != token_index: 202 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 203 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 204 | index = token_index 205 | writer.write(' '.join(bpe_tokens) + u'\n') 206 | index += 1 207 | 208 | return vocab_file, merge_file 209 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoBERTa.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | from .tokenization_gpt2 import GPT2Tokenizer 27 | 28 | try: 29 | from functools import lru_cache 30 | except ImportError: 31 | # Just a dummy decorator to get the checks to run on python2 32 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 33 | def lru_cache(): 34 | return lambda func: func 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 47 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 48 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", 49 | }, 50 | 'merges_file': 51 | { 52 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 53 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 54 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", 55 | }, 56 | } 57 | 58 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 59 | 'roberta-base': 512, 60 | 'roberta-large': 512, 61 | 'roberta-large-mnli': 512, 62 | } 63 | 64 | 65 | class RobertaTokenizer(GPT2Tokenizer): 66 | """ 67 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: 68 | - Byte-level Byte-Pair-Encoding 69 | - Requires a space to start the input string => will add a space is there isn't. 70 | As a consequence, this tokenizer `encode` and `decode` method will not conserve 71 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello" 72 | """ 73 | vocab_files_names = VOCAB_FILES_NAMES 74 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 75 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 76 | 77 | def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="", 78 | cls_token="", unk_token="", pad_token='', mask_token='', **kwargs): 79 | super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors, 80 | bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, 81 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, 82 | mask_token=mask_token, **kwargs) 83 | 84 | def add_special_tokens_single_sentence(self, token_ids): 85 | """ 86 | Adds special tokens to a sequence for sequence classification tasks. 87 | A RoBERTa sequence has the following format: X 88 | """ 89 | return [self.cls_token_id] + token_ids + [self.sep_token_id] 90 | 91 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 92 | """ 93 | Adds special tokens to a sequence pair for sequence classification tasks. 94 | A RoBERTa sequence pair has the following format: A B 95 | """ 96 | sep = [self.sep_token_id] 97 | cls = [self.cls_token_id] 98 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 99 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization classes for XLNet model.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | from shutil import copyfile 22 | 23 | import unicodedata 24 | import six 25 | 26 | from .tokenization_utils import PreTrainedTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} 31 | 32 | PRETRAINED_VOCAB_FILES_MAP = { 33 | 'vocab_file': 34 | { 35 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model", 36 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", 37 | } 38 | } 39 | 40 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 41 | 'xlnet-base-cased': None, 42 | 'xlnet-large-cased': None, 43 | } 44 | 45 | SPIECE_UNDERLINE = u'▁' 46 | 47 | # Segments (not really needed) 48 | SEG_ID_A = 0 49 | SEG_ID_B = 1 50 | SEG_ID_CLS = 2 51 | SEG_ID_SEP = 3 52 | SEG_ID_PAD = 4 53 | 54 | class XLNetTokenizer(PreTrainedTokenizer): 55 | """ 56 | SentencePiece based tokenizer. Peculiarities: 57 | 58 | - requires `SentencePiece `_ 59 | """ 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | 64 | def __init__(self, vocab_file, 65 | do_lower_case=False, remove_space=True, keep_accents=False, 66 | bos_token="", eos_token="", unk_token="", sep_token="", 67 | pad_token="", cls_token="", mask_token="", 68 | additional_special_tokens=["", ""], **kwargs): 69 | super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, 70 | unk_token=unk_token, sep_token=sep_token, 71 | pad_token=pad_token, cls_token=cls_token, 72 | mask_token=mask_token, additional_special_tokens= 73 | additional_special_tokens, **kwargs) 74 | 75 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 76 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens 77 | 78 | try: 79 | import sentencepiece as spm 80 | except ImportError: 81 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 82 | "pip install sentencepiece") 83 | 84 | self.do_lower_case = do_lower_case 85 | self.remove_space = remove_space 86 | self.keep_accents = keep_accents 87 | self.vocab_file = vocab_file 88 | 89 | self.sp_model = spm.SentencePieceProcessor() 90 | self.sp_model.Load(vocab_file) 91 | 92 | @property 93 | def vocab_size(self): 94 | return len(self.sp_model) 95 | 96 | def __getstate__(self): 97 | state = self.__dict__.copy() 98 | state["sp_model"] = None 99 | return state 100 | 101 | def __setstate__(self, d): 102 | self.__dict__ = d 103 | try: 104 | import sentencepiece as spm 105 | except ImportError: 106 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 107 | "pip install sentencepiece") 108 | self.sp_model = spm.SentencePieceProcessor() 109 | self.sp_model.Load(self.vocab_file) 110 | 111 | def preprocess_text(self, inputs): 112 | if self.remove_space: 113 | outputs = ' '.join(inputs.strip().split()) 114 | else: 115 | outputs = inputs 116 | outputs = outputs.replace("``", '"').replace("''", '"') 117 | 118 | if six.PY2 and isinstance(outputs, str): 119 | outputs = outputs.decode('utf-8') 120 | 121 | if not self.keep_accents: 122 | outputs = unicodedata.normalize('NFKD', outputs) 123 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 124 | if self.do_lower_case: 125 | outputs = outputs.lower() 126 | 127 | return outputs 128 | 129 | def _tokenize(self, text, return_unicode=True, sample=False): 130 | """ Tokenize a string. 131 | return_unicode is used only for py2 132 | """ 133 | text = self.preprocess_text(text) 134 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 135 | if six.PY2 and isinstance(text, unicode): 136 | text = text.encode('utf-8') 137 | 138 | if not sample: 139 | pieces = self.sp_model.EncodeAsPieces(text) 140 | else: 141 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) 142 | new_pieces = [] 143 | for piece in pieces: 144 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 145 | cur_pieces = self.sp_model.EncodeAsPieces( 146 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 147 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 148 | if len(cur_pieces[0]) == 1: 149 | cur_pieces = cur_pieces[1:] 150 | else: 151 | cur_pieces[0] = cur_pieces[0][1:] 152 | cur_pieces.append(piece[-1]) 153 | new_pieces.extend(cur_pieces) 154 | else: 155 | new_pieces.append(piece) 156 | 157 | # note(zhiliny): convert back to unicode for py2 158 | if six.PY2 and return_unicode: 159 | ret_pieces = [] 160 | for piece in new_pieces: 161 | if isinstance(piece, str): 162 | piece = piece.decode('utf-8') 163 | ret_pieces.append(piece) 164 | new_pieces = ret_pieces 165 | 166 | return new_pieces 167 | 168 | def _convert_token_to_id(self, token): 169 | """ Converts a token (str/unicode) in an id using the vocab. """ 170 | return self.sp_model.PieceToId(token) 171 | 172 | def _convert_id_to_token(self, index, return_unicode=True): 173 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 174 | token = self.sp_model.IdToPiece(index) 175 | if six.PY2 and return_unicode and isinstance(token, str): 176 | token = token.decode('utf-8') 177 | return token 178 | 179 | def convert_tokens_to_string(self, tokens): 180 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 181 | out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() 182 | return out_string 183 | 184 | def add_special_tokens_single_sentence(self, token_ids): 185 | """ 186 | Adds special tokens to a sequence pair for sequence classification tasks. 187 | An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] 188 | """ 189 | sep = [self.sep_token_id] 190 | cls = [self.cls_token_id] 191 | return token_ids + sep + cls 192 | 193 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 194 | """ 195 | Adds special tokens to a sequence for sequence classification tasks. 196 | An XLNet sequence has the following format: X [SEP][CLS] 197 | """ 198 | sep = [self.sep_token_id] 199 | cls = [self.cls_token_id] 200 | return token_ids_0 + sep + token_ids_1 + sep + cls 201 | 202 | def save_vocabulary(self, save_directory): 203 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 204 | to a directory. 205 | """ 206 | if not os.path.isdir(save_directory): 207 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 208 | return 209 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 210 | 211 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 212 | copyfile(self.vocab_file, out_vocab_file) 213 | 214 | return (out_vocab_file,) 215 | --------------------------------------------------------------------------------