├── LICENSE ├── README.md ├── backup-models └── readme ├── combine.py ├── data ├── analysis.py ├── preprocess.py ├── report.txt ├── statistc.py ├── submit_example.csv ├── test.csv └── train.csv ├── ensemble_submits ├── ensemble.py └── main.py ├── final_changed.py ├── hubconf.py ├── hubconfs ├── bert_hubconf.py ├── gpt2_hubconf.py ├── gpt_hubconf.py ├── transformer_xl_hubconf.py ├── xlm_hubconf.py └── xlnet_hubconf.1.py ├── pretrained_model └── readme.txt ├── pytorch_transformers ├── __init__.py ├── __main__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── file_utils.cpython-36.pyc │ ├── file_utils.cpython-37.pyc │ ├── modeling_auto.cpython-36.pyc │ ├── modeling_auto.cpython-37.pyc │ ├── modeling_bert.cpython-36.pyc │ ├── modeling_bert.cpython-37.pyc │ ├── modeling_gpt2.cpython-36.pyc │ ├── modeling_gpt2.cpython-37.pyc │ ├── modeling_openai.cpython-36.pyc │ ├── modeling_openai.cpython-37.pyc │ ├── modeling_roberta.cpython-36.pyc │ ├── modeling_roberta.cpython-37.pyc │ ├── modeling_transfo_xl.cpython-36.pyc │ ├── modeling_transfo_xl.cpython-37.pyc │ ├── modeling_transfo_xl_utilities.cpython-36.pyc │ ├── modeling_transfo_xl_utilities.cpython-37.pyc │ ├── modeling_utils.cpython-36.pyc │ ├── modeling_utils.cpython-37.pyc │ ├── modeling_xlm.cpython-36.pyc │ ├── modeling_xlm.cpython-37.pyc │ ├── modeling_xlnet.cpython-36.pyc │ ├── modeling_xlnet.cpython-37.pyc │ ├── optimization.cpython-36.pyc │ ├── optimization.cpython-37.pyc │ ├── tokenization_auto.cpython-36.pyc │ ├── tokenization_auto.cpython-37.pyc │ ├── tokenization_bert.cpython-36.pyc │ ├── tokenization_bert.cpython-37.pyc │ ├── tokenization_gpt2.cpython-36.pyc │ ├── tokenization_gpt2.cpython-37.pyc │ ├── tokenization_openai.cpython-36.pyc │ ├── tokenization_openai.cpython-37.pyc │ ├── tokenization_roberta.cpython-36.pyc │ ├── tokenization_roberta.cpython-37.pyc │ ├── tokenization_transfo_xl.cpython-36.pyc │ ├── tokenization_transfo_xl.cpython-37.pyc │ ├── tokenization_utils.cpython-36.pyc │ ├── tokenization_utils.cpython-37.pyc │ ├── tokenization_xlm.cpython-36.pyc │ ├── tokenization_xlm.cpython-37.pyc │ ├── tokenization_xlnet.cpython-36.pyc │ └── tokenization_xlnet.cpython-37.pyc ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf.py ├── convert_roberta_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── convert_xlm_checkpoint_to_pytorch.py ├── convert_xlnet_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling_auto.py ├── modeling_bert.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_roberta.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlnet.py ├── optimization.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── fixtures │ │ ├── input.txt │ │ ├── sample_text.txt │ │ └── test_sentencepiece.model │ ├── modeling_auto_test.py │ ├── modeling_bert_test.py │ ├── modeling_common_test.py │ ├── modeling_gpt2_test.py │ ├── modeling_openai_test.py │ ├── modeling_roberta_test.py │ ├── modeling_transfo_xl_test.py │ ├── modeling_xlm_test.py │ ├── modeling_xlnet_test.py │ ├── optimization_test.py │ ├── tokenization_auto_test.py │ ├── tokenization_bert_test.py │ ├── tokenization_gpt2_test.py │ ├── tokenization_openai_test.py │ ├── tokenization_roberta_test.py │ ├── tokenization_tests_commons.py │ ├── tokenization_transfo_xl_test.py │ ├── tokenization_utils_test.py │ ├── tokenization_xlm_test.py │ └── tokenization_xlnet_test.py ├── tokenization_auto.py ├── tokenization_bert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_roberta.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py └── tokenization_xlnet.py ├── requirements.txt ├── run_bert.py ├── run_bert.sh ├── run_bert_wwm_ext.sh ├── run_roberta.sh ├── run_roberta_wwm_ext.sh ├── run_xlnet.py ├── run_xlnet.sh └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | 作者ustc-linhw 2 | 3 | 本文件为文本分类任务 4 | 5 | 目前支持的功能如下: 6 | 7 | —— 训练数据集kfold处理 8 | 9 | —— 训练数据集数据信息查看 10 | 11 | —— 使用预训练模型进行文本分类 12 | 13 | —— roberta_wwm_ext_large 14 | 15 | —— roberta_large 16 | 17 | —— xlnet_large (to do) 18 | 19 | —— 不同模型结果进行投票ensemble 20 | 21 | —— 对于训练完成的模型自动保存模型,配置以及输出结果 22 | 23 | 主要文件目录如下: 24 | 25 | —— backup-models:自动存档目录,输出的模型和结果会自动存档到该目录 26 | 27 | —— data:数据文件,用于存放训练用的数据,在该文件下数据分析,数据kfold处理 28 | 29 | —— pretrained_model: 用于存放预训练的模型 30 | 31 | —— run_xxxxx.sh: 训练某个模型所使用的bash文件 32 | 33 | —— run_xxxx.py: 具体的训练代码 34 | 35 | —— ensemble_submits:对输出的result文件进行vote融合结果 36 | 37 | 38 | 具体使用流程 39 | 40 | 41 | 1. 对于不同的分类任务,可能需要修改下述文件,目前是2分类,如要修改,修改下述文件。 42 | 43 | —— preprocess.py 44 | 45 | —— run_bert.py 46 | 47 | —— 标签label 48 | 49 | —— 类别数 50 | 51 | —— 类别loss 52 | 53 | —— combine.py 54 | 55 | 56 | 2. cd data && python analysis.py 查看数据集的相关情况 57 | 58 | 3. python preprocess.py 完成数据预处理,并且将数据分成kfold 59 | 60 | 4. 修改run_xxxx.sh文件设置参数 61 | 62 | 注:该模型将文本截成k段,分别输入语言模型,然后顶层用GRU拼接起来。好处在于设置小的max_length和更大的k来降低显存占用,因为显存占用是关于长度平方级增长的,而关于k是线性增长 63 | 64 | 1)实际长度 = max_seq_length * split_num 65 | 66 | 2)实际batch size 大小= per_gpu_train_batch_size * numbers of gpu 67 | 68 | 3)上面的结果所使用的是4卡GPU,因此batch size为4。如果只有1卡的话,那么per_gpu_train_batch_size应设为4, max_length设置小一些。 69 | 70 | 4)如果显存太小,可以设置gradient_accumulation_steps参数,比如gradient_accumulation_steps=2,batch size=4,那么就会运行2次,每次batch size为2,累计梯度后更新,等价于batch size=4,但速度会慢两倍。而且迭代次数也要相应提高两倍,即train_steps设为10000 71 | 72 | 具体batch size可看运行时的log,如: 73 | 74 | 09/06/2019 21:03:41 - INFO - __main__ - ***** Running training ***** 75 | 76 | 09/06/2019 21:03:41 - INFO - __main__ - Num examples = 5872 77 | 78 | 09/06/2019 21:03:41 - INFO - __main__ - Batch size = 4 79 | 80 | 09/06/2019 21:03:41 - INFO - __main__ - Num steps = 5000 81 | 82 | 5. 最后输出文件会生成result.csv,模型会在对应的模型文件夹中生成,backup文件夹问自动保存对应的模型。 83 | 84 | 85 | -------------------------------------------------------------------------------- /backup-models/readme: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/backup-models/readme -------------------------------------------------------------------------------- /combine.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--model_prefix", default=None, type=str, required=True) 6 | parser.add_argument("--out_path", default=None, type=str, required=True) 7 | parser.add_argument("--fold",default=None,type=int,required=True) 8 | args = parser.parse_args() 9 | 10 | k=args.fold 11 | df=pd.read_csv('./data/submit_example.csv') 12 | df['0']=0 13 | df['1']=0 14 | for i in range(k): 15 | temp=pd.read_csv('{}{}/sub.csv'.format(args.model_prefix,i)) 16 | df['0']+=temp['label_0']/k 17 | df['1']+=temp['label_1']/k 18 | print(df['0'].mean()) 19 | 20 | df['flag']=np.argmax(df[['0','1']].values,- 1) 21 | df[['id','flag']].to_csv(args.out_path,index=False) 22 | -------------------------------------------------------------------------------- /data/analysis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import random 4 | import io 5 | 6 | label_name='flag' 7 | out='report.txt' 8 | 9 | os.system("rm report.txt") 10 | 11 | def analyzefile(file): 12 | df=pd.read_csv(file) 13 | os.system("touch "+out) 14 | os.system("echo '-------------information about "+file+" set------------' >> "+out) 15 | os.system("echo 'the row number of "+file+" is "+str(df.shape[0])+"' >> "+out) 16 | if(file=='train.csv'): 17 | os.system("echo 'the label number of "+file+" is\n"+str(df[label_name].value_counts()[:10])+"' >> "+out) 18 | os.system("echo '\n-------------the describe of "+file+" is ------------------ \n"+str(df.describe())+"' >> "+out) 19 | os.system("echo '\n--------------the info of "+file+" data --------------- \n' >> "+ out) 20 | buffer=io.StringIO() 21 | df.info(buf=buffer) 22 | info=buffer.getvalue() 23 | f=open('report.txt','a') 24 | f.write(info) 25 | f.write("\n\n\n") 26 | f.close() 27 | 28 | analyzefile('train.csv') 29 | analyzefile('test.csv') 30 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import random 4 | random.seed(1) 5 | train_df = pd.read_csv('train.csv') 6 | test_df = pd.read_csv('test.csv') 7 | 8 | test_df['content']=test_df['content'].fillna('无。') 9 | train_df['content']=train_df['content'].fillna('无。') 10 | test_df['title']=test_df['title'].fillna('无。') 11 | train_df['title']=train_df['title'].fillna('无。') 12 | test_df['flag']=0 13 | 14 | test_df.to_csv("data/test.csv",index=False) 15 | train_df.to_csv("data/train.csv",index=False) 16 | 17 | index=set(range(train_df.shape[0])) 18 | K_fold=[] 19 | for i in range(5): 20 | if i == 4: 21 | tmp=index 22 | else: 23 | tmp=random.sample(index,int(1.0/5*train_df.shape[0])) 24 | index=index-set(tmp) 25 | print("Number:",len(tmp)) 26 | K_fold.append(tmp) 27 | 28 | 29 | for i in range(5): 30 | print("Fold",i) 31 | os.system("mkdir data_{}".format(i)) 32 | dev_index=list(K_fold[i]) 33 | train_index=[] 34 | for j in range(5): 35 | if j!=i: 36 | train_index+=K_fold[j] 37 | train_df.iloc[train_index].to_csv("data_{}/train.csv".format(i),index=False) 38 | train_df.iloc[dev_index].to_csv("data_{}/dev.csv".format(i),index=False) 39 | test_df.to_csv("data_{}/test.csv".format(i),index=False) 40 | 41 | os.system("mv data_0/dev.csv data/dev.csv") 42 | -------------------------------------------------------------------------------- /data/report.txt: -------------------------------------------------------------------------------- 1 | -------------information about train.csv set------------ 2 | the row number of train.csv is 3993 3 | the label number of train.csv is 4 | 0 2985 5 | 1 1008 6 | Name: flag, dtype: int64 7 | 8 | -------------the describe of train.csv is ------------------ 9 | flag 10 | count 3993.000000 11 | mean 0.252442 12 | std 0.434468 13 | min 0.000000 14 | 25% 0.000000 15 | 50% 0.000000 16 | 75% 1.000000 17 | max 1.000000 18 | 19 | --------------the info of train.csv data --------------- 20 | 21 | 22 | RangeIndex: 3993 entries, 0 to 3992 23 | Data columns (total 4 columns): 24 | id 3993 non-null object 25 | flag 3993 non-null int64 26 | title 3993 non-null object 27 | content 3984 non-null object 28 | dtypes: int64(1), object(3) 29 | memory usage: 124.9+ KB 30 | 31 | 32 | 33 | -------------information about test.csv set------------ 34 | the row number of test.csv is 3993 35 | 36 | -------------the describe of test.csv is ------------------ 37 | id title content 38 | count 3993 3993 3988 39 | unique 3993 3888 3903 40 | top ce3c126915494f4ca93a6f1cbd65283f 全新别克GL8符时下潮流 家用商务两不误    41 | freq 1 7 16 42 | 43 | --------------the info of test.csv data --------------- 44 | 45 | 46 | RangeIndex: 3993 entries, 0 to 3992 47 | Data columns (total 3 columns): 48 | id 3993 non-null object 49 | title 3993 non-null object 50 | content 3988 non-null object 51 | dtypes: object(3) 52 | memory usage: 93.7+ KB 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /data/statistc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | #df=pd.read_csv('Train_DataSet_Label.csv') 3 | # df.columns=['id','title','content'] 4 | # df['label']=0 5 | # df[['id','label']].to_csv('submit_example.csv',index=False) 6 | def cal(text): 7 | df=pd.read_csv(text) 8 | df0=len(df[df['label']==0]) 9 | print(df0) 10 | df1=len(df[df['label']==1]) 11 | print(df1) 12 | df2=len(df[df['label']==2]) 13 | print(df2) 14 | sum=df1+df2+df0 15 | print(df0/sum,df1/sum,df2/sum) 16 | 17 | cal('data_0/train.csv') 18 | cal('data_0/dev.csv') -------------------------------------------------------------------------------- /ensemble_submits/ensemble.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | #vote 文件 5 | submits_path='./submits' 6 | #需要进行vote的文件 7 | submits = ['0.82414645.csv','0.8172323.csv','0.81546885000.csv'] 8 | #vote时文件的权重 9 | file_weight = [3,2,2] 10 | #vote时标签的权重 11 | label_weight =[1,1,1] 12 | 13 | files = [] 14 | data = [] 15 | for f in submits: 16 | if 'csv' in f: 17 | files.append(f) 18 | data.append(pd.read_csv(submits_path+f).values) 19 | print(len(files)) 20 | output = np.zeros([len(data[0]), 3]) 21 | 22 | for i in range(len(data)): 23 | for j in range(len(data[0])): 24 | if data[i][j][1] == 0: 25 | output[j][0] += file_weight[i]*label_weight 26 | elif data[i][j][1] == 1: 27 | output[j][1] += file_weight[i]*label_weight 28 | elif data[i][j][1] == 2: 29 | output[j][2] += file_weight[i]*label_weight 30 | 31 | #读取提交模板,需要设置 32 | submit = pd.read_csv('sub_teample.csv') 33 | submit['label'] = np.argmax(output, axis = 1) 34 | submit.to_csv('submit.csv',index=None) 35 | 36 | 37 | -------------------------------------------------------------------------------- /ensemble_submits/main.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | #对输出文件进行分析,分析他们的相关程度 7 | 8 | path="./submits" #文件夹目录 9 | raw = pd.read_csv('0.82414645.csv') 10 | raw = raw.drop('label',axis=1) 11 | files= os.listdir(path) #得到文件夹下的所有文件名称 12 | s = [] 13 | for file in files: #遍历文件夹 14 | print(file) 15 | if(file.find(".csv")>0): 16 | tmp = pd.read_csv(file) 17 | tmp = tmp[['id', 'label']] 18 | tmp.columns = ['id', file] 19 | raw = pd.merge(raw, tmp, on='id') 20 | 21 | new = raw.drop(['id'],axis=1) 22 | 23 | def test(df): 24 | dfData = df.corr() 25 | print(dfData) 26 | plt.subplots(figsize=(19, 19)) # 设置画面大小 27 | sns.heatmap(dfData, annot=True, vmax=1, square=True, cmap="Blues") 28 | plt.savefig('./SubmitRelation.png') 29 | plt.show() 30 | 31 | test(new) 32 | -------------------------------------------------------------------------------- /final_changed.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import Counter 4 | 5 | train_dataset = pd.read_csv('data/Train_DataSet.csv').values 6 | train_label = pd.read_csv('data/Train_DataSet_Label.csv').values 7 | train_label_dict = {} 8 | # train label 和 train dataset似乎没有对齐,通过dict暴力对齐 9 | for i in range(len(train_label)): 10 | train_label_dict[train_label[i][0]] = train_label[i][1] 11 | test_dataset = pd.read_csv('data/Test_DataSet.csv').values 12 | submission = pd.read_csv('result.csv').values 13 | 14 | changed_num = 0 15 | for i in range(len(test_dataset)): 16 | same_labels = [] 17 | for j in range(len(train_dataset)): 18 | # title或content相同 19 | if train_dataset[j][1] == test_dataset[i][1] : 20 | if len(same_labels) == 0: 21 | print('************************************************************************') 22 | print(str(i) + ': ' + test_dataset[i][1] + ' Train Label: ' + str(train_label_dict[train_dataset[j][0]])) 23 | same_labels.append(train_label_dict[train_dataset[j][0]]) 24 | 25 | if same_labels: 26 | changed_num += 1 27 | same_label_dict = Counter(same_labels) 28 | num_of_label0 = same_label_dict[0] 29 | num_of_label1 = same_label_dict[1] 30 | num_of_label2 = same_label_dict[2] 31 | if num_of_label0 > num_of_label1 and num_of_label0 > num_of_label2: 32 | submission[i][1] = 0 33 | elif num_of_label1 > num_of_label0 and num_of_label1 > num_of_label2: 34 | submission[i][1] = 1 35 | elif num_of_label2 > num_of_label0 and num_of_label2 > num_of_label1: 36 | submission[i][1] = 2 37 | # 有相等的,暂时先不改变 38 | else: 39 | pass 40 | 41 | submit = pd.read_csv('data/submit_example.csv') 42 | submit['label'] = submission[:, -1] 43 | submit.to_csv('./result_after.csv', index = None) 44 | 45 | origin_file = pd.read_csv(str('result.csv')).values 46 | changed_file = pd.read_csv(str('result_after.csv')).values 47 | changed = 0 48 | for i in range(len(origin_file)): 49 | if origin_file[i][1] != changed_file[i][1]: 50 | changed += 1 51 | print(str(changed) + ' labels have been changed.') 52 | print(changed_num) -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] 2 | 3 | from hubconfs.bert_hubconf import ( 4 | bertTokenizer, 5 | bertModel, 6 | bertForNextSentencePrediction, 7 | bertForPreTraining, 8 | bertForMaskedLM, 9 | bertForSequenceClassification, 10 | bertForMultipleChoice, 11 | bertForQuestionAnswering, 12 | bertForTokenClassification 13 | ) 14 | from hubconfs.gpt_hubconf import ( 15 | openAIGPTTokenizer, 16 | openAIGPTModel, 17 | openAIGPTLMHeadModel, 18 | openAIGPTDoubleHeadsModel 19 | ) 20 | from hubconfs.gpt2_hubconf import ( 21 | gpt2Tokenizer, 22 | gpt2Model, 23 | gpt2LMHeadModel, 24 | gpt2DoubleHeadsModel 25 | ) 26 | from hubconfs.transformer_xl_hubconf import ( 27 | transformerXLTokenizer, 28 | transformerXLModel, 29 | transformerXLLMHeadModel 30 | ) 31 | -------------------------------------------------------------------------------- /hubconfs/gpt2_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 2 | from pytorch_transformers.modeling_gpt2 import ( 3 | GPT2Model, 4 | GPT2LMHeadModel, 5 | GPT2DoubleHeadsModel 6 | ) 7 | 8 | # A lot of models share the same param doc. Use a decorator 9 | # to save typing 10 | gpt2_docstring = """ 11 | Params: 12 | pretrained_model_name_or_path: either: 13 | - a str with the name of a pre-trained model to load selected in the list of: 14 | . `gpt2`, `gpt2-medium` 15 | - a path or url to a pretrained model archive containing: 16 | . `gpt2_config.json` a configuration file for the model 17 | . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance 18 | - a path or url to a pretrained model archive containing: 19 | . `gpt2_config.json` a configuration file for the model 20 | . a TensorFlow checkpoint with trained weights 21 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 22 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 23 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 24 | *inputs, **kwargs: additional input for the specific GPT-2 class 25 | """ 26 | 27 | 28 | def _append_from_pretrained_docstring(docstr): 29 | def docstring_decorator(fn): 30 | fn.__doc__ = fn.__doc__ + docstr 31 | return fn 32 | return docstring_decorator 33 | 34 | 35 | def gpt2Tokenizer(*args, **kwargs): 36 | """ 37 | Instantiate a GPT-2 BPE tokenizer for OpenAI GPT-2 from a pre-trained/customized vocab file. 38 | Peculiarities: 39 | - Byte-level BPE 40 | 41 | Args: 42 | pretrained_model_name_or_path: Path to pretrained model archive 43 | or one of pre-trained vocab configs below. 44 | * gpt2 45 | Keyword args: 46 | special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...) 47 | Default: None 48 | max_len: An artificial maximum length to truncate tokenized sequences to; 49 | Effective maximum length is always the minimum of this 50 | value (if specified) and the underlying BERT model's 51 | sequence length. 52 | Default: None 53 | 54 | Example: 55 | import torch 56 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 57 | 58 | text = "Who was Jim Henson ?" 59 | indexed_tokens = tokenizer.encode(tokenized_text) 60 | """ 61 | tokenizer = GPT2Tokenizer.from_pretrained(*args, **kwargs) 62 | return tokenizer 63 | 64 | 65 | @_append_from_pretrained_docstring(gpt2_docstring) 66 | def gpt2Model(*args, **kwargs): 67 | """ 68 | gpt2Model is the basic OpenAI GPT-2 Transformer model based on 69 | identical stacked masked self-attention blocks and pre-trained 70 | on large scale dataset using language modeling signal. 71 | 72 | Example: 73 | # Load the tokenizer 74 | import torch 75 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 76 | 77 | # Prepare tokenized input 78 | text_1 = "Who was Jim Henson ?" 79 | text_2 = "Jim Henson was a puppeteer" 80 | indexed_tokens_1 = tokenizer.encode(text_1) 81 | indexed_tokens_2 = tokenizer.encode(text_2) 82 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 83 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 84 | 85 | # Load gpt2Model 86 | model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Model', 'gpt2') 87 | model.eval() 88 | 89 | # Predict hidden states features for each layer 90 | # past can be used to reuse precomputed hidden state in a subsequent predictions 91 | with torch.no_grad(): 92 | hidden_states_1, past = model(tokens_tensor_1) 93 | hidden_states_2, past = model(tokens_tensor_2, past=past) 94 | """ 95 | model = GPT2Model.from_pretrained(*args, **kwargs) 96 | return model 97 | 98 | 99 | @_append_from_pretrained_docstring(gpt2_docstring) 100 | def gpt2LMHeadModel(*args, **kwargs): 101 | """ 102 | gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the 103 | tied (pre-trained) language modeling head on top. 104 | 105 | Example: 106 | # Load the tokenizer 107 | import torch 108 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 109 | 110 | # Prepare tokenized input 111 | text_1 = "Who was Jim Henson ?" 112 | text_2 = "Jim Henson was a puppeteer" 113 | indexed_tokens_1 = tokenizer.encode(text_1) 114 | indexed_tokens_2 = tokenizer.encode(text_2) 115 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 116 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 117 | 118 | # Load gpt2LMHeadModel 119 | model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2LMHeadModel', 'gpt2') 120 | model.eval() 121 | 122 | # Predict hidden states features for each layer 123 | # past can be used to reuse precomputed hidden state in a subsequent predictions 124 | with torch.no_grad(): 125 | predictions_1, past = model(tokens_tensor_1) 126 | predictions_2, past = model(tokens_tensor_2, past=past) 127 | 128 | # Get the predicted last token 129 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 130 | predicted_token = tokenizer.decode([predicted_index]) 131 | assert predicted_token == ' who' 132 | """ 133 | model = GPT2LMHeadModel.from_pretrained(*args, **kwargs) 134 | return model 135 | 136 | 137 | @_append_from_pretrained_docstring(gpt2_docstring) 138 | def gpt2DoubleHeadsModel(*args, **kwargs): 139 | """ 140 | gpt2DoubleHeadsModel is the OpenAI GPT-2 Transformer model with the 141 | tied (pre-trained) language modeling head and a multiple choice 142 | classification head (only initialized, not pre-trained). 143 | 144 | Example: 145 | # Load the tokenizer 146 | import torch 147 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 148 | 149 | # Prepare tokenized input 150 | text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 151 | text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 152 | tokenized_text1 = tokenizer.tokenize(text1) 153 | tokenized_text2 = tokenizer.tokenize(text2) 154 | indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 155 | indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 156 | tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 157 | mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 158 | 159 | # Load gpt2DoubleHeadsModel 160 | model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2DoubleHeadsModel', 'gpt2') 161 | model.eval() 162 | 163 | # Predict hidden states features for each layer 164 | with torch.no_grad(): 165 | lm_logits, multiple_choice_logits, presents = model(tokens_tensor, mc_token_ids) 166 | """ 167 | model = GPT2DoubleHeadsModel.from_pretrained(*args, **kwargs) 168 | return model 169 | -------------------------------------------------------------------------------- /hubconfs/gpt_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer 2 | from pytorch_transformers.modeling_openai import ( 3 | OpenAIGPTModel, 4 | OpenAIGPTLMHeadModel, 5 | OpenAIGPTDoubleHeadsModel 6 | ) 7 | 8 | # Dependecies that are not specified in global hubconf.py 9 | specific_dependencies = ['spacy', 'ftfy'] 10 | 11 | # A lot of models share the same param doc. Use a decorator 12 | # to save typing 13 | gpt_docstring = """ 14 | OpenAI GPT use a single embedding matrix to store the word and special embeddings. 15 | Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]... 16 | Special tokens need to be trained during the fine-tuning if you use them. 17 | The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function. 18 | 19 | The embeddings are ordered as follow in the token embeddings matrice: 20 | [0, ---------------------- 21 | ... -> word embeddings 22 | config.vocab_size - 1, ______________________ 23 | config.vocab_size, 24 | ... -> special embeddings 25 | config.vocab_size + config.n_special - 1] ______________________ 26 | 27 | where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is: 28 | total_tokens_embeddings = config.vocab_size + config.n_special 29 | You should use the associate indices to index the embeddings. 30 | 31 | Params: 32 | pretrained_model_name_or_path: either: 33 | - a str with the name of a pre-trained model to load selected in the list of: 34 | . `openai-gpt` 35 | - a path or url to a pretrained model archive containing: 36 | . `openai_gpt_config.json` a configuration file for the model 37 | . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance 38 | - a path or url to a pretrained model archive containing: 39 | . `openai-gpt-config.json` a configuration file for the model 40 | . a series of NumPy files containing OpenAI TensorFlow trained weights 41 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 42 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 43 | state_dict: an optional state dictionary (collections.OrderedDict object) 44 | to use instead of pre-trained models 45 | *inputs, **kwargs: additional input for the specific OpenAI-GPT class 46 | """ 47 | 48 | 49 | def _append_from_pretrained_docstring(docstr): 50 | def docstring_decorator(fn): 51 | fn.__doc__ = fn.__doc__ + docstr 52 | return fn 53 | return docstring_decorator 54 | 55 | 56 | def openAIGPTTokenizer(*args, **kwargs): 57 | """ 58 | Instantiate a BPE tokenizer for OpenAI GPT from a pre-trained/customized vocab file. 59 | Peculiarities: 60 | - lower case all inputs 61 | - uses SpaCy tokenizer ('en' model) and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 62 | - argument special_tokens and function set_special_tokens: 63 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 64 | 65 | Args: 66 | pretrained_model_name_or_path: Path to pretrained model archive 67 | or one of pre-trained vocab configs below. 68 | * openai-gpt 69 | Keyword args: 70 | special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...) 71 | Default: None 72 | max_len: An artificial maximum length to truncate tokenized sequences to; 73 | Effective maximum length is always the minimum of this 74 | value (if specified) and the underlying BERT model's 75 | sequence length. 76 | Default: None 77 | 78 | Example: 79 | import torch 80 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 81 | 82 | text = "Who was Jim Henson ? Jim Henson was a puppeteer" 83 | tokenized_text = tokenizer.tokenize(text) 84 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 85 | [763, 509, 4265, 2298, 945, 257, 4265, 2298, 945, 509, 246, 10148, 39041, 483] 86 | """ 87 | tokenizer = OpenAIGPTTokenizer.from_pretrained(*args, **kwargs) 88 | return tokenizer 89 | 90 | 91 | @_append_from_pretrained_docstring(gpt_docstring) 92 | def openAIGPTModel(*args, **kwargs): 93 | """ 94 | OpenAIGPTModel is the basic OpenAI GPT Transformer model based on 95 | identical stacked masked self-attention blocks and pre-trained 96 | on large scale dataset using language modeling signal. 97 | 98 | Example: 99 | # Load the tokenizer 100 | import torch 101 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 102 | 103 | # Prepare tokenized input 104 | text = "Who was Jim Henson ? Jim Henson was a puppeteer" 105 | tokenized_text = tokenizer.tokenize(text) 106 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 107 | tokens_tensor = torch.tensor([indexed_tokens]) 108 | 109 | # Load openAIGPTModel 110 | model = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTModel', 'openai-gpt') 111 | model.eval() 112 | 113 | # Predict hidden states features for each layer 114 | with torch.no_grad(): 115 | hidden_states = model(tokens_tensor) 116 | """ 117 | model = OpenAIGPTModel.from_pretrained(*args, **kwargs) 118 | return model 119 | 120 | 121 | @_append_from_pretrained_docstring(gpt_docstring) 122 | def openAIGPTLMHeadModel(*args, **kwargs): 123 | """ 124 | OpenAIGPTLMHeadModel is the OpenAI GPT Transformer model with the 125 | tied (pre-trained) language modeling head on top. 126 | 127 | Example: 128 | # Load the tokenizer 129 | import torch 130 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 131 | 132 | # Prepare tokenized input 133 | text = "Who was Jim Henson ? Jim Henson was a puppeteer" 134 | tokenized_text = tokenizer.tokenize(text) 135 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 136 | tokens_tensor = torch.tensor([indexed_tokens]) 137 | 138 | # Load openAIGPTLMHeadModel 139 | model = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTLMHeadModel', 'openai-gpt') 140 | model.eval() 141 | 142 | # Predict hidden states features for each layer 143 | with torch.no_grad(): 144 | predictions = model(tokens_tensor) 145 | 146 | # Get the predicted last token 147 | predicted_index = torch.argmax(predictions[0, -1, :]).item() 148 | predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 149 | '.' 150 | """ 151 | model = OpenAIGPTLMHeadModel.from_pretrained(*args, **kwargs) 152 | return model 153 | 154 | 155 | @_append_from_pretrained_docstring(gpt_docstring) 156 | def openAIGPTDoubleHeadsModel(*args, **kwargs): 157 | """ 158 | OpenAIGPTDoubleHeadsModel is the OpenAI GPT Transformer model with the 159 | tied (pre-trained) language modeling head and a multiple choice 160 | classification head (only initialized, not pre-trained). 161 | 162 | Example: 163 | # Load the tokenizer 164 | import torch 165 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 166 | 167 | # Prepare tokenized input 168 | text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 169 | text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 170 | tokenized_text1 = tokenizer.tokenize(text1) 171 | tokenized_text2 = tokenizer.tokenize(text2) 172 | indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 173 | indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 174 | tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 175 | mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 176 | 177 | # Load openAIGPTDoubleHeadsModel 178 | model = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTDoubleHeadsModel', 'openai-gpt') 179 | model.eval() 180 | 181 | # Predict hidden states features for each layer 182 | with torch.no_grad(): 183 | lm_logits, multiple_choice_logits = model(tokens_tensor, mc_token_ids) 184 | """ 185 | model = OpenAIGPTDoubleHeadsModel.from_pretrained(*args, **kwargs) 186 | return model 187 | -------------------------------------------------------------------------------- /hubconfs/transformer_xl_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer 2 | from pytorch_transformers.modeling_transfo_xl import ( 3 | TransfoXLModel, 4 | TransfoXLLMHeadModel 5 | ) 6 | 7 | # A lot of models share the same param doc. Use a decorator 8 | # to save typing 9 | transformer_xl_docstring = """ 10 | Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that: 11 | - you don't need to specify positioning embeddings indices 12 | - the tokens in the vocabulary have to be sorted to decreasing frequency. 13 | 14 | Params: 15 | pretrained_model_name_or_path: either: 16 | - a str with the name of a pre-trained model to load selected in the list of: 17 | . `transfo-xl-wt103` 18 | - a path or url to a pretrained model archive containing: 19 | . `transfo_xl_config.json` a configuration file for the model 20 | . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance 21 | - a path or url to a pretrained model archive containing: 22 | . `transfo_xl_config.json` a configuration file for the model 23 | . `model.chkpt` a TensorFlow checkpoint 24 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 25 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 26 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 27 | *inputs, **kwargs: additional input for the specific TransformerXL class 28 | """ 29 | 30 | 31 | def _append_from_pretrained_docstring(docstr): 32 | def docstring_decorator(fn): 33 | fn.__doc__ = fn.__doc__ + docstr 34 | return fn 35 | return docstring_decorator 36 | 37 | 38 | def transformerXLTokenizer(*args, **kwargs): 39 | """ 40 | Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 41 | 42 | Args: 43 | pretrained_model_name_or_path: Path to pretrained model archive 44 | or one of pre-trained vocab configs below. 45 | * transfo-xl-wt103 46 | 47 | Example: 48 | import torch 49 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 50 | 51 | text = "Who was Jim Henson ?" 52 | tokenized_text = tokenizer.tokenize(tokenized_text) 53 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 54 | """ 55 | tokenizer = TransfoXLTokenizer.from_pretrained(*args, **kwargs) 56 | return tokenizer 57 | 58 | 59 | @_append_from_pretrained_docstring(transformer_xl_docstring) 60 | def transformerXLModel(*args, **kwargs): 61 | """ 62 | transformerXLModel is the basic Transformer XL model. 63 | 64 | Example: 65 | # Load the tokenizer 66 | import torch 67 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 68 | 69 | # Prepare tokenized input 70 | text_1 = "Who was Jim Henson ?" 71 | text_2 = "Jim Henson was a puppeteer" 72 | tokenized_text_1 = tokenizer.tokenize(text_1) 73 | tokenized_text_2 = tokenizer.tokenize(text_2) 74 | indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) 75 | indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) 76 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 77 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 78 | 79 | # Load transformerXLModel 80 | model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLModel', 'transfo-xl-wt103') 81 | model.eval() 82 | 83 | # Predict hidden states features for each layer 84 | # We can re-use the memory cells in a subsequent call to attend a longer context 85 | with torch.no_grad(): 86 | hidden_states_1, mems_1 = model(tokens_tensor_1) 87 | hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1) 88 | """ 89 | model = TransfoXLModel.from_pretrained(*args, **kwargs) 90 | return model 91 | 92 | 93 | @_append_from_pretrained_docstring(transformer_xl_docstring) 94 | def transformerXLLMHeadModel(*args, **kwargs): 95 | """ 96 | transformerXLModel is the basic Transformer XL model with the 97 | tied (pre-trained) language modeling head on top. 98 | 99 | Example: 100 | # Load the tokenizer 101 | import torch 102 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 103 | 104 | # Prepare tokenized input 105 | text_1 = "Who was Jim Henson ?" 106 | text_2 = "Jim Henson was a puppeteer" 107 | tokenized_text_1 = tokenizer.tokenize(text_1) 108 | tokenized_text_2 = tokenizer.tokenize(text_2) 109 | indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) 110 | indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) 111 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 112 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 113 | 114 | # Load transformerXLLMHeadModel 115 | model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLLMHeadModel', 'transfo-xl-wt103') 116 | model.eval() 117 | 118 | # Predict hidden states features for each layer 119 | # We can re-use the memory cells in a subsequent call to attend a longer context 120 | with torch.no_grad(): 121 | predictions_1, mems_1 = model(tokens_tensor_1) 122 | predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1) 123 | 124 | # Get the predicted last token 125 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 126 | predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 127 | assert predicted_token == 'who' 128 | """ 129 | model = TransfoXLLMHeadModel.from_pretrained(*args, **kwargs) 130 | return model 131 | -------------------------------------------------------------------------------- /hubconfs/xlm_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_xlm import XLMTokenizer 2 | from pytorch_transformers.modeling_xlm import ( 3 | XLMConfig, 4 | XLMModel, 5 | XLMWithLMHeadModel, 6 | XLMForSequenceClassification, 7 | XLMForQuestionAnswering 8 | ) 9 | 10 | # A lot of models share the same param doc. Use a decorator 11 | # to save typing 12 | xlm_start_docstring = """ 13 | Model class adapted from the XLM Transformer model of 14 | "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau 15 | Paper: https://arxiv.org/abs/1901.07291 16 | Original code: https://github.com/facebookresearch/XLM 17 | 18 | Example: 19 | # Load the tokenizer 20 | import torch 21 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlmTokenizer', 'xlm-mlm-en-2048') 22 | 23 | # Prepare tokenized input 24 | text_1 = "Who was Jim Henson ?" 25 | text_2 = "Jim Henson was a puppeteer" 26 | indexed_tokens_1 = tokenizer.encode(text_1) 27 | indexed_tokens_2 = tokenizer.encode(text_2) 28 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 29 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 30 | """ 31 | 32 | # A lot of models share the same param doc. Use a decorator 33 | # to save typing 34 | xlm_end_docstring = """ 35 | Params: 36 | pretrained_model_name_or_path: either: 37 | - a str with the name of a pre-trained model to load selected in the list of: 38 | . `xlm-mlm-en-2048` 39 | - a path or url to a pretrained model archive containing: 40 | . `config.json` a configuration file for the model 41 | . `pytorch_model.bin` a PyTorch dump created using the `convert_xlm_checkpoint_to_pytorch` conversion script 42 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 43 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 44 | *inputs, **kwargs: additional input for the specific XLM class 45 | """ 46 | 47 | 48 | def _begin_with_docstring(docstr): 49 | def docstring_decorator(fn): 50 | fn.__doc__ = fn.__doc__ + docstr 51 | return fn 52 | return docstring_decorator 53 | 54 | def _end_with_docstring(docstr): 55 | def docstring_decorator(fn): 56 | fn.__doc__ = fn.__doc__ + docstr 57 | return fn 58 | return docstring_decorator 59 | 60 | 61 | def xlmTokenizer(*args, **kwargs): 62 | """ 63 | Instantiate a XLM BPE tokenizer for XLM from a pre-trained vocab file. 64 | 65 | Args: 66 | pretrained_model_name_or_path: Path to pretrained model archive 67 | or one of pre-trained vocab configs below. 68 | * xlm-mlm-en-2048 69 | Keyword args: 70 | special_tokens: Special tokens in vocabulary that are not pretrained 71 | Default: None 72 | max_len: An artificial maximum length to truncate tokenized sequences to; 73 | Effective maximum length is always the minimum of this 74 | value (if specified) and the underlying model's 75 | sequence length. 76 | Default: None 77 | 78 | Example: 79 | import torch 80 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlmTokenizer', 'xlm-mlm-en-2048') 81 | 82 | text = "Who was Jim Henson ?" 83 | indexed_tokens = tokenizer.encode(tokenized_text) 84 | """ 85 | tokenizer = XLMTokenizer.from_pretrained(*args, **kwargs) 86 | return tokenizer 87 | 88 | 89 | @_begin_with_docstring(xlm_start_docstring) 90 | @_end_with_docstring(xlm_end_docstring) 91 | def xlmModel(*args, **kwargs): 92 | """ 93 | # Load xlmModel 94 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlmModel', 'xlm-mlm-en-2048') 95 | model.eval() 96 | 97 | # Predict hidden states features for each layer 98 | with torch.no_grad(): 99 | hidden_states_1, mems = model(tokens_tensor_1) 100 | hidden_states_2, mems = model(tokens_tensor_2, past=mems) 101 | """ 102 | model = XLMModel.from_pretrained(*args, **kwargs) 103 | return model 104 | 105 | 106 | @_begin_with_docstring(xlm_start_docstring) 107 | @_end_with_docstring(xlm_end_docstring) 108 | def xlmLMHeadModel(*args, **kwargs): 109 | """ 110 | # Prepare tokenized input 111 | text_1 = "Who was Jim Henson ?" 112 | text_2 = "Jim Henson was a puppeteer" 113 | indexed_tokens_1 = tokenizer.encode(text_1) 114 | indexed_tokens_2 = tokenizer.encode(text_2) 115 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 116 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 117 | 118 | # Load xlnetLMHeadModel 119 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetLMHeadModel', 'xlm-mlm-en-2048') 120 | model.eval() 121 | 122 | # Predict hidden states features for each layer 123 | with torch.no_grad(): 124 | predictions_1, mems = model(tokens_tensor_1) 125 | predictions_2, mems = model(tokens_tensor_2, mems=mems) 126 | 127 | # Get the predicted last token 128 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 129 | predicted_token = tokenizer.decode([predicted_index]) 130 | assert predicted_token == ' who' 131 | """ 132 | model = XLMWithLMHeadModel.from_pretrained(*args, **kwargs) 133 | return model 134 | 135 | 136 | # @_end_with_docstring(xlnet_docstring) 137 | # def xlnetForSequenceClassification(*args, **kwargs): 138 | # """ 139 | # xlnetModel is the basic XLNet Transformer model from 140 | # "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 141 | # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 142 | 143 | # Example: 144 | # # Load the tokenizer 145 | # import torch 146 | # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlm-mlm-en-2048') 147 | 148 | # # Prepare tokenized input 149 | # text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 150 | # text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 151 | # tokenized_text1 = tokenizer.tokenize(text1) 152 | # tokenized_text2 = tokenizer.tokenize(text2) 153 | # indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 154 | # indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 155 | # tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 156 | # mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 157 | 158 | # # Load xlnetForSequenceClassification 159 | # model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetForSequenceClassification', 'xlm-mlm-en-2048') 160 | # model.eval() 161 | 162 | # # Predict sequence classes logits 163 | # with torch.no_grad(): 164 | # lm_logits, mems = model(tokens_tensor) 165 | # """ 166 | # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) 167 | # return model 168 | -------------------------------------------------------------------------------- /hubconfs/xlnet_hubconf.1.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_xlnet import XLNetTokenizer 2 | from pytorch_transformers.modeling_xlnet import ( 3 | XLNetConfig, 4 | XLNetModel, 5 | XLNetLMHeadModel, 6 | # XLNetForSequenceClassification 7 | ) 8 | 9 | # A lot of models share the same param doc. Use a decorator 10 | # to save typing 11 | xlnet_docstring = """ 12 | Params: 13 | pretrained_model_name_or_path: either: 14 | - a str with the name of a pre-trained model to load selected in the list of: 15 | . `xlnet-large-cased` 16 | - a path or url to a pretrained model archive containing: 17 | . `config.json` a configuration file for the model 18 | . `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance 19 | - a path or url to a pretrained model archive containing: 20 | . `xlnet_config.json` a configuration file for the model 21 | . `model.chkpt` a TensorFlow checkpoint 22 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 23 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 24 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 25 | *inputs, **kwargs: additional input for the specific XLNet class 26 | """ 27 | 28 | 29 | def _append_from_pretrained_docstring(docstr): 30 | def docstring_decorator(fn): 31 | fn.__doc__ = fn.__doc__ + docstr 32 | return fn 33 | return docstring_decorator 34 | 35 | 36 | def xlnetTokenizer(*args, **kwargs): 37 | """ 38 | Instantiate a XLNet sentencepiece tokenizer for XLNet from a pre-trained vocab file. 39 | Peculiarities: 40 | - require Google sentencepiece (https://github.com/google/sentencepiece) 41 | 42 | Args: 43 | pretrained_model_name_or_path: Path to pretrained model archive 44 | or one of pre-trained vocab configs below. 45 | * xlnet-large-cased 46 | Keyword args: 47 | special_tokens: Special tokens in vocabulary that are not pretrained 48 | Default: None 49 | max_len: An artificial maximum length to truncate tokenized sequences to; 50 | Effective maximum length is always the minimum of this 51 | value (if specified) and the underlying model's 52 | sequence length. 53 | Default: None 54 | 55 | Example: 56 | import torch 57 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 58 | 59 | text = "Who was Jim Henson ?" 60 | indexed_tokens = tokenizer.encode(tokenized_text) 61 | """ 62 | tokenizer = XLNetTokenizer.from_pretrained(*args, **kwargs) 63 | return tokenizer 64 | 65 | 66 | @_append_from_pretrained_docstring(xlnet_docstring) 67 | def xlnetModel(*args, **kwargs): 68 | """ 69 | xlnetModel is the basic XLNet Transformer model from 70 | "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 71 | by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 72 | 73 | Example: 74 | # Load the tokenizer 75 | import torch 76 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 77 | 78 | # Prepare tokenized input 79 | text_1 = "Who was Jim Henson ?" 80 | text_2 = "Jim Henson was a puppeteer" 81 | indexed_tokens_1 = tokenizer.encode(text_1) 82 | indexed_tokens_2 = tokenizer.encode(text_2) 83 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 84 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 85 | 86 | # Load xlnetModel 87 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetModel', 'xlnet-large-cased') 88 | model.eval() 89 | 90 | # Predict hidden states features for each layer 91 | with torch.no_grad(): 92 | hidden_states_1, mems = model(tokens_tensor_1) 93 | hidden_states_2, mems = model(tokens_tensor_2, past=mems) 94 | """ 95 | model = XLNetModel.from_pretrained(*args, **kwargs) 96 | return model 97 | 98 | 99 | @_append_from_pretrained_docstring(xlnet_docstring) 100 | def xlnetLMHeadModel(*args, **kwargs): 101 | """ 102 | xlnetModel is the basic XLNet Transformer model from 103 | "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 104 | by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 105 | with a tied (pre-trained) language modeling head on top. 106 | 107 | Example: 108 | # Load the tokenizer 109 | import torch 110 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 111 | 112 | # Prepare tokenized input 113 | text_1 = "Who was Jim Henson ?" 114 | text_2 = "Jim Henson was a puppeteer" 115 | indexed_tokens_1 = tokenizer.encode(text_1) 116 | indexed_tokens_2 = tokenizer.encode(text_2) 117 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 118 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 119 | 120 | # Load xlnetLMHeadModel 121 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetLMHeadModel', 'xlnet-large-cased') 122 | model.eval() 123 | 124 | # Predict hidden states features for each layer 125 | with torch.no_grad(): 126 | predictions_1, mems = model(tokens_tensor_1) 127 | predictions_2, mems = model(tokens_tensor_2, mems=mems) 128 | 129 | # Get the predicted last token 130 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 131 | predicted_token = tokenizer.decode([predicted_index]) 132 | assert predicted_token == ' who' 133 | """ 134 | model = XLNetLMHeadModel.from_pretrained(*args, **kwargs) 135 | return model 136 | 137 | 138 | # @_append_from_pretrained_docstring(xlnet_docstring) 139 | # def xlnetForSequenceClassification(*args, **kwargs): 140 | # """ 141 | # xlnetModel is the basic XLNet Transformer model from 142 | # "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 143 | # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 144 | 145 | # Example: 146 | # # Load the tokenizer 147 | # import torch 148 | # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 149 | 150 | # # Prepare tokenized input 151 | # text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 152 | # text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 153 | # tokenized_text1 = tokenizer.tokenize(text1) 154 | # tokenized_text2 = tokenizer.tokenize(text2) 155 | # indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 156 | # indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 157 | # tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 158 | # mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 159 | 160 | # # Load xlnetForSequenceClassification 161 | # model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetForSequenceClassification', 'xlnet-large-cased') 162 | # model.eval() 163 | 164 | # # Predict sequence classes logits 165 | # with torch.no_grad(): 166 | # lm_logits, mems = model(tokens_tensor) 167 | # """ 168 | # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) 169 | # return model 170 | -------------------------------------------------------------------------------- /pretrained_model/readme.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pretrained_model/readme.txt -------------------------------------------------------------------------------- /pytorch_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | from .tokenization_auto import AutoTokenizer 3 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 4 | from .tokenization_openai import OpenAIGPTTokenizer 5 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 6 | from .tokenization_gpt2 import GPT2Tokenizer 7 | from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE 8 | from .tokenization_xlm import XLMTokenizer 9 | from .tokenization_roberta import RobertaTokenizer 10 | 11 | from .tokenization_utils import (PreTrainedTokenizer) 12 | 13 | from .modeling_auto import (AutoConfig, AutoModel) 14 | 15 | from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining, 16 | BertForMaskedLM, BertForNextSentencePrediction, 17 | BertForSequenceClassification, BertForMultipleChoice, 18 | BertForTokenClassification, BertForQuestionAnswering, 19 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 20 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 21 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTPreTrainedModel, OpenAIGPTModel, 22 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 23 | load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, 24 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) 25 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, 26 | load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, 27 | TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) 28 | from .modeling_gpt2 import (GPT2Config, GPT2PreTrainedModel, GPT2Model, 29 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 30 | load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, 31 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) 32 | from .modeling_xlnet import (XLNetConfig, 33 | XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, 34 | XLNetForSequenceClassification, XLNetForQuestionAnswering, 35 | load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, 36 | XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) 37 | from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, 38 | XLMWithLMHeadModel, XLMForSequenceClassification, 39 | XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, 40 | XLM_PRETRAINED_MODEL_ARCHIVE_MAP) 41 | from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, 42 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) 43 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, 44 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) 45 | 46 | from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, 47 | WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 48 | 49 | from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path) 50 | -------------------------------------------------------------------------------- /pytorch_transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: 5 | print( 6 | "Should be used as one of: \n" 7 | ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" 8 | ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" 9 | ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" 10 | ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" 11 | ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" 12 | ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") 13 | else: 14 | if sys.argv[1] == "bert": 15 | try: 16 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 17 | except ImportError: 18 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 19 | "In that case, it requires TensorFlow to be installed. Please see " 20 | "https://www.tensorflow.org/install/ for installation instructions.") 21 | raise 22 | 23 | if len(sys.argv) != 5: 24 | # pylint: disable=line-too-long 25 | print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 26 | else: 27 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 28 | TF_CONFIG = sys.argv.pop() 29 | TF_CHECKPOINT = sys.argv.pop() 30 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 31 | elif sys.argv[1] == "gpt": 32 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 33 | if len(sys.argv) < 4 or len(sys.argv) > 5: 34 | # pylint: disable=line-too-long 35 | print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") 36 | else: 37 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 38 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 39 | if len(sys.argv) == 5: 40 | OPENAI_GPT_CONFIG = sys.argv[4] 41 | else: 42 | OPENAI_GPT_CONFIG = "" 43 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 44 | OPENAI_GPT_CONFIG, 45 | PYTORCH_DUMP_OUTPUT) 46 | elif sys.argv[1] == "transfo_xl": 47 | try: 48 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 49 | except ImportError: 50 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 51 | "In that case, it requires TensorFlow to be installed. Please see " 52 | "https://www.tensorflow.org/install/ for installation instructions.") 53 | raise 54 | if len(sys.argv) < 4 or len(sys.argv) > 5: 55 | # pylint: disable=line-too-long 56 | print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 57 | else: 58 | if 'ckpt' in sys.argv[2].lower(): 59 | TF_CHECKPOINT = sys.argv[2] 60 | TF_DATASET_FILE = "" 61 | else: 62 | TF_DATASET_FILE = sys.argv[2] 63 | TF_CHECKPOINT = "" 64 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 65 | if len(sys.argv) == 5: 66 | TF_CONFIG = sys.argv[4] 67 | else: 68 | TF_CONFIG = "" 69 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 70 | elif sys.argv[1] == "gpt2": 71 | try: 72 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 73 | except ImportError: 74 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 75 | "In that case, it requires TensorFlow to be installed. Please see " 76 | "https://www.tensorflow.org/install/ for installation instructions.") 77 | raise 78 | 79 | if len(sys.argv) < 4 or len(sys.argv) > 5: 80 | # pylint: disable=line-too-long 81 | print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 82 | else: 83 | TF_CHECKPOINT = sys.argv[2] 84 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 85 | if len(sys.argv) == 5: 86 | TF_CONFIG = sys.argv[4] 87 | else: 88 | TF_CONFIG = "" 89 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 90 | elif sys.argv[1] == "xlnet": 91 | try: 92 | from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch 93 | except ImportError: 94 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 95 | "In that case, it requires TensorFlow to be installed. Please see " 96 | "https://www.tensorflow.org/install/ for installation instructions.") 97 | raise 98 | 99 | if len(sys.argv) < 5 or len(sys.argv) > 6: 100 | # pylint: disable=line-too-long 101 | print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") 102 | else: 103 | TF_CHECKPOINT = sys.argv[2] 104 | TF_CONFIG = sys.argv[3] 105 | PYTORCH_DUMP_OUTPUT = sys.argv[4] 106 | if len(sys.argv) == 6: 107 | FINETUNING_TASK = sys.argv[5] 108 | else: 109 | FINETUNING_TASK = None 110 | 111 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, 112 | TF_CONFIG, 113 | PYTORCH_DUMP_OUTPUT, 114 | FINETUNING_TASK) 115 | elif sys.argv[1] == "xlm": 116 | from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch 117 | 118 | if len(sys.argv) != 4: 119 | # pylint: disable=line-too-long 120 | print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") 121 | else: 122 | XLM_CHECKPOINT_PATH = sys.argv[2] 123 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 124 | 125 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/file_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/file_utils.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_auto.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_auto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_auto.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_bert.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_bert.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_openai.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_openai.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_roberta.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_roberta.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_utils.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_xlm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_xlm.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/modeling_xlnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/modeling_xlnet.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/optimization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/optimization.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_auto.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_auto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_auto.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_bert.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_bert.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_openai.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_openai.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_roberta.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_roberta.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_utils.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_xlm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_xlm.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_transformers/__pycache__/tokenization_xlnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/__pycache__/tokenization_xlnet.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_pytorch_checkpoint_to_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from pytorch_transformers.modeling import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_transformers.tokenization_transfo_xl as data_utils 27 | 28 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from pytorch_transformers.modeling_transfo_xl import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME 27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | model = chkpt['model'] 37 | 38 | config = chkpt['params'] 39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 40 | 41 | vocab = chkpt['dico_word2id'] 42 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 48 | 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model, pytorch_weights_dump_path) 51 | 52 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 54 | f.write(json.dumps(config, indent=2) + "\n") 55 | 56 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(vocab, indent=2) + "\n") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | ## Required parameters 64 | parser.add_argument("--xlm_checkpoint_path", 65 | default = None, 66 | type = str, 67 | required = True, 68 | help = "Path the official PyTorch dump.") 69 | parser.add_argument("--pytorch_dump_folder_path", 70 | default = None, 71 | type = str, 72 | required = True, 73 | help = "Path to the output PyTorch model.") 74 | args = parser.parse_args() 75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /pytorch_transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | self.warmup_steps = warmup_steps 40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 41 | 42 | def lr_lambda(self, step): 43 | if step < self.warmup_steps: 44 | return float(step) / float(max(1.0, self.warmup_steps)) 45 | return 1. 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | self.t_total = t_total 56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1, self.warmup_steps)) 61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """ Linear warmup and then cosine decay. 66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 75 | 76 | def lr_lambda(self, step): 77 | if step < self.warmup_steps: 78 | return float(step) / float(max(1.0, self.warmup_steps)) 79 | # progress after warmup 80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 82 | 83 | 84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 85 | """ Linear warmup and then cosine cycles with hard restarts. 86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 88 | learning rate (with hard restarts). 89 | """ 90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 91 | self.warmup_steps = warmup_steps 92 | self.t_total = t_total 93 | self.cycles = cycles 94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 95 | 96 | def lr_lambda(self, step): 97 | if step < self.warmup_steps: 98 | return float(step) / float(max(1, self.warmup_steps)) 99 | # progress after warmup 100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 101 | if progress >= 1.0: 102 | return 0.0 103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) 104 | 105 | 106 | 107 | class AdamW(Optimizer): 108 | """ Implements Adam algorithm with weight decay fix. 109 | 110 | Parameters: 111 | lr (float): learning rate. Default 1e-3. 112 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 113 | eps (float): Adams epsilon. Default: 1e-6 114 | weight_decay (float): Weight decay. Default: 0.0 115 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 116 | """ 117 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 118 | if lr < 0.0: 119 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 120 | if not 0.0 <= betas[0] < 1.0: 121 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 122 | if not 0.0 <= betas[1] < 1.0: 123 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 124 | if not 0.0 <= eps: 125 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 126 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 127 | correct_bias=correct_bias) 128 | super(AdamW, self).__init__(params, defaults) 129 | 130 | def step(self, closure=None): 131 | """Performs a single optimization step. 132 | 133 | Arguments: 134 | closure (callable, optional): A closure that reevaluates the model 135 | and returns the loss. 136 | """ 137 | loss = None 138 | if closure is not None: 139 | loss = closure() 140 | 141 | for group in self.param_groups: 142 | for p in group['params']: 143 | if p.grad is None: 144 | continue 145 | grad = p.grad.data 146 | if grad.is_sparse: 147 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 148 | 149 | state = self.state[p] 150 | 151 | # State initialization 152 | if len(state) == 0: 153 | state['step'] = 0 154 | # Exponential moving average of gradient values 155 | state['exp_avg'] = torch.zeros_like(p.data) 156 | # Exponential moving average of squared gradient values 157 | state['exp_avg_sq'] = torch.zeros_like(p.data) 158 | 159 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 160 | beta1, beta2 = group['betas'] 161 | 162 | state['step'] += 1 163 | 164 | # Decay the first and second moment running average coefficient 165 | # In-place operations to update the averages at the same time 166 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 167 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 168 | denom = exp_avg_sq.sqrt().add_(group['eps']) 169 | 170 | step_size = group['lr'] 171 | if group['correct_bias']: # No bias correction for Bert 172 | bias_correction1 = 1.0 - beta1 ** state['step'] 173 | bias_correction2 = 1.0 - beta2 ** state['step'] 174 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 175 | 176 | p.data.addcdiv_(-step_size, exp_avg, denom) 177 | 178 | # Just adding the square of the weights to the loss function is *not* 179 | # the correct way of using L2 regularization/weight decay with Adam, 180 | # since that will interact with the m and v parameters in strange ways. 181 | # 182 | # Instead we want to decay the weights in a manner that doesn't interact 183 | # with the m/v parameters. This is equivalent to adding the square 184 | # of the weights to the loss with plain (non-momentum) SGD. 185 | # Add weight decay at the end (fixed version) 186 | if group['weight_decay'] > 0.0: 187 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 188 | 189 | return loss 190 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/tests/__init__.py -------------------------------------------------------------------------------- /pytorch_transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaow/TextClassification/aa479ae0941c008602631c50124d8c07d159bfb1/pytorch_transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel 25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 26 | 27 | from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) 28 | 29 | 30 | class AutoModelTest(unittest.TestCase): 31 | def test_model_from_pretrained(self): 32 | logging.basicConfig(level=logging.INFO) 33 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 34 | config = AutoConfig.from_pretrained(model_name) 35 | self.assertIsNotNone(config) 36 | self.assertIsInstance(config, BertConfig) 37 | 38 | model = AutoModel.from_pretrained(model_name) 39 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 40 | self.assertIsNotNone(model) 41 | self.assertIsInstance(model, BertModel) 42 | for value in loading_info.values(): 43 | self.assertEqual(len(value), 0) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import pytest 21 | 22 | 23 | from pytorch_transformers import (GPT2Config, GPT2Model, 24 | GPT2LMHeadModel, GPT2DoubleHeadsModel) 25 | 26 | from .modeling_common_test import CommonTestCases, ConfigTester 27 | 28 | class GPT2ModelTest(unittest.TestCase): 29 | 30 | def test_config(self): 31 | config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) 32 | config_tester.run_common_tests() 33 | 34 | def test_model(self): 35 | model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model, 36 | lm_head_model_class=GPT2LMHeadModel, 37 | double_head_model_class=GPT2DoubleHeadsModel) 38 | model_tester.run_common_tests(test_presents=True) 39 | 40 | @pytest.mark.slow 41 | def test_pretrained(self): 42 | model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model, 43 | lm_head_model_class=GPT2LMHeadModel, 44 | double_head_model_class=GPT2DoubleHeadsModel) 45 | model_tester.run_slow_tests() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import pytest 21 | 22 | 23 | from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, 24 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) 25 | 26 | from .modeling_common_test import CommonTestCases, ConfigTester 27 | 28 | class OpenAIModelTest(unittest.TestCase): 29 | 30 | def test_config(self): 31 | config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) 32 | config_tester.run_common_tests() 33 | 34 | def test_model(self): 35 | model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel, 36 | lm_head_model_class=OpenAIGPTLMHeadModel, 37 | double_head_model_class=OpenAIGPTDoubleHeadsModel) 38 | model_tester.run_common_tests(test_presents=False) 39 | 40 | @pytest.mark.slow 41 | def test_pretrained(self): 42 | model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel, 43 | lm_head_model_class=OpenAIGPTLMHeadModel, 44 | double_head_model_class=OpenAIGPTDoubleHeadsModel) 45 | model_tester.run_slow_tests() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, 25 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 26 | 27 | from .tokenization_tests_commons import TemporaryDirectory 28 | 29 | 30 | def unwrap_schedule(scheduler, num_steps=10): 31 | lrs = [] 32 | for _ in range(num_steps): 33 | scheduler.step() 34 | lrs.append(scheduler.get_lr()) 35 | return lrs 36 | 37 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 38 | lrs = [] 39 | for step in range(num_steps): 40 | scheduler.step() 41 | lrs.append(scheduler.get_lr()) 42 | if step == num_steps // 2: 43 | with TemporaryDirectory() as tmpdirname: 44 | file_name = os.path.join(tmpdirname, 'schedule.bin') 45 | torch.save(scheduler.state_dict(), file_name) 46 | 47 | state_dict = torch.load(file_name) 48 | scheduler.load_state_dict(state_dict) 49 | return lrs 50 | 51 | class OptimizationTest(unittest.TestCase): 52 | 53 | def assertListAlmostEqual(self, list1, list2, tol): 54 | self.assertEqual(len(list1), len(list2)) 55 | for a, b in zip(list1, list2): 56 | self.assertAlmostEqual(a, b, delta=tol) 57 | 58 | def test_adam_w(self): 59 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 60 | target = torch.tensor([0.4, 0.2, -0.5]) 61 | criterion = torch.nn.MSELoss() 62 | # No warmup, constant schedule, no gradient clipping 63 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 64 | for _ in range(100): 65 | loss = criterion(w, target) 66 | loss.backward() 67 | optimizer.step() 68 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 69 | w.grad.zero_() 70 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 71 | 72 | 73 | class ScheduleInitTest(unittest.TestCase): 74 | m = torch.nn.Linear(50, 50) 75 | optimizer = AdamW(m.parameters(), lr=10.) 76 | num_steps = 10 77 | 78 | def assertListAlmostEqual(self, list1, list2, tol): 79 | self.assertEqual(len(list1), len(list2)) 80 | for a, b in zip(list1, list2): 81 | self.assertAlmostEqual(a, b, delta=tol) 82 | 83 | def test_constant_scheduler(self): 84 | scheduler = ConstantLRSchedule(self.optimizer) 85 | lrs = unwrap_schedule(scheduler, self.num_steps) 86 | expected_learning_rates = [10.] * self.num_steps 87 | self.assertEqual(len(lrs[0]), 1) 88 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 89 | 90 | scheduler = ConstantLRSchedule(self.optimizer) 91 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 92 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 93 | 94 | def test_warmup_constant_scheduler(self): 95 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 96 | lrs = unwrap_schedule(scheduler, self.num_steps) 97 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 98 | self.assertEqual(len(lrs[0]), 1) 99 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 100 | 101 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 102 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 103 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 104 | 105 | def test_warmup_linear_scheduler(self): 106 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 107 | lrs = unwrap_schedule(scheduler, self.num_steps) 108 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 109 | self.assertEqual(len(lrs[0]), 1) 110 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 111 | 112 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 113 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 114 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 115 | 116 | def test_warmup_cosine_scheduler(self): 117 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 118 | lrs = unwrap_schedule(scheduler, self.num_steps) 119 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 120 | self.assertEqual(len(lrs[0]), 1) 121 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 122 | 123 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 124 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 125 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 126 | 127 | def test_warmup_cosine_hard_restart_scheduler(self): 128 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 129 | lrs = unwrap_schedule(scheduler, self.num_steps) 130 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 131 | self.assertEqual(len(lrs[0]), 1) 132 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 133 | 134 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 135 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 136 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 26 | from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP 27 | 28 | 29 | class AutoTokenizerTest(unittest.TestCase): 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | 29 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 30 | 31 | tokenizer_class = BertTokenizer 32 | 33 | def setUp(self): 34 | super(BertTokenizationTest, self).setUp() 35 | 36 | vocab_tokens = [ 37 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 38 | "##ing", ",", "low", "lowest", 39 | ] 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 42 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 43 | 44 | def get_tokenizer(self): 45 | return BertTokenizer.from_pretrained(self.tmpdirname) 46 | 47 | def get_input_output_texts(self): 48 | input_text = u"UNwant\u00E9d,running" 49 | output_text = u"unwanted, running" 50 | return input_text, output_text 51 | 52 | def test_full_tokenizer(self): 53 | tokenizer = BertTokenizer(self.vocab_file) 54 | 55 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 56 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 57 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 58 | 59 | def test_chinese(self): 60 | tokenizer = BasicTokenizer() 61 | 62 | self.assertListEqual( 63 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 64 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 65 | 66 | def test_basic_tokenizer_lower(self): 67 | tokenizer = BasicTokenizer(do_lower_case=True) 68 | 69 | self.assertListEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["hello", "!", "how", "are", "you", "?"]) 72 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 73 | 74 | def test_basic_tokenizer_no_lower(self): 75 | tokenizer = BasicTokenizer(do_lower_case=False) 76 | 77 | self.assertListEqual( 78 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 79 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 80 | 81 | def test_wordpiece_tokenizer(self): 82 | vocab_tokens = [ 83 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 84 | "##ing" 85 | ] 86 | 87 | vocab = {} 88 | for (i, token) in enumerate(vocab_tokens): 89 | vocab[token] = i 90 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 91 | 92 | self.assertListEqual(tokenizer.tokenize(""), []) 93 | 94 | self.assertListEqual( 95 | tokenizer.tokenize("unwanted running"), 96 | ["un", "##want", "##ed", "runn", "##ing"]) 97 | 98 | self.assertListEqual( 99 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 100 | 101 | def test_is_whitespace(self): 102 | self.assertTrue(_is_whitespace(u" ")) 103 | self.assertTrue(_is_whitespace(u"\t")) 104 | self.assertTrue(_is_whitespace(u"\r")) 105 | self.assertTrue(_is_whitespace(u"\n")) 106 | self.assertTrue(_is_whitespace(u"\u00A0")) 107 | 108 | self.assertFalse(_is_whitespace(u"A")) 109 | self.assertFalse(_is_whitespace(u"-")) 110 | 111 | def test_is_control(self): 112 | self.assertTrue(_is_control(u"\u0005")) 113 | 114 | self.assertFalse(_is_control(u"A")) 115 | self.assertFalse(_is_control(u" ")) 116 | self.assertFalse(_is_control(u"\t")) 117 | self.assertFalse(_is_control(u"\r")) 118 | 119 | def test_is_punctuation(self): 120 | self.assertTrue(_is_punctuation(u"-")) 121 | self.assertTrue(_is_punctuation(u"$")) 122 | self.assertTrue(_is_punctuation(u"`")) 123 | self.assertTrue(_is_punctuation(u".")) 124 | 125 | self.assertFalse(_is_punctuation(u"A")) 126 | self.assertFalse(_is_punctuation(u" ")) 127 | 128 | def test_sequence_builders(self): 129 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 130 | 131 | text = tokenizer.encode("sequence builders") 132 | text_2 = tokenizer.encode("multi-sequence build") 133 | 134 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 135 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 136 | 137 | assert encoded_sentence == [101] + text + [102] 138 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 139 | 140 | if __name__ == '__main__': 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = GPT2Tokenizer 28 | 29 | def setUp(self): 30 | super(GPT2TokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "lo", "low", "er", 35 | "low", "lowest", "newer", "wider", ""] 36 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 37 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 38 | self.special_tokens_map = {"unk_token": ""} 39 | 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 42 | with open(self.vocab_file, "w") as fp: 43 | fp.write(json.dumps(vocab_tokens)) 44 | with open(self.merges_file, "w") as fp: 45 | fp.write("\n".join(merges)) 46 | 47 | def get_tokenizer(self): 48 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) 49 | 50 | def get_input_output_texts(self): 51 | input_text = u"lower newer" 52 | output_text = u"lowernewer" 53 | return input_text, output_text 54 | 55 | def test_full_tokenizer(self): 56 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 57 | text = "lower" 58 | bpe_tokens = ["low", "er"] 59 | tokens = tokenizer.tokenize(text) 60 | self.assertListEqual(tokens, bpe_tokens) 61 | 62 | input_tokens = tokens + [tokenizer.unk_token] 63 | input_bpe_tokens = [13, 12, 17] 64 | self.assertListEqual( 65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | 21 | from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | 25 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | tokenizer_class = RobertaTokenizer 27 | 28 | def setUp(self): 29 | super(RobertaTokenizationTest, self).setUp() 30 | 31 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 32 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 33 | "lo", "low", "er", 34 | "low", "lowest", "newer", "wider", ""] 35 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 36 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 37 | self.special_tokens_map = {"unk_token": ""} 38 | 39 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 40 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 41 | with open(self.vocab_file, "w") as fp: 42 | fp.write(json.dumps(vocab_tokens)) 43 | with open(self.merges_file, "w") as fp: 44 | fp.write("\n".join(merges)) 45 | 46 | def get_tokenizer(self): 47 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"lower newer" 51 | output_text = u"lowernewer" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 56 | text = "lower" 57 | bpe_tokens = ["low", "er"] 58 | tokens = tokenizer.tokenize(text) 59 | self.assertListEqual(tokens, bpe_tokens) 60 | 61 | input_tokens = tokens + [tokenizer.unk_token] 62 | input_bpe_tokens = [13, 12, 17] 63 | self.assertListEqual( 64 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 65 | 66 | def roberta_dict_integration_testing(self): 67 | tokenizer = self.get_tokenizer() 68 | 69 | self.assertListEqual( 70 | tokenizer.encode('Hello world!'), 71 | [0, 31414, 232, 328, 2] 72 | ) 73 | self.assertListEqual( 74 | tokenizer.encode('Hello world! cécé herlolip 418'), 75 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 76 | ) 77 | 78 | def test_sequence_builders(self): 79 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 80 | 81 | text = tokenizer.encode("sequence builders") 82 | text_2 = tokenizer.encode("multi-sequence build") 83 | 84 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 85 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 86 | 87 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 88 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 89 | 90 | assert encoded_sentence == encoded_text_from_decode 91 | assert encoded_pair == encoded_pair_from_decode 92 | 93 | 94 | if __name__ == '__main__': 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_tests_commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import sys 19 | from io import open 20 | import tempfile 21 | import shutil 22 | import unittest 23 | 24 | if sys.version_info[0] == 2: 25 | import cPickle as pickle 26 | 27 | class TemporaryDirectory(object): 28 | """Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" 29 | def __enter__(self): 30 | self.name = tempfile.mkdtemp() 31 | return self.name 32 | def __exit__(self, exc_type, exc_value, traceback): 33 | shutil.rmtree(self.name) 34 | else: 35 | import pickle 36 | TemporaryDirectory = tempfile.TemporaryDirectory 37 | unicode = str 38 | 39 | 40 | class CommonTestCases: 41 | 42 | class CommonTokenizerTester(unittest.TestCase): 43 | 44 | tokenizer_class = None 45 | 46 | def setUp(self): 47 | self.tmpdirname = tempfile.mkdtemp() 48 | 49 | def tearDown(self): 50 | shutil.rmtree(self.tmpdirname) 51 | 52 | def get_tokenizer(self): 53 | raise NotImplementedError 54 | 55 | def get_input_output_texts(self): 56 | raise NotImplementedError 57 | 58 | def test_save_and_load_tokenizer(self): 59 | tokenizer = self.get_tokenizer() 60 | 61 | before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 62 | 63 | with TemporaryDirectory() as tmpdirname: 64 | tokenizer.save_pretrained(tmpdirname) 65 | tokenizer = tokenizer.from_pretrained(tmpdirname) 66 | 67 | after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 68 | self.assertListEqual(before_tokens, after_tokens) 69 | 70 | def test_pickle_tokenizer(self): 71 | tokenizer = self.get_tokenizer() 72 | self.assertIsNotNone(tokenizer) 73 | 74 | text = u"Munich and Berlin are nice cities" 75 | subwords = tokenizer.tokenize(text) 76 | 77 | with TemporaryDirectory() as tmpdirname: 78 | 79 | filename = os.path.join(tmpdirname, u"tokenizer.bin") 80 | pickle.dump(tokenizer, open(filename, "wb")) 81 | 82 | tokenizer_new = pickle.load(open(filename, "rb")) 83 | 84 | subwords_loaded = tokenizer_new.tokenize(text) 85 | 86 | self.assertListEqual(subwords, subwords_loaded) 87 | 88 | 89 | def test_add_tokens_tokenizer(self): 90 | tokenizer = self.get_tokenizer() 91 | 92 | vocab_size = tokenizer.vocab_size 93 | all_size = len(tokenizer) 94 | 95 | self.assertNotEqual(vocab_size, 0) 96 | self.assertEqual(vocab_size, all_size) 97 | 98 | new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] 99 | added_toks = tokenizer.add_tokens(new_toks) 100 | vocab_size_2 = tokenizer.vocab_size 101 | all_size_2 = len(tokenizer) 102 | 103 | self.assertNotEqual(vocab_size_2, 0) 104 | self.assertEqual(vocab_size, vocab_size_2) 105 | self.assertEqual(added_toks, len(new_toks)) 106 | self.assertEqual(all_size_2, all_size + len(new_toks)) 107 | 108 | tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") 109 | self.assertGreaterEqual(len(tokens), 4) 110 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1) 111 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 112 | 113 | new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", 114 | 'pad_token': "<<<<<|||>|>>>>|>"} 115 | added_toks_2 = tokenizer.add_special_tokens(new_toks_2) 116 | vocab_size_3 = tokenizer.vocab_size 117 | all_size_3 = len(tokenizer) 118 | 119 | self.assertNotEqual(vocab_size_3, 0) 120 | self.assertEqual(vocab_size, vocab_size_3) 121 | self.assertEqual(added_toks_2, len(new_toks_2)) 122 | self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) 123 | 124 | tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") 125 | 126 | self.assertGreaterEqual(len(tokens), 6) 127 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1) 128 | self.assertGreater(tokens[0], tokens[1]) 129 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 130 | self.assertGreater(tokens[-2], tokens[-3]) 131 | self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) 132 | self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) 133 | 134 | 135 | def test_required_methods_tokenizer(self): 136 | tokenizer = self.get_tokenizer() 137 | input_text, output_text = self.get_input_output_texts() 138 | 139 | tokens = tokenizer.tokenize(input_text) 140 | ids = tokenizer.convert_tokens_to_ids(tokens) 141 | ids_2 = tokenizer.encode(input_text) 142 | self.assertListEqual(ids, ids_2) 143 | 144 | tokens_2 = tokenizer.convert_ids_to_tokens(ids) 145 | text_2 = tokenizer.decode(ids) 146 | 147 | self.assertEqual(text_2, output_text) 148 | 149 | self.assertNotEqual(len(tokens_2), 0) 150 | self.assertIsInstance(text_2, (str, unicode)) 151 | 152 | 153 | def test_pretrained_model_lists(self): 154 | weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) 155 | weights_lists_2 = [] 156 | for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items(): 157 | weights_lists_2.append(list(map_list.keys())) 158 | 159 | for weights_list_2 in weights_lists_2: 160 | self.assertListEqual(weights_list, weights_list_2) 161 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from.tokenization_tests_commons import CommonTestCases 24 | 25 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = TransfoXLTokenizer 28 | 29 | def setUp(self): 30 | super(TransfoXLTokenizationTest, self).setUp() 31 | 32 | vocab_tokens = [ 33 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 34 | "running", ",", "low", "l", 35 | ] 36 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 37 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 38 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 39 | 40 | def get_tokenizer(self): 41 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) 42 | 43 | def get_input_output_texts(self): 44 | input_text = u" UNwanted , running" 45 | output_text = u" unwanted, running" 46 | return input_text, output_text 47 | 48 | def test_full_tokenizer(self): 49 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 50 | 51 | tokens = tokenizer.tokenize(u" UNwanted , running") 52 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 53 | 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 56 | 57 | def test_full_tokenizer_lower(self): 58 | tokenizer = TransfoXLTokenizer(lower_case=True) 59 | 60 | self.assertListEqual( 61 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 62 | ["hello", "!", "how", "are", "you", "?"]) 63 | 64 | def test_full_tokenizer_no_lower(self): 65 | tokenizer = TransfoXLTokenizer(lower_case=False) 66 | 67 | self.assertListEqual( 68 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 69 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from pytorch_transformers import PreTrainedTokenizer 23 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | class TokenizerUtilsTest(unittest.TestCase): 26 | def check_tokenizer_from_pretrained(self, tokenizer_class): 27 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 28 | for model_name in s3_models[:1]: 29 | tokenizer = tokenizer_class.from_pretrained(model_name) 30 | self.assertIsNotNone(tokenizer) 31 | self.assertIsInstance(tokenizer, tokenizer_class) 32 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 33 | 34 | for special_tok in tokenizer.all_special_tokens: 35 | if six.PY2: 36 | self.assertIsInstance(special_tok, unicode) 37 | else: 38 | self.assertIsInstance(special_tok, str) 39 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 40 | self.assertIsInstance(special_tok_id, int) 41 | 42 | def test_pretrained_tokenizers(self): 43 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = XLMTokenizer 28 | 29 | def setUp(self): 30 | super(XLMTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "w", "r", "t", 35 | "lo", "low", "er", 36 | "low", "lowest", "newer", "wider", ""] 37 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 38 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 39 | 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 42 | with open(self.vocab_file, "w") as fp: 43 | fp.write(json.dumps(vocab_tokens)) 44 | with open(self.merges_file, "w") as fp: 45 | fp.write("\n".join(merges)) 46 | 47 | def get_tokenizer(self): 48 | return XLMTokenizer.from_pretrained(self.tmpdirname) 49 | 50 | def get_input_output_texts(self): 51 | input_text = u"lower newer" 52 | output_text = u"lower newer" 53 | return input_text, output_text 54 | 55 | def test_full_tokenizer(self): 56 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 57 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 58 | 59 | text = "lower" 60 | bpe_tokens = ["low", "er"] 61 | tokens = tokenizer.tokenize(text) 62 | self.assertListEqual(tokens, bpe_tokens) 63 | 64 | input_tokens = tokens + [""] 65 | input_bpe_tokens = [14, 15, 20] 66 | self.assertListEqual( 67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 68 | 69 | def test_sequence_builders(self): 70 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 71 | 72 | text = tokenizer.encode("sequence builders") 73 | text_2 = tokenizer.encode("multi-sequence build") 74 | 75 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 76 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 77 | 78 | assert encoded_sentence == [1] + text + [1] 79 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 80 | 81 | if __name__ == '__main__': 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/test_sentencepiece.model') 26 | 27 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | 29 | tokenizer_class = XLNetTokenizer 30 | 31 | def setUp(self): 32 | super(XLNetTokenizationTest, self).setUp() 33 | 34 | # We have a SentencePiece fixture for testing 35 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 36 | tokenizer.save_pretrained(self.tmpdirname) 37 | 38 | def get_tokenizer(self): 39 | return XLNetTokenizer.from_pretrained(self.tmpdirname) 40 | 41 | def get_input_output_texts(self): 42 | input_text = u"This is a test" 43 | output_text = u"This is a test" 44 | return input_text, output_text 45 | 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 58 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 59 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 60 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 61 | ids = tokenizer.convert_tokens_to_ids(tokens) 62 | self.assertListEqual( 63 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 64 | 602, 347, 347, 347, 3, 12, 66, 65 | 46, 72, 80, 6, 0, 4]) 66 | 67 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 68 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 69 | u'or', u'n', SPIECE_UNDERLINE + u'in', 70 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 71 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 72 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 73 | u'', u'.']) 74 | 75 | def test_tokenizer_lower(self): 76 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 77 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 78 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 79 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 80 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 81 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 82 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 83 | 84 | def test_tokenizer_no_lower(self): 85 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 86 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 87 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 88 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 89 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 90 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 91 | 92 | def test_sequence_builders(self): 93 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 94 | 95 | text = tokenizer.encode("sequence builders") 96 | text_2 = tokenizer.encode("multi-sequence build") 97 | 98 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 99 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 100 | 101 | assert encoded_sentence == text + [4, 3] 102 | assert encoded_pair == text + [4] + text_2 + [4, 3] 103 | 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .tokenization_bert import BertTokenizer 22 | from .tokenization_openai import OpenAIGPTTokenizer 23 | from .tokenization_gpt2 import GPT2Tokenizer 24 | from .tokenization_transfo_xl import TransfoXLTokenizer 25 | from .tokenization_xlnet import XLNetTokenizer 26 | from .tokenization_xlm import XLMTokenizer 27 | from .tokenization_roberta import RobertaTokenizer 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class AutoTokenizer(object): 32 | r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class 33 | that will be instantiated as one of the tokenizer classes of the library 34 | when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` 35 | class method. 36 | 37 | The `from_pretrained()` method take care of returning the correct tokenizer class instance 38 | using pattern matching on the `pretrained_model_name_or_path` string. 39 | 40 | The tokenizer class to instantiate is selected as the first pattern matching 41 | in the `pretrained_model_name_or_path` string (in the following order): 42 | - contains `bert`: BertTokenizer (Bert model) 43 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 44 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 45 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 46 | - contains `xlnet`: XLNetTokenizer (XLNet model) 47 | - contains `xlm`: XLMTokenizer (XLM model) 48 | - contains `roberta`: RobertaTokenizer (RoBERTa model) 49 | 50 | This class cannot be instantiated using `__init__()` (throw an error). 51 | """ 52 | def __init__(self): 53 | raise EnvironmentError("AutoTokenizer is designed to be instantiated " 54 | "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") 55 | 56 | @classmethod 57 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 58 | r""" Instantiate a one of the tokenizer classes of the library 59 | from a pre-trained model vocabulary. 60 | 61 | The tokenizer class to instantiate is selected as the first pattern matching 62 | in the `pretrained_model_name_or_path` string (in the following order): 63 | - contains `bert`: BertTokenizer (Bert model) 64 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 65 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 66 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 67 | - contains `xlnet`: XLNetTokenizer (XLNet model) 68 | - contains `xlm`: XLMTokenizer (XLM model) 69 | - contains `roberta`: RobertaTokenizer (XLM model) 70 | 71 | Params: 72 | **pretrained_model_name_or_path**: either: 73 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache 74 | or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). 75 | - a path to a `directory` containing a configuration file saved 76 | using the `save_pretrained(save_directory)` method. 77 | - a path or url to a saved configuration `file`. 78 | **cache_dir**: (`optional`) string: 79 | Path to a directory in which a downloaded pre-trained model 80 | configuration should be cached if the standard cache should not be used. 81 | 82 | Examples:: 83 | 84 | config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. 85 | config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` 86 | 87 | """ 88 | if 'roberta' in pretrained_model_name_or_path: 89 | return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 90 | elif 'bert' in pretrained_model_name_or_path: 91 | return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 92 | elif 'openai-gpt' in pretrained_model_name_or_path: 93 | return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 94 | elif 'gpt2' in pretrained_model_name_or_path: 95 | return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 96 | elif 'transfo-xl' in pretrained_model_name_or_path: 97 | return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 98 | elif 'xlnet' in pretrained_model_name_or_path: 99 | return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 100 | elif 'xlm' in pretrained_model_name_or_path: 101 | return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 102 | 103 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 104 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " 105 | "'xlm', 'roberta'".format(pretrained_model_name_or_path)) 106 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .tokenization_utils import PreTrainedTokenizer 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 47 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 48 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", 49 | }, 50 | 'merges_file': 51 | { 52 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 53 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 54 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", 55 | }, 56 | } 57 | 58 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 59 | 'gpt2': 1024, 60 | 'gpt2-medium': 1024, 61 | 'gpt2-large': 1024, 62 | } 63 | 64 | @lru_cache() 65 | def bytes_to_unicode(): 66 | """ 67 | Returns list of utf-8 byte and a corresponding list of unicode strings. 68 | The reversible bpe codes work on unicode strings. 69 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 70 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 71 | This is a signficant percentage of your normal, say, 32K bpe vocab. 72 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 73 | And avoids mapping to whitespace/control characters the bpe code barfs on. 74 | """ 75 | _chr = unichr if sys.version_info[0] == 2 else chr 76 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 77 | cs = bs[:] 78 | n = 0 79 | for b in range(2**8): 80 | if b not in bs: 81 | bs.append(b) 82 | cs.append(2**8+n) 83 | n += 1 84 | cs = [_chr(n) for n in cs] 85 | return dict(zip(bs, cs)) 86 | 87 | def get_pairs(word): 88 | """Return set of symbol pairs in a word. 89 | 90 | Word is represented as tuple of symbols (symbols being variable-length strings). 91 | """ 92 | pairs = set() 93 | prev_char = word[0] 94 | for char in word[1:]: 95 | pairs.add((prev_char, char)) 96 | prev_char = char 97 | return pairs 98 | 99 | class GPT2Tokenizer(PreTrainedTokenizer): 100 | """ 101 | GPT-2 BPE tokenizer. Peculiarities: 102 | - Byte-level BPE 103 | """ 104 | vocab_files_names = VOCAB_FILES_NAMES 105 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 106 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 107 | 108 | def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", 109 | bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): 110 | super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) 111 | 112 | self.encoder = json.load(open(vocab_file)) 113 | self.decoder = {v:k for k,v in self.encoder.items()} 114 | self.errors = errors # how to handle errors in decoding 115 | self.byte_encoder = bytes_to_unicode() 116 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 117 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 118 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 119 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 120 | self.cache = {} 121 | 122 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 123 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 124 | 125 | @property 126 | def vocab_size(self): 127 | return len(self.encoder) 128 | 129 | def bpe(self, token): 130 | if token in self.cache: 131 | return self.cache[token] 132 | word = tuple(token) 133 | pairs = get_pairs(word) 134 | 135 | if not pairs: 136 | return token 137 | 138 | while True: 139 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 140 | if bigram not in self.bpe_ranks: 141 | break 142 | first, second = bigram 143 | new_word = [] 144 | i = 0 145 | while i < len(word): 146 | try: 147 | j = word.index(first, i) 148 | new_word.extend(word[i:j]) 149 | i = j 150 | except: 151 | new_word.extend(word[i:]) 152 | break 153 | 154 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 155 | new_word.append(first+second) 156 | i += 2 157 | else: 158 | new_word.append(word[i]) 159 | i += 1 160 | new_word = tuple(new_word) 161 | word = new_word 162 | if len(word) == 1: 163 | break 164 | else: 165 | pairs = get_pairs(word) 166 | word = ' '.join(word) 167 | self.cache[token] = word 168 | return word 169 | 170 | def _tokenize(self, text): 171 | """ Tokenize a string. """ 172 | bpe_tokens = [] 173 | for token in re.findall(self.pat, text): 174 | if sys.version_info[0] == 2: 175 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 176 | else: 177 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 178 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 179 | return bpe_tokens 180 | 181 | def _convert_token_to_id(self, token): 182 | """ Converts a token (str/unicode) in an id using the vocab. """ 183 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 184 | 185 | def _convert_id_to_token(self, index): 186 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 187 | return self.decoder.get(index) 188 | 189 | def convert_tokens_to_string(self, tokens): 190 | """ Converts a sequence of tokens (string) in a single string. """ 191 | text = ''.join(tokens) 192 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 193 | return text 194 | 195 | def save_vocabulary(self, save_directory): 196 | """Save the tokenizer vocabulary and merge files to a directory.""" 197 | if not os.path.isdir(save_directory): 198 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 199 | return 200 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 201 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 202 | 203 | with open(vocab_file, 'w', encoding='utf-8') as f: 204 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 205 | 206 | index = 0 207 | with open(merge_file, "w", encoding="utf-8") as writer: 208 | writer.write(u'#version: 0.2\n') 209 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 210 | if index != token_index: 211 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 212 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 213 | index = token_index 214 | writer.write(' '.join(bpe_tokens) + u'\n') 215 | index += 1 216 | 217 | return vocab_file, merge_file 218 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | from .tokenization_bert import BasicTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = { 31 | 'vocab_file': 'vocab.json', 32 | 'merges_file': 'merges.txt', 33 | } 34 | 35 | PRETRAINED_VOCAB_FILES_MAP = { 36 | 'vocab_file': 37 | { 38 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 39 | }, 40 | 'merges_file': 41 | { 42 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 43 | }, 44 | } 45 | 46 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 47 | 'openai-gpt': 512, 48 | } 49 | 50 | def get_pairs(word): 51 | """ 52 | Return set of symbol pairs in a word. 53 | word is represented as tuple of symbols (symbols being variable-length strings) 54 | """ 55 | pairs = set() 56 | prev_char = word[0] 57 | for char in word[1:]: 58 | pairs.add((prev_char, char)) 59 | prev_char = char 60 | return pairs 61 | 62 | def text_standardize(text): 63 | """ 64 | fixes some issues the spacy tokenizer had on books corpus 65 | also does some whitespace standardization 66 | """ 67 | text = text.replace('—', '-') 68 | text = text.replace('–', '-') 69 | text = text.replace('―', '-') 70 | text = text.replace('…', '...') 71 | text = text.replace('´', "'") 72 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 73 | text = re.sub(r'\s*\n\s*', ' \n ', text) 74 | text = re.sub(r'[^\S\n]+', ' ', text) 75 | return text.strip() 76 | 77 | class OpenAIGPTTokenizer(PreTrainedTokenizer): 78 | """ 79 | BPE tokenizer. Peculiarities: 80 | - lower case all inputs 81 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 82 | """ 83 | vocab_files_names = VOCAB_FILES_NAMES 84 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 85 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 86 | 87 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): 88 | super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) 89 | 90 | try: 91 | import ftfy 92 | from spacy.lang.en import English 93 | _nlp = English() 94 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp) 95 | self.fix_text = ftfy.fix_text 96 | except ImportError: 97 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 98 | self.nlp = BasicTokenizer(do_lower_case=True) 99 | self.fix_text = None 100 | 101 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 102 | self.decoder = {v:k for k,v in self.encoder.items()} 103 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 104 | merges = [tuple(merge.split()) for merge in merges] 105 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 106 | self.cache = {} 107 | 108 | @property 109 | def vocab_size(self): 110 | return len(self.encoder) 111 | 112 | def bpe(self, token): 113 | word = tuple(token[:-1]) + (token[-1] + '',) 114 | if token in self.cache: 115 | return self.cache[token] 116 | pairs = get_pairs(word) 117 | 118 | if not pairs: 119 | return token+'' 120 | 121 | while True: 122 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 123 | if bigram not in self.bpe_ranks: 124 | break 125 | first, second = bigram 126 | new_word = [] 127 | i = 0 128 | while i < len(word): 129 | try: 130 | j = word.index(first, i) 131 | new_word.extend(word[i:j]) 132 | i = j 133 | except: 134 | new_word.extend(word[i:]) 135 | break 136 | 137 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 138 | new_word.append(first+second) 139 | i += 2 140 | else: 141 | new_word.append(word[i]) 142 | i += 1 143 | new_word = tuple(new_word) 144 | word = new_word 145 | if len(word) == 1: 146 | break 147 | else: 148 | pairs = get_pairs(word) 149 | word = ' '.join(word) 150 | if word == '\n ': 151 | word = '\n' 152 | self.cache[token] = word 153 | return word 154 | 155 | def _tokenize(self, text): 156 | """ Tokenize a string. """ 157 | split_tokens = [] 158 | if self.fix_text is None: 159 | # Using BERT's BasicTokenizer 160 | text = self.nlp.tokenize(text) 161 | for token in text: 162 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 163 | else: 164 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 165 | text = self.nlp(text_standardize(self.fix_text(text))) 166 | for token in text: 167 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 168 | return split_tokens 169 | 170 | def _convert_token_to_id(self, token): 171 | """ Converts a token (str/unicode) in an id using the vocab. """ 172 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 173 | 174 | def _convert_id_to_token(self, index): 175 | """Converts an id in a token (BPE) using the vocab.""" 176 | return self.decoder.get(index, self.unk_token) 177 | 178 | def convert_tokens_to_string(self, tokens): 179 | """ Converts a sequence of tokens (string) in a single string. """ 180 | out_string = ''.join(tokens).replace('', ' ').strip() 181 | return out_string 182 | 183 | def save_vocabulary(self, save_directory): 184 | """Save the tokenizer vocabulary and merge files to a directory.""" 185 | if not os.path.isdir(save_directory): 186 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 187 | return 188 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 189 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 190 | 191 | with open(vocab_file, 'w', encoding='utf-8') as f: 192 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 193 | 194 | index = 0 195 | with open(merge_file, "w", encoding="utf-8") as writer: 196 | writer.write(u'#version: 0.2\n') 197 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 198 | if index != token_index: 199 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 200 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 201 | index = token_index 202 | writer.write(' '.join(bpe_tokens) + u'\n') 203 | index += 1 204 | 205 | return vocab_file, merge_file 206 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoBERTa.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | from .tokenization_gpt2 import bytes_to_unicode, get_pairs 27 | from .tokenization_utils import PreTrainedTokenizer 28 | 29 | try: 30 | from functools import lru_cache 31 | except ImportError: 32 | # Just a dummy decorator to get the checks to run on python2 33 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 34 | def lru_cache(): 35 | return lambda func: func 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | VOCAB_FILES_NAMES = { 40 | 'vocab_file': 'vocab.json', 41 | 'merges_file': 'merges.txt', 42 | } 43 | 44 | PRETRAINED_VOCAB_FILES_MAP = { 45 | 'vocab_file': 46 | { 47 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 48 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 49 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", 50 | }, 51 | 'merges_file': 52 | { 53 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 54 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 55 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", 56 | }, 57 | } 58 | 59 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 60 | 'roberta-base': 512, 61 | 'roberta-large': 512, 62 | 'roberta-large-mnli': 512, 63 | } 64 | 65 | 66 | class RobertaTokenizer(PreTrainedTokenizer): 67 | """ 68 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: Byte-level BPE 69 | """ 70 | vocab_files_names = VOCAB_FILES_NAMES 71 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 72 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 73 | 74 | def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="", 75 | cls_token="", unk_token="", pad_token='', mask_token='', **kwargs): 76 | super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, 77 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, 78 | mask_token=mask_token, **kwargs) 79 | 80 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 81 | self.decoder = {v: k for k, v in self.encoder.items()} 82 | self.errors = errors # how to handle errors in decoding 83 | self.byte_encoder = bytes_to_unicode() 84 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 85 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 86 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 87 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 88 | self.cache = {} 89 | 90 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 91 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 92 | 93 | @property 94 | def vocab_size(self): 95 | return len(self.encoder) 96 | 97 | def bpe(self, token): 98 | if token in self.cache: 99 | return self.cache[token] 100 | word = tuple(token) 101 | pairs = get_pairs(word) 102 | 103 | if not pairs: 104 | return token 105 | 106 | while True: 107 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 108 | if bigram not in self.bpe_ranks: 109 | break 110 | first, second = bigram 111 | new_word = [] 112 | i = 0 113 | while i < len(word): 114 | try: 115 | j = word.index(first, i) 116 | new_word.extend(word[i:j]) 117 | i = j 118 | except: 119 | new_word.extend(word[i:]) 120 | break 121 | 122 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 123 | new_word.append(first+second) 124 | i += 2 125 | else: 126 | new_word.append(word[i]) 127 | i += 1 128 | new_word = tuple(new_word) 129 | word = new_word 130 | if len(word) == 1: 131 | break 132 | else: 133 | pairs = get_pairs(word) 134 | word = ' '.join(word) 135 | self.cache[token] = word 136 | return word 137 | 138 | def _tokenize(self, text): 139 | """ Tokenize a string. """ 140 | bpe_tokens = [] 141 | for token in re.findall(self.pat, text): 142 | if sys.version_info[0] == 2: 143 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 144 | else: 145 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 146 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 147 | return bpe_tokens 148 | 149 | def _convert_token_to_id(self, token): 150 | """ Converts a token (str/unicode) in an id using the vocab. """ 151 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 152 | 153 | def _convert_id_to_token(self, index): 154 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 155 | return self.decoder.get(index) 156 | 157 | def convert_tokens_to_string(self, tokens): 158 | """ Converts a sequence of tokens (string) in a single string. """ 159 | text = ''.join(tokens) 160 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 161 | return text 162 | 163 | def add_special_tokens_single_sentence(self, token_ids): 164 | """ 165 | Adds special tokens to a sequence for sequence classification tasks. 166 | A RoBERTa sequence has the following format: [CLS] X [SEP] 167 | """ 168 | return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)] 169 | 170 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 171 | """ 172 | Adds special tokens to a sequence pair for sequence classification tasks. 173 | A RoBERTa sequence pair has the following format: [CLS] A [SEP][SEP] B [SEP] 174 | """ 175 | sep = [self._convert_token_to_id(self.sep_token)] 176 | cls = [self._convert_token_to_id(self.cls_token)] 177 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 178 | 179 | def save_vocabulary(self, save_directory): 180 | """Save the tokenizer vocabulary and merge files to a directory.""" 181 | if not os.path.isdir(save_directory): 182 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 183 | return 184 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 185 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 186 | 187 | with open(vocab_file, 'w', encoding='utf-8') as f: 188 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 189 | 190 | index = 0 191 | with open(merge_file, "w", encoding="utf-8") as writer: 192 | writer.write(u'#version: 0.2\n') 193 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 194 | if index != token_index: 195 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 196 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 197 | index = token_index 198 | writer.write(' '.join(bpe_tokens) + u'\n') 199 | index += 1 200 | 201 | return vocab_file, merge_file 202 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization classes for XLNet model.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | from shutil import copyfile 22 | 23 | import unicodedata 24 | import six 25 | 26 | from .tokenization_utils import PreTrainedTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} 31 | 32 | PRETRAINED_VOCAB_FILES_MAP = { 33 | 'vocab_file': 34 | { 35 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model", 36 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", 37 | } 38 | } 39 | 40 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 41 | 'xlnet-base-cased': None, 42 | 'xlnet-large-cased': None, 43 | } 44 | 45 | SPIECE_UNDERLINE = u'▁' 46 | 47 | # Segments (not really needed) 48 | SEG_ID_A = 0 49 | SEG_ID_B = 1 50 | SEG_ID_CLS = 2 51 | SEG_ID_SEP = 3 52 | SEG_ID_PAD = 4 53 | 54 | class XLNetTokenizer(PreTrainedTokenizer): 55 | """ 56 | SentencePiece based tokenizer. Peculiarities: 57 | 58 | - requires `SentencePiece `_ 59 | """ 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | 64 | def __init__(self, vocab_file, max_len=None, 65 | do_lower_case=False, remove_space=True, keep_accents=False, 66 | bos_token="", eos_token="", unk_token="", sep_token="", 67 | pad_token="", cls_token="", mask_token="", 68 | additional_special_tokens=["", ""], **kwargs): 69 | super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, 70 | unk_token=unk_token, sep_token=sep_token, 71 | pad_token=pad_token, cls_token=cls_token, 72 | mask_token=mask_token, additional_special_tokens= 73 | additional_special_tokens, **kwargs) 74 | try: 75 | import sentencepiece as spm 76 | except ImportError: 77 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 78 | "pip install sentencepiece") 79 | 80 | self.do_lower_case = do_lower_case 81 | self.remove_space = remove_space 82 | self.keep_accents = keep_accents 83 | self.vocab_file = vocab_file 84 | 85 | self.sp_model = spm.SentencePieceProcessor() 86 | self.sp_model.Load(vocab_file) 87 | 88 | @property 89 | def vocab_size(self): 90 | return len(self.sp_model) 91 | 92 | def __getstate__(self): 93 | state = self.__dict__.copy() 94 | state["sp_model"] = None 95 | return state 96 | 97 | def __setstate__(self, d): 98 | self.__dict__ = d 99 | try: 100 | import sentencepiece as spm 101 | except ImportError: 102 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 103 | "pip install sentencepiece") 104 | self.sp_model = spm.SentencePieceProcessor() 105 | self.sp_model.Load(self.vocab_file) 106 | 107 | def preprocess_text(self, inputs): 108 | if self.remove_space: 109 | outputs = ' '.join(inputs.strip().split()) 110 | else: 111 | outputs = inputs 112 | outputs = outputs.replace("``", '"').replace("''", '"') 113 | 114 | if six.PY2 and isinstance(outputs, str): 115 | outputs = outputs.decode('utf-8') 116 | 117 | if not self.keep_accents: 118 | outputs = unicodedata.normalize('NFKD', outputs) 119 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 120 | if self.do_lower_case: 121 | outputs = outputs.lower() 122 | 123 | return outputs 124 | 125 | def _tokenize(self, text, return_unicode=True, sample=False): 126 | """ Tokenize a string. 127 | return_unicode is used only for py2 128 | """ 129 | text = self.preprocess_text(text) 130 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 131 | if six.PY2 and isinstance(text, unicode): 132 | text = text.encode('utf-8') 133 | 134 | if not sample: 135 | pieces = self.sp_model.EncodeAsPieces(text) 136 | else: 137 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) 138 | new_pieces = [] 139 | for piece in pieces: 140 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 141 | cur_pieces = self.sp_model.EncodeAsPieces( 142 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 143 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 144 | if len(cur_pieces[0]) == 1: 145 | cur_pieces = cur_pieces[1:] 146 | else: 147 | cur_pieces[0] = cur_pieces[0][1:] 148 | cur_pieces.append(piece[-1]) 149 | new_pieces.extend(cur_pieces) 150 | else: 151 | new_pieces.append(piece) 152 | 153 | # note(zhiliny): convert back to unicode for py2 154 | if six.PY2 and return_unicode: 155 | ret_pieces = [] 156 | for piece in new_pieces: 157 | if isinstance(piece, str): 158 | piece = piece.decode('utf-8') 159 | ret_pieces.append(piece) 160 | new_pieces = ret_pieces 161 | 162 | return new_pieces 163 | 164 | def _convert_token_to_id(self, token): 165 | """ Converts a token (str/unicode) in an id using the vocab. """ 166 | return self.sp_model.PieceToId(token) 167 | 168 | def _convert_id_to_token(self, index, return_unicode=True): 169 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 170 | token = self.sp_model.IdToPiece(index) 171 | if six.PY2 and return_unicode and isinstance(token, str): 172 | token = token.decode('utf-8') 173 | return token 174 | 175 | def convert_tokens_to_string(self, tokens): 176 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 177 | out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() 178 | return out_string 179 | 180 | def add_special_tokens_single_sentence(self, token_ids): 181 | """ 182 | Adds special tokens to a sequence pair for sequence classification tasks. 183 | An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] 184 | """ 185 | sep = [self._convert_token_to_id(self.sep_token)] 186 | cls = [self._convert_token_to_id(self.cls_token)] 187 | return token_ids + sep + cls 188 | 189 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 190 | """ 191 | Adds special tokens to a sequence for sequence classification tasks. 192 | An XLNet sequence has the following format: X [SEP][CLS] 193 | """ 194 | sep = [self._convert_token_to_id(self.sep_token)] 195 | cls = [self._convert_token_to_id(self.cls_token)] 196 | return token_ids_0 + sep + token_ids_1 + sep + cls 197 | 198 | def save_vocabulary(self, save_directory): 199 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 200 | to a directory. 201 | """ 202 | if not os.path.isdir(save_directory): 203 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 204 | return 205 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 206 | 207 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 208 | copyfile(self.vocab_file, out_vocab_file) 209 | 210 | return (out_vocab_file,) 211 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=1.0.0 3 | -------------------------------------------------------------------------------- /run_bert.sh: -------------------------------------------------------------------------------- 1 | DATE=$(date +%Y%m%d%H%M) 2 | mkdir models/$DATE 3 | export CUDA_VISIBLE_DEVICES=0 4 | for((i=0;i<5;i++)); 5 | do 6 | 7 | python run_bert.py \ 8 | --model_type bert \ 9 | --model_name_or_path ../premodels/chinese_wwm_ex_bert \ 10 | --do_train \ 11 | --do_eval \ 12 | --do_test \ 13 | --data_dir ../data/rawdata/guoday/data_$i \ 14 | --output_dir ./models/$DATE/model_bert$i \ 15 | --max_seq_length 128 \ 16 | --split_num 3 \ 17 | --lstm_hidden_size 512 \ 18 | --lstm_layers 3 \ 19 | --lstm_dropout 0.1 \ 20 | --eval_steps 200 \ 21 | --per_gpu_train_batch_size 4 \ 22 | --gradient_accumulation_steps 4 \ 23 | --warmup_steps 0 \ 24 | --per_gpu_eval_batch_size 32 \ 25 | --learning_rate 5e-6 \ 26 | --adam_epsilon 1e-6 \ 27 | --weight_decay 0.01 \ 28 | --train_steps 20000 29 | 30 | done 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /run_bert_wwm_ext.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | for((i=0;i<5;i++)); 3 | do 4 | 5 | python run_bert.py \ 6 | --model_type bert \ 7 | --model_name_or_path ./chinese_wwm_ex_bert \ 8 | --do_train \ 9 | --do_eval \ 10 | --do_test \ 11 | --data_dir ./data/data_$i \ 12 | --output_dir ./model_bert_wwm_ext$i \ 13 | --max_seq_length 256 \ 14 | --split_num 3 \ 15 | --lstm_hidden_size 512 \ 16 | --lstm_layers 1 \ 17 | --lstm_dropout 0.1 \ 18 | --eval_steps 200 \ 19 | --per_gpu_train_batch_size 1 \ 20 | --gradient_accumulation_steps 1 \ 21 | --warmup_steps 0 \ 22 | --per_gpu_eval_batch_size 32 \ 23 | --learning_rate 5e-6 \ 24 | --adam_epsilon 1e-6 \ 25 | --weight_decay 0 \ 26 | --train_steps 5000 ; 27 | 28 | done 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /run_roberta.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | fold=1 3 | rm -rf model_roberta 4 | for((i=0;i<$fold;i++)); 5 | do 6 | 7 | python run_bert.py \ 8 | --seed 4321 \ 9 | --model_type bert \ 10 | --model_name_or_path pretrained_model/chinese_roberta \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_test \ 14 | --data_dir ./data/data_$i \ 15 | --output_dir ./model_roberta/model_roberta$i \ 16 | --max_seq_length 168 \ 17 | --split_num 5 \ 18 | --lstm_hidden_size 1024 \ 19 | --lstm_layers 3 \ 20 | --lstm_dropout 0.5 \ 21 | --eval_steps 200 \ 22 | --per_gpu_train_batch_size 4 \ 23 | --gradient_accumulation_steps 4 \ 24 | --warmup_steps 0 \ 25 | --per_gpu_eval_batch_size 32 \ 26 | --learning_rate 2e-6 \ 27 | --adam_epsilon 1e-6 \ 28 | --weight_decay 0.005 \ 29 | --train_steps 25000 30 | 31 | done 32 | 33 | echo "save models into backup fold" 34 | t=`date +%Y%m%d%H%M%S` 35 | cp -rf ./model_roberta backup-models/roberta_models/models-$t 36 | echo "done" 37 | 38 | rm result.csv 39 | echo "combine result" 40 | python combine.py --model_prefix model_roberta/model_roberta --out_path result.csv --fold $fold 41 | echo "done" 42 | 43 | echo "save result into backup fold" 44 | cp result.csv backup-models/roberta_models/models-$t 45 | echo "done" 46 | 47 | echo "save run script into backup fold" 48 | cp run_roberta.sh backup-models/roberta_models/models-$t 49 | 50 | -------------------------------------------------------------------------------- /run_roberta_wwm_ext.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | fold=1 3 | rm -rf model_roberta_wwm_ext 4 | for((i=0;i<$fold;i++)); 5 | do 6 | 7 | python run_bert.py \ 8 | --seed 4321 \ 9 | --model_type bert \ 10 | --model_name_or_path ./pretrained_model/chinese_roberta_wwm_ext \ 11 | --do_test \ 12 | --do_train \ 13 | --do_eval \ 14 | --data_dir ./data/data_$i \ 15 | --output_dir ./model_roberta_wwm_ext/model_roberta_wwm_ext$i \ 16 | --max_seq_length 510 \ 17 | --split_num 1 \ 18 | --lstm_hidden_size 1024 \ 19 | --lstm_layers 1 \ 20 | --lstm_dropout 0.5 \ 21 | --eval_steps 200 \ 22 | --per_gpu_train_batch_size 4 \ 23 | --gradient_accumulation_steps 4 \ 24 | --warmup_steps 0 \ 25 | --per_gpu_eval_batch_size 32 \ 26 | --learning_rate 3e-6 \ 27 | --adam_epsilon 1e-6 \ 28 | --weight_decay 0.007 \ 29 | --train_steps 30000 30 | 31 | done 32 | 33 | 34 | echo "save models into backup fold" 35 | t=`date +%Y%m%d%H%M%S` 36 | cp -rf ./model_roberta_wwm_ext backup-models/roberta_wwm_ext_models/models-$t 37 | echo "done" 38 | 39 | rm result.csv 40 | echo "combine result" 41 | python combine.py --model_prefix model_roberta_wwm_ext/model_roberta_wwm_ext --out_path result.csv --fold $fold 42 | echo "done" 43 | 44 | echo "save result into backup fold" 45 | cp result.csv backup-models/roberta_wwm_ext_models/models-$t 46 | echo "done" 47 | 48 | echo "save run script into backup fold" 49 | cp run_roberta_wwm_ext.sh backup-models/roberta_wwm_ext_models/models-$t 50 | 51 | -------------------------------------------------------------------------------- /run_xlnet.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | for((i=0;i<5;i++)); 3 | do 4 | 5 | python run_xlnet.py \ 6 | --model_type xlnet \ 7 | --model_name_or_path ./chinese_xlnet_mid \ 8 | --do_train \ 9 | --do_eval \ 10 | --do_test \ 11 | --data_dir ./data/data_$i \ 12 | --output_dir ./model_xlnet$i \ 13 | --max_seq_length 150 \ 14 | --split_num 10 \ 15 | --lstm_hidden_size 512 \ 16 | --lstm_layers 1 \ 17 | --lstm_dropout 0.1 \ 18 | --eval_steps 200 \ 19 | --per_gpu_train_batch_size 1 \ 20 | --gradient_accumulation_steps 1 \ 21 | --warmup_steps 0 \ 22 | --per_gpu_eval_batch_size 64 \ 23 | --learning_rate 5e-6 \ 24 | --adam_epsilon 1e-6 \ 25 | --weight_decay 0 \ 26 | --train_steps 4000 \ 27 | --report_steps 200 ; 28 | 29 | done 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi pytorch-transformers 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | setup( 40 | name="pytorch_transformers", 41 | version="1.1.0", 42 | author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors", 43 | author_email="thomas@huggingface.co", 44 | description="Repository of pre-trained NLP Transformer models: BERT & RoBERTa, GPT & GPT-2, Transformer-XL, XLNet and XLM", 45 | long_description=open("README.md", "r", encoding='utf-8').read(), 46 | long_description_content_type="text/markdown", 47 | keywords='NLP deep learning transformer pytorch BERT GPT GPT-2 google openai CMU', 48 | license='Apache', 49 | url="https://github.com/huggingface/pytorch-transformers", 50 | packages=find_packages(exclude=["*.tests", "*.tests.*", 51 | "tests.*", "tests"]), 52 | install_requires=['torch>=1.0.0', 53 | 'numpy', 54 | 'boto3', 55 | 'requests', 56 | 'tqdm', 57 | 'regex', 58 | 'sentencepiece'], 59 | entry_points={ 60 | 'console_scripts': [ 61 | "pytorch_transformers=pytorch_transformers.__main__:main", 62 | ] 63 | }, 64 | # python_requires='>=3.5.0', 65 | tests_require=['pytest'], 66 | classifiers=[ 67 | 'Intended Audience :: Science/Research', 68 | 'License :: OSI Approved :: Apache Software License', 69 | 'Programming Language :: Python :: 3', 70 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 71 | ], 72 | ) 73 | --------------------------------------------------------------------------------