├── .idea
├── TSBERT.iml
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
└── vcs.xml
├── README.md
├── examples
├── DATASET
│ ├── ch
│ │ ├── cut_result_zys_reference_v2.txt
│ │ └── document_zys_2.json
│ └── en
│ │ ├── document_en_dataset_2.json
│ │ └── en_dataset_reference_2.txt
├── data
│ ├── alime
│ │ └── put_E-commerce_data_here.txt
│ ├── douban
│ │ └── put_douban_data_here.txt
│ └── ubuntu
│ │ └── put_ubuntu_data_here.txt
├── data_process.py
├── run_TSbert_v3.py
├── segmentation_BERTCLS.py
├── utils_TSbert_v3.py
└── utils_segmentation.py
└── pytorch_transformers
├── __init__.py
├── __main__.py
├── configuration_auto.py
├── configuration_bert.py
├── configuration_distilbert.py
├── configuration_gpt2.py
├── configuration_openai.py
├── configuration_roberta.py
├── configuration_transfo_xl.py
├── configuration_utils.py
├── configuration_xlm.py
├── configuration_xlnet.py
├── convert_gpt2_checkpoint_to_pytorch.py
├── convert_openai_checkpoint_to_pytorch.py
├── convert_pytorch_checkpoint_to_tf.py
├── convert_roberta_checkpoint_to_pytorch.py
├── convert_tf_checkpoint_to_pytorch.py
├── convert_transfo_xl_checkpoint_to_pytorch.py
├── convert_xlm_checkpoint_to_pytorch.py
├── convert_xlnet_checkpoint_to_pytorch.py
├── file_utils.py
├── modeling_TSbert.py
├── modeling_TSbert_v3.py
├── modeling_auto.py
├── modeling_bert.py
├── modeling_distilbert.py
├── modeling_gpt2.py
├── modeling_openai.py
├── modeling_roberta.py
├── modeling_transfo_xl.py
├── modeling_transfo_xl_utilities.py
├── modeling_utils.py
├── modeling_xlm.py
├── modeling_xlnet.py
├── optimization.py
├── optimization_bert.py
├── tests
├── __init__.py
├── configuration_common_test.py
├── conftest.py
├── fixtures
│ ├── input.txt
│ ├── sample_text.txt
│ └── test_sentencepiece.model
├── modeling_auto_test.py
├── modeling_bert_test.py
├── modeling_common_test.py
├── modeling_distilbert_test.py
├── modeling_gpt2_test.py
├── modeling_openai_test.py
├── modeling_roberta_test.py
├── modeling_transfo_xl_test.py
├── modeling_xlm_test.py
├── modeling_xlnet_test.py
├── optimization_test.py
├── tokenization_auto_test.py
├── tokenization_bert_test.py
├── tokenization_dilbert_test.py
├── tokenization_gpt2_test.py
├── tokenization_openai_test.py
├── tokenization_roberta_test.py
├── tokenization_tests_commons.py
├── tokenization_transfo_xl_test.py
├── tokenization_utils_test.py
├── tokenization_xlm_test.py
└── tokenization_xlnet_test.py
├── tokenization_auto.py
├── tokenization_bert.py
├── tokenization_distilbert.py
├── tokenization_gpt2.py
├── tokenization_openai.py
├── tokenization_roberta.py
├── tokenization_transfo_xl.py
├── tokenization_utils.py
├── tokenization_xlm.py
└── tokenization_xlnet.py
/.idea/TSBERT.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Dataset
2 | Please download datasets to the corresponding directory under "data"
3 |
4 | E-commerce
5 | https://drive.google.com/file/d/154J-neBo20ABtSmJDvm7DK0eTuieAuvw/view?usp=sharing.
6 |
7 | Ubuntu
8 | https://www.dropbox.com/s/2fdn26rj6h9bpvl/ubuntudata.zip?dl=0
9 |
10 | Douban
11 | https://www.dropbox.com/s/90t0qtji9ow20ca/DoubanConversaionCorpus.zip?dl=0&file_subpath=%2FDoubanConversaionCorpus
12 |
13 | Our own dataset for segmentation is under DATASET directory
14 |
15 | ## Source Code
16 | * prepare data
17 |
18 | generate cutlist.txt
19 |
20 | python segmentation_BERTCLS.py --datapath=data/xxx/xxx.txt
21 |
22 | gather segmented data: data/xxx/xxxseg.txt:
23 |
24 | set interval = 2 for train.txt, interval = 10 for test.txt
25 | set corresponding datafile and dataset in data_process.py
26 |
27 | python data_process.py
28 |
29 | * train
30 |
31 | python run_TSbert_v3.py --task=alime --do_train --train_batch_size=20 --learning_rate=2e-5
32 |
33 | The data will be saved in data/alime/input_cache_v3
34 |
35 | model will be saved in data/alime/model_save_v3, training log will also be saved in log.txt
36 |
37 |
38 | * eval
39 |
40 | python run_TSbert_v3.py --task=xxx
41 |
42 | You can also load our trained model for testing https://drive.google.com/drive/folders/1_sRSmwlaAK_TPaVYYNhao81rXUW92z98?usp=sharing
43 |
44 | ### Environment:
45 | we use pre-trained BERT of pytorch version from https://github.com/huggingface/transformers
46 |
47 | torch>=1.0.0
48 |
49 | package: tqdm, boto3, requests, regex, sacremoses, openpyxl, numpy, sentencepiece
50 |
51 | ### Reference
52 |
53 | If you use this code please cite our paper:
54 | ```
55 | @article{xu2020topic,
56 | title={Topic-aware multi-turn dialogue modeling},
57 | author={Xu, Yi and Zhao, Hai and Zhang, Zhuosheng},
58 | journal={arXiv preprint arXiv:2009.12539},
59 | year={2020}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/examples/data/alime/put_E-commerce_data_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/examples/data/alime/put_E-commerce_data_here.txt
--------------------------------------------------------------------------------
/examples/data/douban/put_douban_data_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/examples/data/douban/put_douban_data_here.txt
--------------------------------------------------------------------------------
/examples/data/ubuntu/put_ubuntu_data_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/examples/data/ubuntu/put_ubuntu_data_here.txt
--------------------------------------------------------------------------------
/examples/data_process.py:
--------------------------------------------------------------------------------
1 | import json
2 | input_file = "data/alime/train.txt"
3 | output_file_seg = "data/alime/train_seg.txt"
4 | cut_list_file = "data/alime/cutlist_train.json"
5 |
6 | final_outputfile="data/alime/trainseg.txt"
7 | interval = 2
8 |
9 | def generate_output_file_seg():
10 | contexts=[]
11 | num=0
12 | with open(input_file,'r',encoding='utf-8') as rf:
13 | for line in rf:
14 | line=line.strip()
15 | if(line and num%interval==0):
16 | contexts.append(line.split("\t")[1:-1])
17 | num += 1
18 | print(num)
19 | with open(cut_list_file, 'r', encoding='utf-8') as rf:
20 | cutlist = json.loads(rf.read())
21 |
22 | print(len(contexts))
23 | print(len(cutlist))
24 |
25 | with open(output_file_seg,'w',encoding='utf-8') as wf:
26 | for context, cutl in zip(contexts, cutlist):
27 | seg_list = []
28 | seg=""
29 | index_1 = 0
30 | index_2 = 0
31 | for index, utt in enumerate(context):
32 | if(seg):
33 | seg=seg+" "+utt
34 | else:
35 | seg=utt
36 | if (index_1 == cutl[index_2]):
37 | index_2 += 1
38 | seg_list.append(seg)
39 | seg=""
40 | index_1 += 1
41 | # a="\t".join(seg_list)+'\n'
42 | wf.write("\t".join(seg_list)+'\n')
43 |
44 |
45 | def write_segfile(inputfile_utt, inputfile_seg, outputfile, interval):
46 | lab_res_list=[]
47 | with open(inputfile_utt, 'r', encoding='utf-8') as rf:
48 | for line in rf:
49 | line=line.strip()
50 | if(line):
51 | lab=line.split('\t')[0]
52 | res=line.split('\t')[-1]
53 | lab_res_list.append([lab,None,res])
54 | print(len(lab_res_list))
55 | seg_list = []
56 | with open(inputfile_seg, 'r', encoding='utf-8') as rf:
57 | for line in rf:
58 | line = line.strip()
59 | if(line):
60 | seg_list.append(line)
61 | print(len(seg_list))
62 | i = 0
63 | for seg in seg_list:
64 | for _ in range(interval):
65 | lab_res_list[i][1] = seg
66 | i += 1
67 | print(i)
68 |
69 | with open(outputfile,'w',encoding='utf-8') as wf:
70 | for lab_seg_res in lab_res_list:
71 | wf.write('\t'.join(lab_seg_res)+'\n')
72 |
73 |
74 | generate_output_file_seg()
75 | write_segfile(input_file, output_file_seg, final_outputfile, interval)
--------------------------------------------------------------------------------
/examples/segmentation_BERTCLS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import json
4 | import argparse
5 | import numpy as np
6 | import os
7 |
8 | from pytorch_transformers import BertConfig, BertModel, BertTokenizer
9 | from utils_segmentation import convert_examples_to_features, read_expamples_2
10 | WINDOW_SIZE = 2
11 | SEGMENT_JUMP_STEP = 2
12 | SIMILARITY_THRESHOLD = 0.6
13 | MAX_SEGMENT_ROUND = 6
14 | MAX_SEQ_LENGTH = 50
15 | MODEL_CLASSES = {
16 | 'bert': (BertConfig, BertModel, BertTokenizer),
17 | }
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 | def similarity(A, B):
20 | return np.dot(A, B) / (np.linalg.norm(A) * np.linalg.norm(B))
21 |
22 |
23 | def generate_vectors_2(model,examples,tokenizer,device):
24 | features = convert_examples_to_features(examples, MAX_SEQ_LENGTH, tokenizer,
25 | cls_token=tokenizer.cls_token,
26 | cls_token_segment_id=0,
27 | sep_token=tokenizer.sep_token,
28 | )
29 |
30 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device)
31 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long).to(device)
32 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long).to(device)
33 |
34 | vectors = model(input_ids=all_input_ids, attention_mask=all_input_mask, token_type_ids=all_segment_ids)[1]
35 | return vectors.cpu().detach().numpy()
36 |
37 | def segmentation(documents):
38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39 | config_class, model_class, tokenizer_class = MODEL_CLASSES['bert']
40 | config = config_class.from_pretrained(args.berttype)
41 | tokenizer = tokenizer_class.from_pretrained(args.berttype,do_lower_case=True)
42 | model = model_class.from_pretrained(args.berttype ,config=config).to(device)#, from_tf=bool('.ckpt' in args.model_name_or_path),
43 | model.eval()
44 |
45 | all_cut_list = []
46 | pbar = tqdm(total=len(documents))
47 | for document_o in documents:
48 | if(len(document_o)%2):
49 | document=document_o[1:]
50 | else:
51 | document = document_o
52 | cut_index=0
53 | cut_list = []
54 | while(cut_index0):
62 | index=WINDOW_SIZE
63 | while(index>0):
64 | left_sent+=document[cut_index-index]
65 | index-=1
66 |
67 | else:
68 | temp_index=0
69 | while(temp_index right_value else right_value
104 | if(not left_sent and not right_sent):#防止前后都没有参考窗口,即len(document)<=MAX_SEGMENT_ROUND
105 |
106 | larger_value=SIMILARITY_THRESHOLD #如果中间截断的情况的最小相似性都大于0.8则这段通话不进行切分,中间截断的情况只有小于
107 | #这个阈值才会截断,
108 | if(larger_value 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
5 | print(
6 | "Should be used as one of: \n"
7 | ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
8 | ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
9 | ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
10 | ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
11 | ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
12 | ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
13 | else:
14 | if sys.argv[1] == "bert":
15 | try:
16 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
17 | except ImportError:
18 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
19 | "In that case, it requires TensorFlow to be installed. Please see "
20 | "https://www.tensorflow.org/install/ for installation instructions.")
21 | raise
22 |
23 | if len(sys.argv) != 5:
24 | # pylint: disable=line-too-long
25 | print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
26 | else:
27 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
28 | TF_CONFIG = sys.argv.pop()
29 | TF_CHECKPOINT = sys.argv.pop()
30 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
31 | elif sys.argv[1] == "gpt":
32 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
33 | if len(sys.argv) < 4 or len(sys.argv) > 5:
34 | # pylint: disable=line-too-long
35 | print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
36 | else:
37 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
38 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
39 | if len(sys.argv) == 5:
40 | OPENAI_GPT_CONFIG = sys.argv[4]
41 | else:
42 | OPENAI_GPT_CONFIG = ""
43 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
44 | OPENAI_GPT_CONFIG,
45 | PYTORCH_DUMP_OUTPUT)
46 | elif sys.argv[1] == "transfo_xl":
47 | try:
48 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
49 | except ImportError:
50 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
51 | "In that case, it requires TensorFlow to be installed. Please see "
52 | "https://www.tensorflow.org/install/ for installation instructions.")
53 | raise
54 | if len(sys.argv) < 4 or len(sys.argv) > 5:
55 | # pylint: disable=line-too-long
56 | print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
57 | else:
58 | if 'ckpt' in sys.argv[2].lower():
59 | TF_CHECKPOINT = sys.argv[2]
60 | TF_DATASET_FILE = ""
61 | else:
62 | TF_DATASET_FILE = sys.argv[2]
63 | TF_CHECKPOINT = ""
64 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
65 | if len(sys.argv) == 5:
66 | TF_CONFIG = sys.argv[4]
67 | else:
68 | TF_CONFIG = ""
69 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
70 | elif sys.argv[1] == "gpt2":
71 | try:
72 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
73 | except ImportError:
74 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
75 | "In that case, it requires TensorFlow to be installed. Please see "
76 | "https://www.tensorflow.org/install/ for installation instructions.")
77 | raise
78 |
79 | if len(sys.argv) < 4 or len(sys.argv) > 5:
80 | # pylint: disable=line-too-long
81 | print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
82 | else:
83 | TF_CHECKPOINT = sys.argv[2]
84 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
85 | if len(sys.argv) == 5:
86 | TF_CONFIG = sys.argv[4]
87 | else:
88 | TF_CONFIG = ""
89 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
90 | elif sys.argv[1] == "xlnet":
91 | try:
92 | from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
93 | except ImportError:
94 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
95 | "In that case, it requires TensorFlow to be installed. Please see "
96 | "https://www.tensorflow.org/install/ for installation instructions.")
97 | raise
98 |
99 | if len(sys.argv) < 5 or len(sys.argv) > 6:
100 | # pylint: disable=line-too-long
101 | print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
102 | else:
103 | TF_CHECKPOINT = sys.argv[2]
104 | TF_CONFIG = sys.argv[3]
105 | PYTORCH_DUMP_OUTPUT = sys.argv[4]
106 | if len(sys.argv) == 6:
107 | FINETUNING_TASK = sys.argv[5]
108 | else:
109 | FINETUNING_TASK = None
110 |
111 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
112 | TF_CONFIG,
113 | PYTORCH_DUMP_OUTPUT,
114 | FINETUNING_TASK)
115 | elif sys.argv[1] == "xlm":
116 | from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
117 |
118 | if len(sys.argv) != 4:
119 | # pylint: disable=line-too-long
120 | print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
121 | else:
122 | XLM_CHECKPOINT_PATH = sys.argv[2]
123 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
124 |
125 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)
126 |
127 | if __name__ == '__main__':
128 | main()
129 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_auto.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Auto Model class. """
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import logging
20 |
21 | from .configuration_bert import BertConfig
22 | from .configuration_openai import OpenAIGPTConfig
23 | from .configuration_gpt2 import GPT2Config
24 | from .configuration_transfo_xl import TransfoXLConfig
25 | from .configuration_xlnet import XLNetConfig
26 | from .configuration_xlm import XLMConfig
27 | from .configuration_roberta import RobertaConfig
28 | from .configuration_distilbert import DistilBertConfig
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | class AutoConfig(object):
34 | r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class
35 | that will be instantiated as one of the configuration classes of the library
36 | when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
37 | class method.
38 |
39 | The `from_pretrained()` method take care of returning the correct model class instance
40 | using pattern matching on the `pretrained_model_name_or_path` string.
41 |
42 | The base model class to instantiate is selected as the first pattern matching
43 | in the `pretrained_model_name_or_path` string (in the following order):
44 | - contains `distilbert`: DistilBertConfig (DistilBERT model)
45 | - contains `bert`: BertConfig (Bert model)
46 | - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
47 | - contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
48 | - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
49 | - contains `xlnet`: XLNetConfig (XLNet model)
50 | - contains `xlm`: XLMConfig (XLM model)
51 | - contains `roberta`: RobertaConfig (RoBERTa model)
52 |
53 | This class cannot be instantiated using `__init__()` (throw an error).
54 | """
55 | def __init__(self):
56 | raise EnvironmentError("AutoConfig is designed to be instantiated "
57 | "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.")
58 |
59 | @classmethod
60 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
61 | r""" Instantiate a one of the configuration classes of the library
62 | from a pre-trained model configuration.
63 |
64 | The configuration class to instantiate is selected as the first pattern matching
65 | in the `pretrained_model_name_or_path` string (in the following order):
66 | - contains `distilbert`: DistilBertConfig (DistilBERT model)
67 | - contains `bert`: BertConfig (Bert model)
68 | - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
69 | - contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
70 | - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
71 | - contains `xlnet`: XLNetConfig (XLNet model)
72 | - contains `xlm`: XLMConfig (XLM model)
73 | - contains `roberta`: RobertaConfig (RoBERTa model)
74 |
75 | Params:
76 | pretrained_model_name_or_path: either:
77 |
78 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
79 | - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
80 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
81 |
82 | cache_dir: (`optional`) string:
83 | Path to a directory in which a downloaded pre-trained model
84 | configuration should be cached if the standard cache should not be used.
85 |
86 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
87 |
88 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
89 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
90 |
91 | force_download: (`optional`) boolean, default False:
92 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
93 |
94 | proxies: (`optional`) dict, default None:
95 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
96 | The proxies are used on each request.
97 |
98 | return_unused_kwargs: (`optional`) bool:
99 |
100 | - If False, then this function returns just the final configuration object.
101 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
102 |
103 | Examples::
104 |
105 | config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
106 | config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
107 | config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
108 | config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
109 | assert config.output_attention == True
110 | config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True,
111 | foo=False, return_unused_kwargs=True)
112 | assert config.output_attention == True
113 | assert unused_kwargs == {'foo': False}
114 |
115 | """
116 | if 'distilbert' in pretrained_model_name_or_path:
117 | return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
118 | elif 'roberta' in pretrained_model_name_or_path:
119 | return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
120 | elif 'bert' in pretrained_model_name_or_path:
121 | return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
122 | elif 'openai-gpt' in pretrained_model_name_or_path:
123 | return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
124 | elif 'gpt2' in pretrained_model_name_or_path:
125 | return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
126 | elif 'transfo-xl' in pretrained_model_name_or_path:
127 | return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
128 | elif 'xlnet' in pretrained_model_name_or_path:
129 | return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
130 | elif 'xlm' in pretrained_model_name_or_path:
131 | return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
132 |
133 | raise ValueError("Unrecognized model identifier in {}. Should contains one of "
134 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
135 | "'xlm', 'roberta'".format(pretrained_model_name_or_path))
136 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from .configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
43 | }
44 |
45 |
46 | class BertConfig(PretrainedConfig):
47 | r"""
48 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
49 | `BertModel`.
50 |
51 |
52 | Arguments:
53 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
54 | hidden_size: Size of the encoder layers and the pooler layer.
55 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
56 | num_attention_heads: Number of attention heads for each attention layer in
57 | the Transformer encoder.
58 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
59 | layer in the Transformer encoder.
60 | hidden_act: The non-linear activation function (function or string) in the
61 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
62 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
63 | layers in the embeddings, encoder, and pooler.
64 | attention_probs_dropout_prob: The dropout ratio for the attention
65 | probabilities.
66 | max_position_embeddings: The maximum sequence length that this model might
67 | ever be used with. Typically set this to something large just in case
68 | (e.g., 512 or 1024 or 2048).
69 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
70 | `BertModel`.
71 | initializer_range: The sttdev of the truncated_normal_initializer for
72 | initializing all weight matrices.
73 | layer_norm_eps: The epsilon used by LayerNorm.
74 | """
75 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
76 |
77 | def __init__(self,
78 | vocab_size_or_config_json_file=30522,
79 | hidden_size=768,
80 | num_hidden_layers=12,
81 | num_attention_heads=12,
82 | intermediate_size=3072,
83 | hidden_act="gelu",
84 | hidden_dropout_prob=0.1,
85 | attention_probs_dropout_prob=0.1,
86 | max_position_embeddings=512,
87 | type_vocab_size=2,
88 | initializer_range=0.02,
89 | layer_norm_eps=1e-12,
90 | **kwargs):
91 | super(BertConfig, self).__init__(**kwargs)
92 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
93 | and isinstance(vocab_size_or_config_json_file, unicode)):
94 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
95 | json_config = json.loads(reader.read())
96 | for key, value in json_config.items():
97 | self.__dict__[key] = value
98 | elif isinstance(vocab_size_or_config_json_file, int):
99 | self.vocab_size = vocab_size_or_config_json_file
100 | self.hidden_size = hidden_size
101 | self.num_hidden_layers = num_hidden_layers
102 | self.num_attention_heads = num_attention_heads
103 | self.hidden_act = hidden_act
104 | self.intermediate_size = intermediate_size
105 | self.hidden_dropout_prob = hidden_dropout_prob
106 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
107 | self.max_position_embeddings = max_position_embeddings
108 | self.type_vocab_size = type_vocab_size
109 | self.initializer_range = initializer_range
110 | self.layer_norm_eps = layer_norm_eps
111 | else:
112 | raise ValueError("First argument must be either a vocabulary size (int)"
113 | " or the path to a pretrained model config file (str)")
114 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_distilbert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ DistilBERT model configuration """
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import sys
20 | import json
21 | import logging
22 | from io import open
23 |
24 | from .configuration_utils import PretrainedConfig
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
29 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
30 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
31 | }
32 |
33 |
34 | class DistilBertConfig(PretrainedConfig):
35 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
36 |
37 | def __init__(self,
38 | vocab_size_or_config_json_file=30522,
39 | max_position_embeddings=512,
40 | sinusoidal_pos_embds=True,
41 | n_layers=6,
42 | n_heads=12,
43 | dim=768,
44 | hidden_dim=4*768,
45 | dropout=0.1,
46 | attention_dropout=0.1,
47 | activation='gelu',
48 | initializer_range=0.02,
49 | tie_weights_=True,
50 | qa_dropout=0.1,
51 | seq_classif_dropout=0.2,
52 | **kwargs):
53 | super(DistilBertConfig, self).__init__(**kwargs)
54 |
55 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
56 | and isinstance(vocab_size_or_config_json_file, unicode)):
57 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
58 | json_config = json.loads(reader.read())
59 | for key, value in json_config.items():
60 | self.__dict__[key] = value
61 | elif isinstance(vocab_size_or_config_json_file, int):
62 | self.vocab_size = vocab_size_or_config_json_file
63 | self.max_position_embeddings = max_position_embeddings
64 | self.sinusoidal_pos_embds = sinusoidal_pos_embds
65 | self.n_layers = n_layers
66 | self.n_heads = n_heads
67 | self.dim = dim
68 | self.hidden_dim = hidden_dim
69 | self.dropout = dropout
70 | self.attention_dropout = attention_dropout
71 | self.activation = activation
72 | self.initializer_range = initializer_range
73 | self.tie_weights_ = tie_weights_
74 | self.qa_dropout = qa_dropout
75 | self.seq_classif_dropout = seq_classif_dropout
76 | else:
77 | raise ValueError("First argument must be either a vocabulary size (int)"
78 | " or the path to a pretrained model config file (str)")
79 | @property
80 | def hidden_size(self):
81 | return self.dim
82 |
83 | @property
84 | def num_attention_heads(self):
85 | return self.n_heads
86 |
87 | @property
88 | def num_hidden_layers(self):
89 | return self.n_layers
90 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ OpenAI GPT-2 configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from .configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
32 |
33 | class GPT2Config(PretrainedConfig):
34 | """Configuration class to store the configuration of a `GPT2Model`.
35 |
36 | Args:
37 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
38 | n_positions: Number of positional embeddings.
39 | n_ctx: Size of the causal mask (usually same as n_positions).
40 | n_embd: Dimensionality of the embeddings and hidden states.
41 | n_layer: Number of hidden layers in the Transformer encoder.
42 | n_head: Number of attention heads for each attention layer in
43 | the Transformer encoder.
44 | layer_norm_epsilon: epsilon to use in the layer norm layers
45 | resid_pdrop: The dropout probabilitiy for all fully connected
46 | layers in the embeddings, encoder, and pooler.
47 | attn_pdrop: The dropout ratio for the attention
48 | probabilities.
49 | embd_pdrop: The dropout ratio for the embeddings.
50 | initializer_range: The sttdev of the truncated_normal_initializer for
51 | initializing all weight matrices.
52 | """
53 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
54 |
55 | def __init__(
56 | self,
57 | vocab_size_or_config_json_file=50257,
58 | n_positions=1024,
59 | n_ctx=1024,
60 | n_embd=768,
61 | n_layer=12,
62 | n_head=12,
63 | resid_pdrop=0.1,
64 | embd_pdrop=0.1,
65 | attn_pdrop=0.1,
66 | layer_norm_epsilon=1e-5,
67 | initializer_range=0.02,
68 |
69 | num_labels=1,
70 | summary_type='cls_index',
71 | summary_use_proj=True,
72 | summary_activation=None,
73 | summary_proj_to_labels=True,
74 | summary_first_dropout=0.1,
75 | **kwargs
76 | ):
77 | """Constructs GPT2Config.
78 |
79 | Args:
80 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
81 | n_positions: Number of positional embeddings.
82 | n_ctx: Size of the causal mask (usually same as n_positions).
83 | n_embd: Dimensionality of the embeddings and hidden states.
84 | n_layer: Number of hidden layers in the Transformer encoder.
85 | n_head: Number of attention heads for each attention layer in
86 | the Transformer encoder.
87 | layer_norm_epsilon: epsilon to use in the layer norm layers
88 | resid_pdrop: The dropout probabilitiy for all fully connected
89 | layers in the embeddings, encoder, and pooler.
90 | attn_pdrop: The dropout ratio for the attention
91 | probabilities.
92 | embd_pdrop: The dropout ratio for the embeddings.
93 | initializer_range: The sttdev of the truncated_normal_initializer for
94 | initializing all weight matrices.
95 | """
96 | super(GPT2Config, self).__init__(**kwargs)
97 |
98 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
99 | and isinstance(vocab_size_or_config_json_file, unicode)):
100 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
101 | json_config = json.loads(reader.read())
102 | for key, value in json_config.items():
103 | self.__dict__[key] = value
104 | elif isinstance(vocab_size_or_config_json_file, int):
105 | self.vocab_size = vocab_size_or_config_json_file
106 | self.n_ctx = n_ctx
107 | self.n_positions = n_positions
108 | self.n_embd = n_embd
109 | self.n_layer = n_layer
110 | self.n_head = n_head
111 | self.resid_pdrop = resid_pdrop
112 | self.embd_pdrop = embd_pdrop
113 | self.attn_pdrop = attn_pdrop
114 | self.layer_norm_epsilon = layer_norm_epsilon
115 | self.initializer_range = initializer_range
116 |
117 | self.num_labels = num_labels
118 | self.summary_type = summary_type
119 | self.summary_use_proj = summary_use_proj
120 | self.summary_activation = summary_activation
121 | self.summary_first_dropout = summary_first_dropout
122 | self.summary_proj_to_labels = summary_proj_to_labels
123 | else:
124 | raise ValueError(
125 | "First argument must be either a vocabulary size (int)"
126 | "or the path to a pretrained model config file (str)"
127 | )
128 |
129 | @property
130 | def max_position_embeddings(self):
131 | return self.n_positions
132 |
133 | @property
134 | def hidden_size(self):
135 | return self.n_embd
136 |
137 | @property
138 | def num_attention_heads(self):
139 | return self.n_head
140 |
141 | @property
142 | def num_hidden_layers(self):
143 | return self.n_layer
144 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ OpenAI GPT configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from .configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
31 | }
32 |
33 | class OpenAIGPTConfig(PretrainedConfig):
34 | """
35 | Configuration class to store the configuration of a `OpenAIGPTModel`.
36 |
37 | Args:
38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
39 | n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
40 | n_positions: Number of positional embeddings.
41 | n_ctx: Size of the causal mask (usually same as n_positions).
42 | n_embd: Dimensionality of the embeddings and hidden states.
43 | n_layer: Number of hidden layers in the Transformer encoder.
44 | n_head: Number of attention heads for each attention layer in
45 | the Transformer encoder.
46 | afn: The non-linear activation function (function or string) in the
47 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
48 | resid_pdrop: The dropout probabilitiy for all fully connected
49 | layers in the embeddings, encoder, and pooler.
50 | attn_pdrop: The dropout ratio for the attention
51 | probabilities.
52 | embd_pdrop: The dropout ratio for the embeddings.
53 | layer_norm_epsilon: epsilon to use in the layer norm layers
54 | initializer_range: The sttdev of the truncated_normal_initializer for
55 | initializing all weight matrices.
56 | predict_special_tokens: should we predict special tokens (when the model has a LM head)
57 | """
58 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
59 |
60 | def __init__(
61 | self,
62 | vocab_size_or_config_json_file=40478,
63 | n_positions=512,
64 | n_ctx=512,
65 | n_embd=768,
66 | n_layer=12,
67 | n_head=12,
68 | afn="gelu",
69 | resid_pdrop=0.1,
70 | embd_pdrop=0.1,
71 | attn_pdrop=0.1,
72 | layer_norm_epsilon=1e-5,
73 | initializer_range=0.02,
74 | predict_special_tokens=True,
75 |
76 | num_labels=1,
77 | summary_type='cls_index',
78 | summary_use_proj=True,
79 | summary_activation=None,
80 | summary_proj_to_labels=True,
81 | summary_first_dropout=0.1,
82 | **kwargs
83 | ):
84 | """Constructs OpenAIGPTConfig.
85 | """
86 | super(OpenAIGPTConfig, self).__init__(**kwargs)
87 |
88 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
89 | and isinstance(vocab_size_or_config_json_file, unicode)):
90 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
91 | json_config = json.loads(reader.read())
92 | for key, value in json_config.items():
93 | self.__dict__[key] = value
94 | elif isinstance(vocab_size_or_config_json_file, int):
95 | self.vocab_size = vocab_size_or_config_json_file
96 | self.n_ctx = n_ctx
97 | self.n_positions = n_positions
98 | self.n_embd = n_embd
99 | self.n_layer = n_layer
100 | self.n_head = n_head
101 | self.afn = afn
102 | self.resid_pdrop = resid_pdrop
103 | self.embd_pdrop = embd_pdrop
104 | self.attn_pdrop = attn_pdrop
105 | self.layer_norm_epsilon = layer_norm_epsilon
106 | self.initializer_range = initializer_range
107 | self.predict_special_tokens = predict_special_tokens
108 |
109 | self.num_labels = num_labels
110 | self.summary_type = summary_type
111 | self.summary_use_proj = summary_use_proj
112 | self.summary_activation = summary_activation
113 | self.summary_first_dropout = summary_first_dropout
114 | self.summary_proj_to_labels = summary_proj_to_labels
115 | else:
116 | raise ValueError(
117 | "First argument must be either a vocabulary size (int)"
118 | "or the path to a pretrained model config file (str)"
119 | )
120 |
121 | @property
122 | def max_position_embeddings(self):
123 | return self.n_positions
124 |
125 | @property
126 | def hidden_size(self):
127 | return self.n_embd
128 |
129 | @property
130 | def num_attention_heads(self):
131 | return self.n_head
132 |
133 | @property
134 | def num_hidden_layers(self):
135 | return self.n_layer
136 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ RoBERTa configuration """
17 |
18 | from __future__ import (absolute_import, division, print_function,
19 | unicode_literals)
20 |
21 | import logging
22 |
23 | from .configuration_bert import BertConfig
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
31 | }
32 |
33 |
34 | class RobertaConfig(BertConfig):
35 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
36 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_transfo_xl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Transformer XL configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from .configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
31 | }
32 |
33 | class TransfoXLConfig(PretrainedConfig):
34 | """Configuration class to store the configuration of a `TransfoXLModel`.
35 |
36 | Args:
37 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
38 | cutoffs: cutoffs for the adaptive softmax
39 | d_model: Dimensionality of the model's hidden states.
40 | d_embed: Dimensionality of the embeddings
41 | d_head: Dimensionality of the model's heads.
42 | div_val: divident value for adapative input and softmax
43 | pre_lnorm: apply LayerNorm to the input instead of the output
44 | d_inner: Inner dimension in FF
45 | n_layer: Number of hidden layers in the Transformer encoder.
46 | n_head: Number of attention heads for each attention layer in
47 | the Transformer encoder.
48 | tgt_len: number of tokens to predict
49 | ext_len: length of the extended context
50 | mem_len: length of the retained previous heads
51 | same_length: use the same attn length for all tokens
52 | proj_share_all_but_first: True to share all but first projs, False not to share.
53 | attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
54 | clamp_len: use the same pos embeddings after clamp_len
55 | sample_softmax: number of samples in sampled softmax
56 | adaptive: use adaptive softmax
57 | tie_weight: tie the word embedding and softmax weights
58 | dropout: The dropout probabilitiy for all fully connected
59 | layers in the embeddings, encoder, and pooler.
60 | dropatt: The dropout ratio for the attention probabilities.
61 | untie_r: untie relative position biases
62 | embd_pdrop: The dropout ratio for the embeddings.
63 | init: parameter initializer to use
64 | init_range: parameters initialized by U(-init_range, init_range).
65 | proj_init_std: parameters initialized by N(0, init_std)
66 | init_std: parameters initialized by N(0, init_std)
67 | """
68 | pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
69 |
70 | def __init__(self,
71 | vocab_size_or_config_json_file=267735,
72 | cutoffs=[20000, 40000, 200000],
73 | d_model=1024,
74 | d_embed=1024,
75 | n_head=16,
76 | d_head=64,
77 | d_inner=4096,
78 | div_val=4,
79 | pre_lnorm=False,
80 | n_layer=18,
81 | tgt_len=128,
82 | ext_len=0,
83 | mem_len=1600,
84 | clamp_len=1000,
85 | same_length=True,
86 | proj_share_all_but_first=True,
87 | attn_type=0,
88 | sample_softmax=-1,
89 | adaptive=True,
90 | tie_weight=True,
91 | dropout=0.1,
92 | dropatt=0.0,
93 | untie_r=True,
94 | init="normal",
95 | init_range=0.01,
96 | proj_init_std=0.01,
97 | init_std=0.02,
98 | **kwargs):
99 | """Constructs TransfoXLConfig.
100 | """
101 | super(TransfoXLConfig, self).__init__(**kwargs)
102 |
103 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
104 | and isinstance(vocab_size_or_config_json_file, unicode)):
105 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
106 | json_config = json.loads(reader.read())
107 | for key, value in json_config.items():
108 | self.__dict__[key] = value
109 | elif isinstance(vocab_size_or_config_json_file, int):
110 | self.n_token = vocab_size_or_config_json_file
111 | self.cutoffs = []
112 | self.cutoffs.extend(cutoffs)
113 | self.tie_weight = tie_weight
114 | if proj_share_all_but_first:
115 | self.tie_projs = [False] + [True] * len(self.cutoffs)
116 | else:
117 | self.tie_projs = [False] + [False] * len(self.cutoffs)
118 | self.d_model = d_model
119 | self.d_embed = d_embed
120 | self.d_head = d_head
121 | self.d_inner = d_inner
122 | self.div_val = div_val
123 | self.pre_lnorm = pre_lnorm
124 | self.n_layer = n_layer
125 | self.n_head = n_head
126 | self.tgt_len = tgt_len
127 | self.ext_len = ext_len
128 | self.mem_len = mem_len
129 | self.same_length = same_length
130 | self.attn_type = attn_type
131 | self.clamp_len = clamp_len
132 | self.sample_softmax = sample_softmax
133 | self.adaptive = adaptive
134 | self.dropout = dropout
135 | self.dropatt = dropatt
136 | self.untie_r = untie_r
137 | self.init = init
138 | self.init_range = init_range
139 | self.proj_init_std = proj_init_std
140 | self.init_std = init_std
141 | else:
142 | raise ValueError("First argument must be either a vocabulary size (int)"
143 | " or the path to a pretrained model config file (str)")
144 |
145 | @property
146 | def max_position_embeddings(self):
147 | return self.tgt_len + self.ext_len + self.mem_len
148 |
149 | @property
150 | def vocab_size(self):
151 | return self.n_token
152 |
153 | @vocab_size.setter
154 | def vocab_size(self, value):
155 | self.n_token = value
156 |
157 | @property
158 | def hidden_size(self):
159 | return self.d_model
160 |
161 | @property
162 | def num_attention_heads(self):
163 | return self.n_head
164 |
165 | @property
166 | def num_hidden_layers(self):
167 | return self.n_layer
168 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_xlm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ XLM configuration """
16 | from __future__ import absolute_import, division, print_function, unicode_literals
17 |
18 | import json
19 | import logging
20 | import sys
21 | from io import open
22 |
23 | from .configuration_utils import PretrainedConfig
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28 | 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
29 | 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
30 | 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
31 | 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
32 | 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
33 | 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
34 | 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
35 | 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
36 | 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
37 | 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
38 | }
39 |
40 |
41 | class XLMConfig(PretrainedConfig):
42 | """Configuration class to store the configuration of a `XLMModel`.
43 |
44 | Args:
45 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`.
46 | d_model: Size of the encoder layers and the pooler layer.
47 | n_layer: Number of hidden layers in the Transformer encoder.
48 | n_head: Number of attention heads for each attention layer in
49 | the Transformer encoder.
50 | d_inner: The size of the "intermediate" (i.e., feed-forward)
51 | layer in the Transformer encoder.
52 | ff_activation: The non-linear activation function (function or string) in the
53 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
54 | untie_r: untie relative position biases
55 | attn_type: 'bi' for XLM, 'uni' for Transformer-XL
56 |
57 | dropout: The dropout probabilitiy for all fully connected
58 | layers in the embeddings, encoder, and pooler.
59 | dropatt: The dropout ratio for the attention
60 | probabilities.
61 | max_position_embeddings: The maximum sequence length that this model might
62 | ever be used with. Typically set this to something large just in case
63 | (e.g., 512 or 1024 or 2048).
64 | initializer_range: The sttdev of the truncated_normal_initializer for
65 | initializing all weight matrices.
66 | layer_norm_eps: The epsilon used by LayerNorm.
67 |
68 | dropout: float, dropout rate.
69 | dropatt: float, dropout rate on attention probabilities.
70 | init: str, the initialization scheme, either "normal" or "uniform".
71 | init_range: float, initialize the parameters with a uniform distribution
72 | in [-init_range, init_range]. Only effective when init="uniform".
73 | init_std: float, initialize the parameters with a normal distribution
74 | with mean 0 and stddev init_std. Only effective when init="normal".
75 | mem_len: int, the number of tokens to cache.
76 | reuse_len: int, the number of tokens in the currect batch to be cached
77 | and reused in the future.
78 | bi_data: bool, whether to use bidirectional input pipeline.
79 | Usually set to True during pretraining and False during finetuning.
80 | clamp_len: int, clamp all relative distances larger than clamp_len.
81 | -1 means no clamping.
82 | same_length: bool, whether to use the same attention length for each token.
83 | """
84 | pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
85 |
86 | def __init__(self,
87 | vocab_size_or_config_json_file=30145,
88 | emb_dim=2048,
89 | n_layers=12,
90 | n_heads=16,
91 | dropout=0.1,
92 | attention_dropout=0.1,
93 | gelu_activation=True,
94 | sinusoidal_embeddings=False,
95 | causal=False,
96 | asm=False,
97 | n_langs=1,
98 | use_lang_emb=True,
99 | max_position_embeddings=512,
100 | embed_init_std=2048 ** -0.5,
101 | layer_norm_eps=1e-12,
102 | init_std=0.02,
103 | bos_index=0,
104 | eos_index=1,
105 | pad_index=2,
106 | unk_index=3,
107 | mask_index=5,
108 | is_encoder=True,
109 |
110 | finetuning_task=None,
111 | num_labels=2,
112 | summary_type='first',
113 | summary_use_proj=True,
114 | summary_activation=None,
115 | summary_proj_to_labels=True,
116 | summary_first_dropout=0.1,
117 | start_n_top=5,
118 | end_n_top=5,
119 | **kwargs):
120 | """Constructs XLMConfig.
121 | """
122 | super(XLMConfig, self).__init__(**kwargs)
123 |
124 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
125 | and isinstance(vocab_size_or_config_json_file, unicode)):
126 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
127 | json_config = json.loads(reader.read())
128 | for key, value in json_config.items():
129 | self.__dict__[key] = value
130 | elif isinstance(vocab_size_or_config_json_file, int):
131 | self.n_words = vocab_size_or_config_json_file
132 | self.emb_dim = emb_dim
133 | self.n_layers = n_layers
134 | self.n_heads = n_heads
135 | self.dropout = dropout
136 | self.attention_dropout = attention_dropout
137 | self.gelu_activation = gelu_activation
138 | self.sinusoidal_embeddings = sinusoidal_embeddings
139 | self.causal = causal
140 | self.asm = asm
141 | self.n_langs = n_langs
142 | self.use_lang_emb = use_lang_emb
143 | self.layer_norm_eps = layer_norm_eps
144 | self.bos_index = bos_index
145 | self.eos_index = eos_index
146 | self.pad_index = pad_index
147 | self.unk_index = unk_index
148 | self.mask_index = mask_index
149 | self.is_encoder = is_encoder
150 | self.max_position_embeddings = max_position_embeddings
151 | self.embed_init_std = embed_init_std
152 | self.init_std = init_std
153 | self.finetuning_task = finetuning_task
154 | self.num_labels = num_labels
155 | self.summary_type = summary_type
156 | self.summary_use_proj = summary_use_proj
157 | self.summary_activation = summary_activation
158 | self.summary_proj_to_labels = summary_proj_to_labels
159 | self.summary_first_dropout = summary_first_dropout
160 | self.start_n_top = start_n_top
161 | self.end_n_top = end_n_top
162 | else:
163 | raise ValueError("First argument must be either a vocabulary size (int)"
164 | " or the path to a pretrained model config file (str)")
165 |
166 | @property
167 | def vocab_size(self):
168 | return self.n_words
169 |
170 | @vocab_size.setter
171 | def vocab_size(self, value):
172 | self.n_words = value
173 |
174 | @property
175 | def hidden_size(self):
176 | return self.emb_dim
177 |
178 | @property
179 | def num_attention_heads(self):
180 | return self.n_heads
181 |
182 | @property
183 | def num_hidden_layers(self):
184 | return self.n_layers
185 |
--------------------------------------------------------------------------------
/pytorch_transformers/configuration_xlnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ XLNet configuration """
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import json
20 | import logging
21 | import sys
22 | from io import open
23 |
24 | from .configuration_utils import PretrainedConfig
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
29 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
30 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
31 | }
32 |
33 |
34 | class XLNetConfig(PretrainedConfig):
35 | """Configuration class to store the configuration of a ``XLNetModel``.
36 |
37 | Args:
38 | vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``.
39 | d_model: Size of the encoder layers and the pooler layer.
40 | n_layer: Number of hidden layers in the Transformer encoder.
41 | n_head: Number of attention heads for each attention layer in
42 | the Transformer encoder.
43 | d_inner: The size of the "intermediate" (i.e., feed-forward)
44 | layer in the Transformer encoder.
45 | ff_activation: The non-linear activation function (function or string) in the
46 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
47 | untie_r: untie relative position biases
48 | attn_type: 'bi' for XLNet, 'uni' for Transformer-XL
49 |
50 | dropout: The dropout probabilitiy for all fully connected
51 | layers in the embeddings, encoder, and pooler.
52 | dropatt: The dropout ratio for the attention
53 | probabilities.
54 | initializer_range: The sttdev of the truncated_normal_initializer for
55 | initializing all weight matrices.
56 | layer_norm_eps: The epsilon used by LayerNorm.
57 |
58 | dropout: float, dropout rate.
59 | dropatt: float, dropout rate on attention probabilities.
60 | init: str, the initialization scheme, either "normal" or "uniform".
61 | init_range: float, initialize the parameters with a uniform distribution
62 | in [-init_range, init_range]. Only effective when init="uniform".
63 | init_std: float, initialize the parameters with a normal distribution
64 | with mean 0 and stddev init_std. Only effective when init="normal".
65 | mem_len: int, the number of tokens to cache.
66 | reuse_len: int, the number of tokens in the currect batch to be cached
67 | and reused in the future.
68 | bi_data: bool, whether to use bidirectional input pipeline.
69 | Usually set to True during pretraining and False during finetuning.
70 | clamp_len: int, clamp all relative distances larger than clamp_len.
71 | -1 means no clamping.
72 | same_length: bool, whether to use the same attention length for each token.
73 | finetuning_task: name of the glue task on which the model was fine-tuned if any
74 | """
75 | pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
76 |
77 | def __init__(self,
78 | vocab_size_or_config_json_file=32000,
79 | d_model=1024,
80 | n_layer=24,
81 | n_head=16,
82 | d_inner=4096,
83 | ff_activation="gelu",
84 | untie_r=True,
85 | attn_type="bi",
86 |
87 | initializer_range=0.02,
88 | layer_norm_eps=1e-12,
89 |
90 | dropout=0.1,
91 | mem_len=None,
92 | reuse_len=None,
93 | bi_data=False,
94 | clamp_len=-1,
95 | same_length=False,
96 |
97 | finetuning_task=None,
98 | num_labels=2,
99 | summary_type='last',
100 | summary_use_proj=True,
101 | summary_activation='tanh',
102 | summary_last_dropout=0.1,
103 | start_n_top=5,
104 | end_n_top=5,
105 | **kwargs):
106 | """Constructs XLNetConfig.
107 | """
108 | super(XLNetConfig, self).__init__(**kwargs)
109 |
110 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
111 | and isinstance(vocab_size_or_config_json_file, unicode)):
112 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
113 | json_config = json.loads(reader.read())
114 | for key, value in json_config.items():
115 | self.__dict__[key] = value
116 | elif isinstance(vocab_size_or_config_json_file, int):
117 | self.n_token = vocab_size_or_config_json_file
118 | self.d_model = d_model
119 | self.n_layer = n_layer
120 | self.n_head = n_head
121 | assert d_model % n_head == 0
122 | self.d_head = d_model // n_head
123 | self.ff_activation = ff_activation
124 | self.d_inner = d_inner
125 | self.untie_r = untie_r
126 | self.attn_type = attn_type
127 |
128 | self.initializer_range = initializer_range
129 | self.layer_norm_eps = layer_norm_eps
130 |
131 | self.dropout = dropout
132 | self.mem_len = mem_len
133 | self.reuse_len = reuse_len
134 | self.bi_data = bi_data
135 | self.clamp_len = clamp_len
136 | self.same_length = same_length
137 |
138 | self.finetuning_task = finetuning_task
139 | self.num_labels = num_labels
140 | self.summary_type = summary_type
141 | self.summary_use_proj = summary_use_proj
142 | self.summary_activation = summary_activation
143 | self.summary_last_dropout = summary_last_dropout
144 | self.start_n_top = start_n_top
145 | self.end_n_top = end_n_top
146 | else:
147 | raise ValueError("First argument must be either a vocabulary size (int)"
148 | " or the path to a pretrained model config file (str)")
149 |
150 | @property
151 | def max_position_embeddings(self):
152 | return -1
153 |
154 | @property
155 | def vocab_size(self):
156 | return self.n_token
157 |
158 | @vocab_size.setter
159 | def vocab_size(self, value):
160 | self.n_token = value
161 |
162 | @property
163 | def hidden_size(self):
164 | return self.d_model
165 |
166 | @property
167 | def num_attention_heads(self):
168 | return self.n_head
169 |
170 | @property
171 | def num_hidden_layers(self):
172 | return self.n_layer
173 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
25 | GPT2Config,
26 | GPT2Model,
27 | load_tf_weights_in_gpt2)
28 |
29 | import logging
30 | logging.basicConfig(level=logging.INFO)
31 |
32 |
33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
34 | # Construct model
35 | if gpt2_config_file == "":
36 | config = GPT2Config()
37 | else:
38 | config = GPT2Config.from_json_file(gpt2_config_file)
39 | model = GPT2Model(config)
40 |
41 | # Load weights from numpy
42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
43 |
44 | # Save pytorch-model
45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
48 | torch.save(model.state_dict(), pytorch_weights_dump_path)
49 | print("Save configuration file to {}".format(pytorch_config_dump_path))
50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
51 | f.write(config.to_json_string())
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | ## Required parameters
57 | parser.add_argument("--gpt2_checkpoint_path",
58 | default = None,
59 | type = str,
60 | required = True,
61 | help = "Path to the TensorFlow checkpoint path.")
62 | parser.add_argument("--pytorch_dump_folder_path",
63 | default = None,
64 | type = str,
65 | required = True,
66 | help = "Path to the output PyTorch model.")
67 | parser.add_argument("--gpt2_config_file",
68 | default = "",
69 | type = str,
70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
71 | "This specifies the model architecture.")
72 | args = parser.parse_args()
73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
74 | args.gpt2_config_file,
75 | args.pytorch_dump_folder_path)
76 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_openai_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
25 | OpenAIGPTConfig,
26 | OpenAIGPTModel,
27 | load_tf_weights_in_openai_gpt)
28 |
29 | import logging
30 | logging.basicConfig(level=logging.INFO)
31 |
32 |
33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
34 | # Construct model
35 | if openai_config_file == "":
36 | config = OpenAIGPTConfig()
37 | else:
38 | config = OpenAIGPTConfig.from_json_file(openai_config_file)
39 | model = OpenAIGPTModel(config)
40 |
41 | # Load weights from numpy
42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
43 |
44 | # Save pytorch-model
45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
48 | torch.save(model.state_dict(), pytorch_weights_dump_path)
49 | print("Save configuration file to {}".format(pytorch_config_dump_path))
50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
51 | f.write(config.to_json_string())
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | ## Required parameters
57 | parser.add_argument("--openai_checkpoint_folder_path",
58 | default = None,
59 | type = str,
60 | required = True,
61 | help = "Path to the TensorFlow checkpoint path.")
62 | parser.add_argument("--pytorch_dump_folder_path",
63 | default = None,
64 | type = str,
65 | required = True,
66 | help = "Path to the output PyTorch model.")
67 | parser.add_argument("--openai_config_file",
68 | default = "",
69 | type = str,
70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
71 | "This specifies the model architecture.")
72 | args = parser.parse_args()
73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
74 | args.openai_config_file,
75 | args.pytorch_dump_folder_path)
76 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_pytorch_checkpoint_to_tf.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
17 |
18 | import os
19 | import argparse
20 | import torch
21 | import numpy as np
22 | import tensorflow as tf
23 | from pytorch_transformers import BertModel
24 |
25 |
26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
27 |
28 | """
29 | :param model:BertModel Pytorch model instance to be converted
30 | :param ckpt_dir: Tensorflow model directory
31 | :param model_name: model name
32 | :return:
33 |
34 | Currently supported HF models:
35 | Y BertModel
36 | N BertForMaskedLM
37 | N BertForPreTraining
38 | N BertForMultipleChoice
39 | N BertForNextSentencePrediction
40 | N BertForSequenceClassification
41 | N BertForQuestionAnswering
42 | """
43 |
44 | tensors_to_transpose = (
45 | "dense.weight",
46 | "attention.self.query",
47 | "attention.self.key",
48 | "attention.self.value"
49 | )
50 |
51 | var_map = (
52 | ('layer.', 'layer_'),
53 | ('word_embeddings.weight', 'word_embeddings'),
54 | ('position_embeddings.weight', 'position_embeddings'),
55 | ('token_type_embeddings.weight', 'token_type_embeddings'),
56 | ('.', '/'),
57 | ('LayerNorm/weight', 'LayerNorm/gamma'),
58 | ('LayerNorm/bias', 'LayerNorm/beta'),
59 | ('weight', 'kernel')
60 | )
61 |
62 | if not os.path.isdir(ckpt_dir):
63 | os.makedirs(ckpt_dir)
64 |
65 | state_dict = model.state_dict()
66 |
67 | def to_tf_var_name(name:str):
68 | for patt, repl in iter(var_map):
69 | name = name.replace(patt, repl)
70 | return 'bert/{}'.format(name)
71 |
72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session):
73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
75 | session.run(tf.variables_initializer([tf_var]))
76 | session.run(tf_var)
77 | return tf_var
78 |
79 | tf.reset_default_graph()
80 | with tf.Session() as session:
81 | for var_name in state_dict:
82 | tf_name = to_tf_var_name(var_name)
83 | torch_tensor = state_dict[var_name].numpy()
84 | if any([x in var_name for x in tensors_to_transpose]):
85 | torch_tensor = torch_tensor.T
86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
87 | tf.keras.backend.set_value(tf_var, torch_tensor)
88 | tf_weight = session.run(tf_var)
89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
90 |
91 | saver = tf.train.Saver(tf.trainable_variables())
92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
93 |
94 |
95 | def main(raw_args=None):
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument("--model_name",
98 | type=str,
99 | required=True,
100 | help="model name e.g. bert-base-uncased")
101 | parser.add_argument("--cache_dir",
102 | type=str,
103 | default=None,
104 | required=False,
105 | help="Directory containing pytorch model")
106 | parser.add_argument("--pytorch_model_path",
107 | type=str,
108 | required=True,
109 | help="/path/to/.bin")
110 | parser.add_argument("--tf_cache_dir",
111 | type=str,
112 | required=True,
113 | help="Directory in which to save tensorflow model")
114 | args = parser.parse_args(raw_args)
115 |
116 | model = BertModel.from_pretrained(
117 | pretrained_model_name_or_path=args.model_name,
118 | state_dict=torch.load(args.pytorch_model_path),
119 | cache_dir=args.cache_dir
120 | )
121 |
122 | convert_pytorch_checkpoint_to_tf(
123 | model=model,
124 | ckpt_dir=args.tf_cache_dir,
125 | model_name=args.model_name
126 | )
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert RoBERTa checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import logging
21 | import numpy as np
22 | import torch
23 |
24 | from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
25 | from fairseq.modules import TransformerSentenceEncoderLayer
26 | from pytorch_transformers import (BertConfig, BertEncoder,
27 | BertIntermediate, BertLayer,
28 | BertModel, BertOutput,
29 | BertSelfAttention,
30 | BertSelfOutput)
31 | from pytorch_transformers import (RobertaEmbeddings,
32 | RobertaForMaskedLM,
33 | RobertaForSequenceClassification,
34 | RobertaModel)
35 |
36 | logging.basicConfig(level=logging.INFO)
37 | logger = logging.getLogger(__name__)
38 |
39 | SAMPLE_TEXT = 'Hello world! cécé herlolip'
40 |
41 |
42 | def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
43 | """
44 | Copy/paste/tweak roberta's weights to our BERT structure.
45 | """
46 | roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
47 | roberta.eval() # disable dropout
48 | config = BertConfig(
49 | vocab_size_or_config_json_file=50265,
50 | hidden_size=roberta.args.encoder_embed_dim,
51 | num_hidden_layers=roberta.args.encoder_layers,
52 | num_attention_heads=roberta.args.encoder_attention_heads,
53 | intermediate_size=roberta.args.encoder_ffn_embed_dim,
54 | max_position_embeddings=514,
55 | type_vocab_size=1,
56 | layer_norm_eps=1e-5, # PyTorch default used in fairseq
57 | )
58 | if classification_head:
59 | config.num_labels = roberta.args.num_classes
60 | print("Our BERT config:", config)
61 |
62 | model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
63 | model.eval()
64 |
65 | # Now let's copy all the weights.
66 | # Embeddings
67 | roberta_sent_encoder = roberta.model.decoder.sentence_encoder
68 | model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
69 | model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
70 | model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
71 | model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
72 | model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
73 |
74 | for i in range(config.num_hidden_layers):
75 | # Encoder: start of layer
76 | layer: BertLayer = model.roberta.encoder.layer[i]
77 | roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
78 |
79 | ### self attention
80 | self_attn: BertSelfAttention = layer.attention.self
81 | assert(
82 | roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size))
83 | )
84 | # we use three distinct linear layers so we split the source layer here.
85 | self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :]
86 | self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size]
87 | self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :]
88 | self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size]
89 | self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :]
90 | self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:]
91 |
92 | ### self-attention output
93 | self_output: BertSelfOutput = layer.attention.output
94 | assert(
95 | self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
96 | )
97 | self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
98 | self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
99 | self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
100 | self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
101 |
102 | ### intermediate
103 | intermediate: BertIntermediate = layer.intermediate
104 | assert(
105 | intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
106 | )
107 | intermediate.dense.weight = roberta_layer.fc1.weight
108 | intermediate.dense.bias = roberta_layer.fc1.bias
109 |
110 | ### output
111 | bert_output: BertOutput = layer.output
112 | assert(
113 | bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
114 | )
115 | bert_output.dense.weight = roberta_layer.fc2.weight
116 | bert_output.dense.bias = roberta_layer.fc2.bias
117 | bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
118 | bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
119 | #### end of layer
120 |
121 | if classification_head:
122 | model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
123 | model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
124 | model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
125 | model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
126 | else:
127 | # LM Head
128 | model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
129 | model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
130 | model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
131 | model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
132 | model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
133 | model.lm_head.bias = roberta.model.decoder.lm_head.bias
134 |
135 | # Let's check that we get the same results.
136 | input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
137 |
138 | our_output = model(input_ids)[0]
139 | if classification_head:
140 | their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
141 | else:
142 | their_output = roberta.model(input_ids)[0]
143 | print(our_output.shape, their_output.shape)
144 | max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
145 | print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
146 | success = torch.allclose(our_output, their_output, atol=1e-3)
147 | print(
148 | "Do both models output the same tensors?",
149 | "🔥" if success else "💩"
150 | )
151 | if not success:
152 | raise Exception("Something went wRoNg")
153 |
154 | print(f"Saving model to {pytorch_dump_folder_path}")
155 | model.save_pretrained(pytorch_dump_folder_path)
156 |
157 |
158 | if __name__ == "__main__":
159 | parser = argparse.ArgumentParser()
160 | ## Required parameters
161 | parser.add_argument("--roberta_checkpoint_path",
162 | default = None,
163 | type = str,
164 | required = True,
165 | help = "Path the official PyTorch dump.")
166 | parser.add_argument("--pytorch_dump_folder_path",
167 | default = None,
168 | type = str,
169 | required = True,
170 | help = "Path to the output PyTorch model.")
171 | parser.add_argument("--classification_head",
172 | action = "store_true",
173 | help = "Whether to convert a final classification head.")
174 | args = parser.parse_args()
175 | convert_roberta_checkpoint_to_pytorch(
176 | args.roberta_checkpoint_path,
177 | args.pytorch_dump_folder_path,
178 | args.classification_head
179 | )
180 |
181 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import argparse
22 | import torch
23 |
24 | from pytorch_transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
25 |
26 | import logging
27 | logging.basicConfig(level=logging.INFO)
28 |
29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
30 | # Initialise PyTorch model
31 | config = BertConfig.from_json_file(bert_config_file)
32 | print("Building PyTorch model from configuration: {}".format(str(config)))
33 | model = BertForPreTraining(config)
34 |
35 | # Load weights from tf checkpoint
36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path)
37 |
38 | # Save pytorch-model
39 | print("Save PyTorch model to {}".format(pytorch_dump_path))
40 | torch.save(model.state_dict(), pytorch_dump_path)
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | ## Required parameters
46 | parser.add_argument("--tf_checkpoint_path",
47 | default = None,
48 | type = str,
49 | required = True,
50 | help = "Path to the TensorFlow checkpoint path.")
51 | parser.add_argument("--bert_config_file",
52 | default = None,
53 | type = str,
54 | required = True,
55 | help = "The config json file corresponding to the pre-trained BERT model. \n"
56 | "This specifies the model architecture.")
57 | parser.add_argument("--pytorch_dump_path",
58 | default = None,
59 | type = str,
60 | required = True,
61 | help = "Path to the output PyTorch model.")
62 | args = parser.parse_args()
63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
64 | args.bert_config_file,
65 | args.pytorch_dump_path)
66 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert Transformer XL checkpoint and datasets."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import os
21 | import sys
22 | from io import open
23 |
24 | import torch
25 |
26 | import pytorch_transformers.tokenization_transfo_xl as data_utils
27 |
28 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
29 | from pytorch_transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
30 | load_tf_weights_in_transfo_xl)
31 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
32 |
33 | if sys.version_info[0] == 2:
34 | import cPickle as pickle
35 | else:
36 | import pickle
37 |
38 | import logging
39 | logging.basicConfig(level=logging.INFO)
40 |
41 | # We do this to be able to load python 2 datasets pickles
42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
43 | data_utils.Vocab = data_utils.TransfoXLTokenizer
44 | data_utils.Corpus = data_utils.TransfoXLCorpus
45 | sys.modules['data_utils'] = data_utils
46 | sys.modules['vocabulary'] = data_utils
47 |
48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
49 | transfo_xl_config_file,
50 | pytorch_dump_folder_path,
51 | transfo_xl_dataset_file):
52 | if transfo_xl_dataset_file:
53 | # Convert a pre-processed corpus (see original TensorFlow repo)
54 | with open(transfo_xl_dataset_file, "rb") as fp:
55 | corpus = pickle.load(fp, encoding="latin1")
56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
59 | corpus_vocab_dict = corpus.vocab.__dict__
60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
61 |
62 | corpus_dict_no_vocab = corpus.__dict__
63 | corpus_dict_no_vocab.pop('vocab', None)
64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
65 | print("Save dataset to {}".format(pytorch_dataset_dump_path))
66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
67 |
68 | if tf_checkpoint_path:
69 | # Convert a pre-trained TensorFlow model
70 | config_path = os.path.abspath(transfo_xl_config_file)
71 | tf_path = os.path.abspath(tf_checkpoint_path)
72 |
73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
74 | # Initialise PyTorch model
75 | if transfo_xl_config_file == "":
76 | config = TransfoXLConfig()
77 | else:
78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
79 | print("Building PyTorch model from configuration: {}".format(str(config)))
80 | model = TransfoXLLMHeadModel(config)
81 |
82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path)
83 | # Save pytorch-model
84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
87 | torch.save(model.state_dict(), pytorch_weights_dump_path)
88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
90 | f.write(config.to_json_string())
91 |
92 |
93 | if __name__ == "__main__":
94 | parser = argparse.ArgumentParser()
95 | parser.add_argument("--pytorch_dump_folder_path",
96 | default = None,
97 | type = str,
98 | required = True,
99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.")
100 | parser.add_argument("--tf_checkpoint_path",
101 | default = "",
102 | type = str,
103 | help = "An optional path to a TensorFlow checkpoint path to be converted.")
104 | parser.add_argument("--transfo_xl_config_file",
105 | default = "",
106 | type = str,
107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n"
108 | "This specifies the model architecture.")
109 | parser.add_argument("--transfo_xl_dataset_file",
110 | default = "",
111 | type = str,
112 | help = "An optional dataset file to be converted in a vocabulary.")
113 | args = parser.parse_args()
114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
115 | args.transfo_xl_config_file,
116 | args.pytorch_dump_folder_path,
117 | args.transfo_xl_dataset_file)
118 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import json
21 | from io import open
22 |
23 | import torch
24 | import numpy
25 |
26 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
28 |
29 | import logging
30 | logging.basicConfig(level=logging.INFO)
31 |
32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
33 | # Load checkpoint
34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
35 |
36 | model = chkpt['model']
37 |
38 | config = chkpt['params']
39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
40 |
41 | vocab = chkpt['dico_word2id']
42 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items())
43 |
44 | # Save pytorch-model
45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
48 |
49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
50 | torch.save(model, pytorch_weights_dump_path)
51 |
52 | print("Save configuration file to {}".format(pytorch_config_dump_path))
53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
54 | f.write(json.dumps(config, indent=2) + "\n")
55 |
56 | print("Save vocab file to {}".format(pytorch_config_dump_path))
57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f:
58 | f.write(json.dumps(vocab, indent=2) + "\n")
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | ## Required parameters
64 | parser.add_argument("--xlm_checkpoint_path",
65 | default = None,
66 | type = str,
67 | required = True,
68 | help = "Path the official PyTorch dump.")
69 | parser.add_argument("--pytorch_dump_folder_path",
70 | default = None,
71 | type = str,
72 | required = True,
73 | help = "Path to the output PyTorch model.")
74 | args = parser.parse_args()
75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
76 |
--------------------------------------------------------------------------------
/pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import argparse
23 | import torch
24 |
25 | from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
26 | XLNetConfig,
27 | XLNetLMHeadModel, XLNetForQuestionAnswering,
28 | XLNetForSequenceClassification,
29 | load_tf_weights_in_xlnet)
30 |
31 | GLUE_TASKS_NUM_LABELS = {
32 | "cola": 2,
33 | "mnli": 3,
34 | "mrpc": 2,
35 | "sst-2": 2,
36 | "sts-b": 1,
37 | "qqp": 2,
38 | "qnli": 2,
39 | "rte": 2,
40 | "wnli": 2,
41 | }
42 |
43 | import logging
44 | logging.basicConfig(level=logging.INFO)
45 |
46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
47 | # Initialise PyTorch model
48 | config = XLNetConfig.from_json_file(bert_config_file)
49 |
50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
51 | if finetuning_task in GLUE_TASKS_NUM_LABELS:
52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
53 | config.finetuning_task = finetuning_task
54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
55 | model = XLNetForSequenceClassification(config)
56 | elif 'squad' in finetuning_task:
57 | config.finetuning_task = finetuning_task
58 | model = XLNetForQuestionAnswering(config)
59 | else:
60 | model = XLNetLMHeadModel(config)
61 |
62 | # Load weights from tf checkpoint
63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)
64 |
65 | # Save pytorch-model
66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
69 | torch.save(model.state_dict(), pytorch_weights_dump_path)
70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
72 | f.write(config.to_json_string())
73 |
74 |
75 | if __name__ == "__main__":
76 | parser = argparse.ArgumentParser()
77 | ## Required parameters
78 | parser.add_argument("--tf_checkpoint_path",
79 | default = None,
80 | type = str,
81 | required = True,
82 | help = "Path to the TensorFlow checkpoint path.")
83 | parser.add_argument("--xlnet_config_file",
84 | default = None,
85 | type = str,
86 | required = True,
87 | help = "The config json file corresponding to the pre-trained XLNet model. \n"
88 | "This specifies the model architecture.")
89 | parser.add_argument("--pytorch_dump_folder_path",
90 | default = None,
91 | type = str,
92 | required = True,
93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.")
94 | parser.add_argument("--finetuning_task",
95 | default = None,
96 | type = str,
97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
98 | args = parser.parse_args()
99 | print(args)
100 |
101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
102 | args.xlnet_config_file,
103 | args.pytorch_dump_folder_path,
104 | args.finetuning_task)
105 |
--------------------------------------------------------------------------------
/pytorch_transformers/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import logging
18 | import math
19 |
20 | import torch
21 | from torch.optim import Optimizer
22 | from torch.optim.lr_scheduler import LambdaLR
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | class ConstantLRSchedule(LambdaLR):
27 | """ Constant learning rate schedule.
28 | """
29 | def __init__(self, optimizer, last_epoch=-1):
30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
31 |
32 |
33 | class WarmupConstantSchedule(LambdaLR):
34 | """ Linear warmup and then constant.
35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
36 | Keeps learning rate schedule equal to 1. after warmup_steps.
37 | """
38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1):
39 | self.warmup_steps = warmup_steps
40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
41 |
42 | def lr_lambda(self, step):
43 | if step < self.warmup_steps:
44 | return float(step) / float(max(1.0, self.warmup_steps))
45 | return 1.
46 |
47 |
48 | class WarmupLinearSchedule(LambdaLR):
49 | """ Linear warmup and then linear decay.
50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
52 | """
53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
54 | self.warmup_steps = warmup_steps
55 | self.t_total = t_total
56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
57 |
58 | def lr_lambda(self, step):
59 | if step < self.warmup_steps:
60 | return float(step) / float(max(1, self.warmup_steps))
61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
62 |
63 |
64 | class WarmupCosineSchedule(LambdaLR):
65 | """ Linear warmup and then cosine decay.
66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
69 | """
70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
71 | self.warmup_steps = warmup_steps
72 | self.t_total = t_total
73 | self.cycles = cycles
74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
75 |
76 | def lr_lambda(self, step):
77 | if step < self.warmup_steps:
78 | return float(step) / float(max(1.0, self.warmup_steps))
79 | # progress after warmup
80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
82 |
83 |
84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR):
85 | """ Linear warmup and then cosine cycles with hard restarts.
86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
88 | learning rate (with hard restarts).
89 | """
90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
91 | self.warmup_steps = warmup_steps
92 | self.t_total = t_total
93 | self.cycles = cycles
94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
95 |
96 | def lr_lambda(self, step):
97 | if step < self.warmup_steps:
98 | return float(step) / float(max(1, self.warmup_steps))
99 | # progress after warmup
100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
101 | if progress >= 1.0:
102 | return 0.0
103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
104 |
105 |
106 |
107 | class AdamW(Optimizer):
108 | """ Implements Adam algorithm with weight decay fix.
109 |
110 | Parameters:
111 | lr (float): learning rate. Default 1e-3.
112 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
113 | eps (float): Adams epsilon. Default: 1e-6
114 | weight_decay (float): Weight decay. Default: 0.0
115 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
116 | """
117 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
118 | if lr < 0.0:
119 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
120 | if not 0.0 <= betas[0] < 1.0:
121 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
122 | if not 0.0 <= betas[1] < 1.0:
123 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
124 | if not 0.0 <= eps:
125 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
126 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
127 | correct_bias=correct_bias)
128 | super(AdamW, self).__init__(params, defaults)
129 |
130 | def step(self, closure=None):
131 | """Performs a single optimization step.
132 |
133 | Arguments:
134 | closure (callable, optional): A closure that reevaluates the model
135 | and returns the loss.
136 | """
137 | loss = None
138 | if closure is not None:
139 | loss = closure()
140 |
141 | for group in self.param_groups:
142 | for p in group['params']:
143 | if p.grad is None:
144 | continue
145 | grad = p.grad.data
146 | if grad.is_sparse:
147 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
148 |
149 | state = self.state[p]
150 |
151 | # State initialization
152 | if len(state) == 0:
153 | state['step'] = 0
154 | # Exponential moving average of gradient values
155 | state['exp_avg'] = torch.zeros_like(p.data)
156 | # Exponential moving average of squared gradient values
157 | state['exp_avg_sq'] = torch.zeros_like(p.data)
158 |
159 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
160 | beta1, beta2 = group['betas']
161 |
162 | state['step'] += 1
163 |
164 | # Decay the first and second moment running average coefficient
165 | # In-place operations to update the averages at the same time
166 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
167 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
168 | denom = exp_avg_sq.sqrt().add_(group['eps'])
169 |
170 | step_size = group['lr']
171 | if group['correct_bias']: # No bias correction for Bert
172 | bias_correction1 = 1.0 - beta1 ** state['step']
173 | bias_correction2 = 1.0 - beta2 ** state['step']
174 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
175 |
176 | p.data.addcdiv_(-step_size, exp_avg, denom)
177 |
178 | # Just adding the square of the weights to the loss function is *not*
179 | # the correct way of using L2 regularization/weight decay with Adam,
180 | # since that will interact with the m and v parameters in strange ways.
181 | #
182 | # Instead we want to decay the weights in a manner that doesn't interact
183 | # with the m/v parameters. This is equivalent to adding the square
184 | # of the weights to the loss with plain (non-momentum) SGD.
185 | # Add weight decay at the end (fixed version)
186 | if group['weight_decay'] > 0.0:
187 | p.data.add_(-group['lr'] * group['weight_decay'], p.data)
188 |
189 | return loss
190 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/pytorch_transformers/tests/__init__.py
--------------------------------------------------------------------------------
/pytorch_transformers/tests/configuration_common_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import copy
20 | import os
21 | import shutil
22 | import json
23 | import random
24 | import uuid
25 |
26 | import unittest
27 | import logging
28 |
29 |
30 | class ConfigTester(object):
31 | def __init__(self, parent, config_class=None, **kwargs):
32 | self.parent = parent
33 | self.config_class = config_class
34 | self.inputs_dict = kwargs
35 |
36 | def create_and_test_config_common_properties(self):
37 | config = self.config_class(**self.inputs_dict)
38 | self.parent.assertTrue(hasattr(config, 'vocab_size'))
39 | self.parent.assertTrue(hasattr(config, 'hidden_size'))
40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads'))
41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers'))
42 |
43 | def create_and_test_config_to_json_string(self):
44 | config = self.config_class(**self.inputs_dict)
45 | obj = json.loads(config.to_json_string())
46 | for key, value in self.inputs_dict.items():
47 | self.parent.assertEqual(obj[key], value)
48 |
49 | def create_and_test_config_to_json_file(self):
50 | config_first = self.config_class(**self.inputs_dict)
51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
52 | config_first.to_json_file(json_file_path)
53 | config_second = self.config_class.from_json_file(json_file_path)
54 | os.remove(json_file_path)
55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
56 |
57 | def run_common_tests(self):
58 | self.create_and_test_config_common_properties()
59 | self.create_and_test_config_to_json_string()
60 | self.create_and_test_config_to_json_file()
61 |
62 | if __name__ == "__main__":
63 | unittest.main()
--------------------------------------------------------------------------------
/pytorch_transformers/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # content of conftest.py
2 |
3 | import pytest
4 |
5 |
6 | def pytest_addoption(parser):
7 | parser.addoption(
8 | "--runslow", action="store_true", default=False, help="run slow tests"
9 | )
10 |
11 |
12 | def pytest_collection_modifyitems(config, items):
13 | if config.getoption("--runslow"):
14 | # --runslow given in cli: do not skip slow tests
15 | return
16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run")
17 | for item in items:
18 | if "slow" in item.keywords:
19 | item.add_marker(skip_slow)
20 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/fixtures/input.txt:
--------------------------------------------------------------------------------
1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer
2 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/fixtures/sample_text.txt:
--------------------------------------------------------------------------------
1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
2 | Text should be one-sentence-per-line, with empty lines between documents.
3 | This sample text is public domain and was randomly selected from Project Guttenberg.
4 |
5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
8 | "Cass" Beard had risen early that morning, but not with a view to discovery.
9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
11 | This was nearly opposite.
12 | Mr. Cassius crossed the highway, and stopped suddenly.
13 | Something glittered in the nearest red pool before him.
14 | Gold, surely!
15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
17 | Like most of his fellow gold-seekers, Cass was superstitious.
18 |
19 | The fountain of classic wisdom, Hypatia herself.
20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
21 | From my youth I felt in me a soul above the matter-entangled herd.
22 | She revealed to me the glorious fact, that I am a spark of Divinity itself.
23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
24 | There is a philosophic pleasure in opening one's treasures to the modest young.
25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
31 | At last they reached the quay at the opposite end of the street;
32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
34 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/fixtures/test_sentencepiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyease/TADAM/683c9e971bef06a93037cf8481e668415f743f04/pytorch_transformers/tests/fixtures/test_sentencepiece.model
--------------------------------------------------------------------------------
/pytorch_transformers/tests/modeling_auto_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 | import pytest
22 | import logging
23 |
24 | from pytorch_transformers import (AutoConfig, BertConfig,
25 | AutoModel, BertModel,
26 | AutoModelWithLMHead, BertForMaskedLM,
27 | AutoModelForSequenceClassification, BertForSequenceClassification,
28 | AutoModelForQuestionAnswering, BertForQuestionAnswering)
29 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
30 |
31 | from .modeling_common_test import (CommonTestCases, ids_tensor)
32 | from .configuration_common_test import ConfigTester
33 |
34 |
35 | class AutoModelTest(unittest.TestCase):
36 | def test_model_from_pretrained(self):
37 | logging.basicConfig(level=logging.INFO)
38 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
39 | config = AutoConfig.from_pretrained(model_name)
40 | self.assertIsNotNone(config)
41 | self.assertIsInstance(config, BertConfig)
42 |
43 | model = AutoModel.from_pretrained(model_name)
44 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True)
45 | self.assertIsNotNone(model)
46 | self.assertIsInstance(model, BertModel)
47 | for value in loading_info.values():
48 | self.assertEqual(len(value), 0)
49 |
50 | def test_lmhead_model_from_pretrained(self):
51 | logging.basicConfig(level=logging.INFO)
52 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
53 | config = AutoConfig.from_pretrained(model_name)
54 | self.assertIsNotNone(config)
55 | self.assertIsInstance(config, BertConfig)
56 |
57 | model = AutoModelWithLMHead.from_pretrained(model_name)
58 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True)
59 | self.assertIsNotNone(model)
60 | self.assertIsInstance(model, BertForMaskedLM)
61 |
62 | def test_sequence_classification_model_from_pretrained(self):
63 | logging.basicConfig(level=logging.INFO)
64 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
65 | config = AutoConfig.from_pretrained(model_name)
66 | self.assertIsNotNone(config)
67 | self.assertIsInstance(config, BertConfig)
68 |
69 | model = AutoModelForSequenceClassification.from_pretrained(model_name)
70 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True)
71 | self.assertIsNotNone(model)
72 | self.assertIsInstance(model, BertForSequenceClassification)
73 |
74 | def test_question_answering_model_from_pretrained(self):
75 | logging.basicConfig(level=logging.INFO)
76 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
77 | config = AutoConfig.from_pretrained(model_name)
78 | self.assertIsNotNone(config)
79 | self.assertIsInstance(config, BertConfig)
80 |
81 | model = AutoModelForQuestionAnswering.from_pretrained(model_name)
82 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True)
83 | self.assertIsNotNone(model)
84 | self.assertIsInstance(model, BertForQuestionAnswering)
85 |
86 |
87 | if __name__ == "__main__":
88 | unittest.main()
89 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import os
21 |
22 | import torch
23 |
24 | from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
25 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
26 |
27 | from .tokenization_tests_commons import TemporaryDirectory
28 |
29 |
30 | def unwrap_schedule(scheduler, num_steps=10):
31 | lrs = []
32 | for _ in range(num_steps):
33 | scheduler.step()
34 | lrs.append(scheduler.get_lr())
35 | return lrs
36 |
37 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
38 | lrs = []
39 | for step in range(num_steps):
40 | scheduler.step()
41 | lrs.append(scheduler.get_lr())
42 | if step == num_steps // 2:
43 | with TemporaryDirectory() as tmpdirname:
44 | file_name = os.path.join(tmpdirname, 'schedule.bin')
45 | torch.save(scheduler.state_dict(), file_name)
46 |
47 | state_dict = torch.load(file_name)
48 | scheduler.load_state_dict(state_dict)
49 | return lrs
50 |
51 | class OptimizationTest(unittest.TestCase):
52 |
53 | def assertListAlmostEqual(self, list1, list2, tol):
54 | self.assertEqual(len(list1), len(list2))
55 | for a, b in zip(list1, list2):
56 | self.assertAlmostEqual(a, b, delta=tol)
57 |
58 | def test_adam_w(self):
59 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
60 | target = torch.tensor([0.4, 0.2, -0.5])
61 | criterion = torch.nn.MSELoss()
62 | # No warmup, constant schedule, no gradient clipping
63 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
64 | for _ in range(100):
65 | loss = criterion(w, target)
66 | loss.backward()
67 | optimizer.step()
68 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
69 | w.grad.zero_()
70 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
71 |
72 |
73 | class ScheduleInitTest(unittest.TestCase):
74 | m = torch.nn.Linear(50, 50)
75 | optimizer = AdamW(m.parameters(), lr=10.)
76 | num_steps = 10
77 |
78 | def assertListAlmostEqual(self, list1, list2, tol):
79 | self.assertEqual(len(list1), len(list2))
80 | for a, b in zip(list1, list2):
81 | self.assertAlmostEqual(a, b, delta=tol)
82 |
83 | def test_constant_scheduler(self):
84 | scheduler = ConstantLRSchedule(self.optimizer)
85 | lrs = unwrap_schedule(scheduler, self.num_steps)
86 | expected_learning_rates = [10.] * self.num_steps
87 | self.assertEqual(len(lrs[0]), 1)
88 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
89 |
90 | scheduler = ConstantLRSchedule(self.optimizer)
91 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
92 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
93 |
94 | def test_warmup_constant_scheduler(self):
95 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
96 | lrs = unwrap_schedule(scheduler, self.num_steps)
97 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
98 | self.assertEqual(len(lrs[0]), 1)
99 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
100 |
101 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
102 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
103 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
104 |
105 | def test_warmup_linear_scheduler(self):
106 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
107 | lrs = unwrap_schedule(scheduler, self.num_steps)
108 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
109 | self.assertEqual(len(lrs[0]), 1)
110 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
111 |
112 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
113 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
114 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
115 |
116 | def test_warmup_cosine_scheduler(self):
117 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
118 | lrs = unwrap_schedule(scheduler, self.num_steps)
119 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
120 | self.assertEqual(len(lrs[0]), 1)
121 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
122 |
123 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
124 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
125 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
126 |
127 | def test_warmup_cosine_hard_restart_scheduler(self):
128 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
129 | lrs = unwrap_schedule(scheduler, self.num_steps)
130 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
131 | self.assertEqual(len(lrs[0]), 1)
132 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
133 |
134 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
135 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
136 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
137 |
138 | if __name__ == "__main__":
139 | unittest.main()
140 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_auto_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 | import pytest
22 | import logging
23 |
24 | from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
26 | from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
27 |
28 |
29 | class AutoTokenizerTest(unittest.TestCase):
30 | def test_tokenizer_from_pretrained(self):
31 | logging.basicConfig(level=logging.INFO)
32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
33 | tokenizer = AutoTokenizer.from_pretrained(model_name)
34 | self.assertIsNotNone(tokenizer)
35 | self.assertIsInstance(tokenizer, BertTokenizer)
36 | self.assertGreater(len(tokenizer), 0)
37 |
38 | for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
39 | tokenizer = AutoTokenizer.from_pretrained(model_name)
40 | self.assertIsNotNone(tokenizer)
41 | self.assertIsInstance(tokenizer, GPT2Tokenizer)
42 | self.assertGreater(len(tokenizer), 0)
43 |
44 |
45 | if __name__ == "__main__":
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_bert_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from pytorch_transformers.tokenization_bert import (BasicTokenizer,
22 | BertTokenizer,
23 | WordpieceTokenizer,
24 | _is_control, _is_punctuation,
25 | _is_whitespace, VOCAB_FILES_NAMES)
26 |
27 | from .tokenization_tests_commons import CommonTestCases
28 |
29 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
30 |
31 | tokenizer_class = BertTokenizer
32 |
33 | def setUp(self):
34 | super(BertTokenizationTest, self).setUp()
35 |
36 | vocab_tokens = [
37 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
38 | "##ing", ",", "low", "lowest",
39 | ]
40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
41 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
42 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
43 |
44 | def get_tokenizer(self, **kwargs):
45 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
46 |
47 | def get_input_output_texts(self):
48 | input_text = u"UNwant\u00E9d,running"
49 | output_text = u"unwanted, running"
50 | return input_text, output_text
51 |
52 | def test_full_tokenizer(self):
53 | tokenizer = self.tokenizer_class(self.vocab_file)
54 |
55 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
56 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
57 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
58 |
59 | def test_chinese(self):
60 | tokenizer = BasicTokenizer()
61 |
62 | self.assertListEqual(
63 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
64 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
65 |
66 | def test_basic_tokenizer_lower(self):
67 | tokenizer = BasicTokenizer(do_lower_case=True)
68 |
69 | self.assertListEqual(
70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
71 | ["hello", "!", "how", "are", "you", "?"])
72 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
73 |
74 | def test_basic_tokenizer_no_lower(self):
75 | tokenizer = BasicTokenizer(do_lower_case=False)
76 |
77 | self.assertListEqual(
78 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
79 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
80 |
81 | def test_wordpiece_tokenizer(self):
82 | vocab_tokens = [
83 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
84 | "##ing"
85 | ]
86 |
87 | vocab = {}
88 | for (i, token) in enumerate(vocab_tokens):
89 | vocab[token] = i
90 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
91 |
92 | self.assertListEqual(tokenizer.tokenize(""), [])
93 |
94 | self.assertListEqual(
95 | tokenizer.tokenize("unwanted running"),
96 | ["un", "##want", "##ed", "runn", "##ing"])
97 |
98 | self.assertListEqual(
99 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
100 |
101 | def test_is_whitespace(self):
102 | self.assertTrue(_is_whitespace(u" "))
103 | self.assertTrue(_is_whitespace(u"\t"))
104 | self.assertTrue(_is_whitespace(u"\r"))
105 | self.assertTrue(_is_whitespace(u"\n"))
106 | self.assertTrue(_is_whitespace(u"\u00A0"))
107 |
108 | self.assertFalse(_is_whitespace(u"A"))
109 | self.assertFalse(_is_whitespace(u"-"))
110 |
111 | def test_is_control(self):
112 | self.assertTrue(_is_control(u"\u0005"))
113 |
114 | self.assertFalse(_is_control(u"A"))
115 | self.assertFalse(_is_control(u" "))
116 | self.assertFalse(_is_control(u"\t"))
117 | self.assertFalse(_is_control(u"\r"))
118 |
119 | def test_is_punctuation(self):
120 | self.assertTrue(_is_punctuation(u"-"))
121 | self.assertTrue(_is_punctuation(u"$"))
122 | self.assertTrue(_is_punctuation(u"`"))
123 | self.assertTrue(_is_punctuation(u"."))
124 |
125 | self.assertFalse(_is_punctuation(u"A"))
126 | self.assertFalse(_is_punctuation(u" "))
127 |
128 | def test_sequence_builders(self):
129 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
130 |
131 | text = tokenizer.encode("sequence builders")
132 | text_2 = tokenizer.encode("multi-sequence build")
133 |
134 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
135 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
136 |
137 | assert encoded_sentence == [101] + text + [102]
138 | assert encoded_pair == [101] + text + [102] + text_2 + [102]
139 |
140 | if __name__ == '__main__':
141 | unittest.main()
142 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_dilbert_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer)
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 | from .tokenization_bert_test import BertTokenizationTest
25 |
26 | class DistilBertTokenizationTest(BertTokenizationTest):
27 |
28 | tokenizer_class = DistilBertTokenizer
29 |
30 | def get_tokenizer(self, **kwargs):
31 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
32 |
33 | def test_sequence_builders(self):
34 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
35 |
36 | text = tokenizer.encode("sequence builders")
37 | text_2 = tokenizer.encode("multi-sequence build")
38 |
39 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
40 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
41 |
42 | assert encoded_sentence == [101] + text + [102]
43 | assert encoded_pair == [101] + text + [102] + text_2 + [102]
44 |
45 | if __name__ == '__main__':
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_gpt2_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | import json
20 | from io import open
21 |
22 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
23 |
24 | from .tokenization_tests_commons import CommonTestCases
25 |
26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
27 |
28 | tokenizer_class = GPT2Tokenizer
29 |
30 | def setUp(self):
31 | super(GPT2TokenizationTest, self).setUp()
32 |
33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
35 | "\u0120", "\u0120l", "\u0120n",
36 | "\u0120lo", "\u0120low", "er",
37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""]
38 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
40 | self.special_tokens_map = {"unk_token": ""}
41 |
42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
44 | with open(self.vocab_file, "w", encoding="utf-8") as fp:
45 | fp.write(json.dumps(vocab_tokens) + "\n")
46 | with open(self.merges_file, "w", encoding="utf-8") as fp:
47 | fp.write("\n".join(merges))
48 |
49 | def get_tokenizer(self, **kwargs):
50 | kwargs.update(self.special_tokens_map)
51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
52 |
53 | def get_input_output_texts(self):
54 | input_text = u"lower newer"
55 | output_text = u" lower newer"
56 | return input_text, output_text
57 |
58 | def test_full_tokenizer(self):
59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
60 | text = "lower newer"
61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
62 | tokens = tokenizer.tokenize(text)
63 | self.assertListEqual(tokens, bpe_tokens)
64 |
65 | input_tokens = tokens + [tokenizer.unk_token]
66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
67 | self.assertListEqual(
68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
69 |
70 |
71 | if __name__ == '__main__':
72 | unittest.main()
73 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_openai_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | import json
20 |
21 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 |
25 |
26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
27 |
28 | tokenizer_class = OpenAIGPTTokenizer
29 |
30 | def setUp(self):
31 | super(OpenAIGPTTokenizationTest, self).setUp()
32 |
33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
35 | "w", "r", "t",
36 | "lo", "low", "er",
37 | "low", "lowest", "newer", "wider", ""]
38 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
40 |
41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
43 | with open(self.vocab_file, "w") as fp:
44 | fp.write(json.dumps(vocab_tokens))
45 | with open(self.merges_file, "w") as fp:
46 | fp.write("\n".join(merges))
47 |
48 | def get_tokenizer(self, **kwargs):
49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
50 |
51 | def get_input_output_texts(self):
52 | input_text = u"lower newer"
53 | output_text = u"lower newer"
54 | return input_text, output_text
55 |
56 |
57 | def test_full_tokenizer(self):
58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
59 |
60 | text = "lower"
61 | bpe_tokens = ["low", "er"]
62 | tokens = tokenizer.tokenize(text)
63 | self.assertListEqual(tokens, bpe_tokens)
64 |
65 | input_tokens = tokens + [""]
66 | input_bpe_tokens = [14, 15, 20]
67 | self.assertListEqual(
68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
69 |
70 |
71 | if __name__ == '__main__':
72 | unittest.main()
73 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_roberta_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import json
19 | import unittest
20 | from io import open
21 |
22 | from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
23 | from .tokenization_tests_commons import CommonTestCases
24 |
25 |
26 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
27 | tokenizer_class = RobertaTokenizer
28 |
29 | def setUp(self):
30 | super(RobertaTokenizationTest, self).setUp()
31 |
32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
34 | "\u0120", "\u0120l", "\u0120n",
35 | "\u0120lo", "\u0120low", "er",
36 | "\u0120lowest", "\u0120newer", "\u0120wider", ""]
37 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
38 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
39 | self.special_tokens_map = {"unk_token": ""}
40 |
41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
43 | with open(self.vocab_file, "w", encoding="utf-8") as fp:
44 | fp.write(json.dumps(vocab_tokens) + "\n")
45 | with open(self.merges_file, "w", encoding="utf-8") as fp:
46 | fp.write("\n".join(merges))
47 |
48 | def get_tokenizer(self, **kwargs):
49 | kwargs.update(self.special_tokens_map)
50 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
51 |
52 | def get_input_output_texts(self):
53 | input_text = u"lower newer"
54 | output_text = u" lower newer"
55 | return input_text, output_text
56 |
57 | def test_full_tokenizer(self):
58 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
59 | text = "lower newer"
60 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
61 | tokens = tokenizer.tokenize(text)
62 | self.assertListEqual(tokens, bpe_tokens)
63 |
64 | input_tokens = tokens + [tokenizer.unk_token]
65 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
66 | self.assertListEqual(
67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
68 |
69 | def roberta_dict_integration_testing(self):
70 | tokenizer = self.get_tokenizer()
71 |
72 | self.assertListEqual(
73 | tokenizer.encode('Hello world!'),
74 | [0, 31414, 232, 328, 2]
75 | )
76 | self.assertListEqual(
77 | tokenizer.encode('Hello world! cécé herlolip 418'),
78 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
79 | )
80 |
81 | def test_sequence_builders(self):
82 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
83 |
84 | text = tokenizer.encode("sequence builders")
85 | text_2 = tokenizer.encode("multi-sequence build")
86 |
87 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
88 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
89 |
90 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
91 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
92 |
93 | assert encoded_sentence == encoded_text_from_decode
94 | assert encoded_pair == encoded_pair_from_decode
95 |
96 |
97 | if __name__ == '__main__':
98 | unittest.main()
99 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_tests_commons.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import sys
19 | from io import open
20 | import tempfile
21 | import shutil
22 | import unittest
23 |
24 | if sys.version_info[0] == 2:
25 | import cPickle as pickle
26 |
27 | class TemporaryDirectory(object):
28 | """Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
29 | def __enter__(self):
30 | self.name = tempfile.mkdtemp()
31 | return self.name
32 | def __exit__(self, exc_type, exc_value, traceback):
33 | shutil.rmtree(self.name)
34 | else:
35 | import pickle
36 | TemporaryDirectory = tempfile.TemporaryDirectory
37 | unicode = str
38 |
39 |
40 | class CommonTestCases:
41 |
42 | class CommonTokenizerTester(unittest.TestCase):
43 |
44 | tokenizer_class = None
45 |
46 | def setUp(self):
47 | self.tmpdirname = tempfile.mkdtemp()
48 |
49 | def tearDown(self):
50 | shutil.rmtree(self.tmpdirname)
51 |
52 | def get_tokenizer(self, **kwargs):
53 | raise NotImplementedError
54 |
55 | def get_input_output_texts(self):
56 | raise NotImplementedError
57 |
58 | def test_tokenizers_common_properties(self):
59 | tokenizer = self.get_tokenizer()
60 | attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token",
61 | "pad_token", "cls_token", "mask_token"]
62 | for attr in attributes_list:
63 | self.assertTrue(hasattr(tokenizer, attr))
64 | self.assertTrue(hasattr(tokenizer, attr + "_id"))
65 |
66 | self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
67 | self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids'))
68 |
69 | attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder",
70 | "added_tokens_decoder"]
71 | for attr in attributes_list:
72 | self.assertTrue(hasattr(tokenizer, attr))
73 |
74 | def test_save_and_load_tokenizer(self):
75 | # safety check on max_len default value so we are sure the test works
76 | tokenizer = self.get_tokenizer()
77 | self.assertNotEqual(tokenizer.max_len, 42)
78 |
79 | # Now let's start the test
80 | tokenizer = self.get_tokenizer(max_len=42)
81 |
82 | before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
83 |
84 | with TemporaryDirectory() as tmpdirname:
85 | tokenizer.save_pretrained(tmpdirname)
86 | tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
87 |
88 | after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
89 | self.assertListEqual(before_tokens, after_tokens)
90 |
91 | self.assertEqual(tokenizer.max_len, 42)
92 | tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
93 | self.assertEqual(tokenizer.max_len, 43)
94 |
95 | def test_pickle_tokenizer(self):
96 | tokenizer = self.get_tokenizer()
97 | self.assertIsNotNone(tokenizer)
98 |
99 | text = u"Munich and Berlin are nice cities"
100 | subwords = tokenizer.tokenize(text)
101 |
102 | with TemporaryDirectory() as tmpdirname:
103 |
104 | filename = os.path.join(tmpdirname, u"tokenizer.bin")
105 | pickle.dump(tokenizer, open(filename, "wb"))
106 |
107 | tokenizer_new = pickle.load(open(filename, "rb"))
108 |
109 | subwords_loaded = tokenizer_new.tokenize(text)
110 |
111 | self.assertListEqual(subwords, subwords_loaded)
112 |
113 |
114 | def test_add_tokens_tokenizer(self):
115 | tokenizer = self.get_tokenizer()
116 |
117 | vocab_size = tokenizer.vocab_size
118 | all_size = len(tokenizer)
119 |
120 | self.assertNotEqual(vocab_size, 0)
121 | self.assertEqual(vocab_size, all_size)
122 |
123 | new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
124 | added_toks = tokenizer.add_tokens(new_toks)
125 | vocab_size_2 = tokenizer.vocab_size
126 | all_size_2 = len(tokenizer)
127 |
128 | self.assertNotEqual(vocab_size_2, 0)
129 | self.assertEqual(vocab_size, vocab_size_2)
130 | self.assertEqual(added_toks, len(new_toks))
131 | self.assertEqual(all_size_2, all_size + len(new_toks))
132 |
133 | tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l")
134 | out_string = tokenizer.decode(tokens)
135 |
136 | self.assertGreaterEqual(len(tokens), 4)
137 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
138 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
139 |
140 | new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
141 | 'pad_token': "<<<<<|||>|>>>>|>"}
142 | added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
143 | vocab_size_3 = tokenizer.vocab_size
144 | all_size_3 = len(tokenizer)
145 |
146 | self.assertNotEqual(vocab_size_3, 0)
147 | self.assertEqual(vocab_size, vocab_size_3)
148 | self.assertEqual(added_toks_2, len(new_toks_2))
149 | self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
150 |
151 | tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
152 | out_string = tokenizer.decode(tokens)
153 |
154 | self.assertGreaterEqual(len(tokens), 6)
155 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
156 | self.assertGreater(tokens[0], tokens[1])
157 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
158 | self.assertGreater(tokens[-2], tokens[-3])
159 | self.assertEqual(tokens[0], tokenizer.eos_token_id)
160 | self.assertEqual(tokens[-2], tokenizer.pad_token_id)
161 |
162 |
163 | def test_required_methods_tokenizer(self):
164 | tokenizer = self.get_tokenizer()
165 | input_text, output_text = self.get_input_output_texts()
166 |
167 | tokens = tokenizer.tokenize(input_text)
168 | ids = tokenizer.convert_tokens_to_ids(tokens)
169 | ids_2 = tokenizer.encode(input_text)
170 | self.assertListEqual(ids, ids_2)
171 |
172 | tokens_2 = tokenizer.convert_ids_to_tokens(ids)
173 | text_2 = tokenizer.decode(ids)
174 |
175 | self.assertEqual(text_2, output_text)
176 |
177 | self.assertNotEqual(len(tokens_2), 0)
178 | self.assertIsInstance(text_2, (str, unicode))
179 |
180 |
181 | def test_pretrained_model_lists(self):
182 | weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
183 | weights_lists_2 = []
184 | for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
185 | weights_lists_2.append(list(map_list.keys()))
186 |
187 | for weights_list_2 in weights_lists_2:
188 | self.assertListEqual(weights_list, weights_list_2)
189 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_transfo_xl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
22 |
23 | from.tokenization_tests_commons import CommonTestCases
24 |
25 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
26 |
27 | tokenizer_class = TransfoXLTokenizer
28 |
29 | def setUp(self):
30 | super(TransfoXLTokenizationTest, self).setUp()
31 |
32 | vocab_tokens = [
33 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
34 | "running", ",", "low", "l",
35 | ]
36 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
37 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
38 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
39 |
40 | def get_tokenizer(self, **kwargs):
41 | kwargs['lower_case'] = True
42 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
43 |
44 | def get_input_output_texts(self):
45 | input_text = u" UNwanted , running"
46 | output_text = u" unwanted, running"
47 | return input_text, output_text
48 |
49 | def test_full_tokenizer(self):
50 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
51 |
52 | tokens = tokenizer.tokenize(u" UNwanted , running")
53 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"])
54 |
55 | self.assertListEqual(
56 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
57 |
58 | def test_full_tokenizer_lower(self):
59 | tokenizer = TransfoXLTokenizer(lower_case=True)
60 |
61 | self.assertListEqual(
62 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
63 | ["hello", "!", "how", "are", "you", "?"])
64 |
65 | def test_full_tokenizer_no_lower(self):
66 | tokenizer = TransfoXLTokenizer(lower_case=False)
67 |
68 | self.assertListEqual(
69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
70 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
71 |
72 |
73 | if __name__ == '__main__':
74 | unittest.main()
75 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 HuggingFace Inc..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import six
21 |
22 | from pytorch_transformers import PreTrainedTokenizer
23 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
24 |
25 | class TokenizerUtilsTest(unittest.TestCase):
26 | def check_tokenizer_from_pretrained(self, tokenizer_class):
27 | s3_models = list(tokenizer_class.max_model_input_sizes.keys())
28 | for model_name in s3_models[:1]:
29 | tokenizer = tokenizer_class.from_pretrained(model_name)
30 | self.assertIsNotNone(tokenizer)
31 | self.assertIsInstance(tokenizer, tokenizer_class)
32 | self.assertIsInstance(tokenizer, PreTrainedTokenizer)
33 |
34 | for special_tok in tokenizer.all_special_tokens:
35 | if six.PY2:
36 | self.assertIsInstance(special_tok, unicode)
37 | else:
38 | self.assertIsInstance(special_tok, str)
39 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
40 | self.assertIsInstance(special_tok_id, int)
41 |
42 | def test_pretrained_tokenizers(self):
43 | self.check_tokenizer_from_pretrained(GPT2Tokenizer)
44 |
45 | if __name__ == "__main__":
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_xlm_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | import json
20 |
21 | from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 |
25 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
26 |
27 | tokenizer_class = XLMTokenizer
28 |
29 | def setUp(self):
30 | super(XLMTokenizationTest, self).setUp()
31 |
32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
34 | "w", "r", "t",
35 | "lo", "low", "er",
36 | "low", "lowest", "newer", "wider", ""]
37 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
38 | merges = ["l o 123", "lo w 1456", "e r 1789", ""]
39 |
40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
42 | with open(self.vocab_file, "w") as fp:
43 | fp.write(json.dumps(vocab_tokens))
44 | with open(self.merges_file, "w") as fp:
45 | fp.write("\n".join(merges))
46 |
47 | def get_tokenizer(self, **kwargs):
48 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
49 |
50 | def get_input_output_texts(self):
51 | input_text = u"lower newer"
52 | output_text = u"lower newer"
53 | return input_text, output_text
54 |
55 | def test_full_tokenizer(self):
56 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
57 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
58 |
59 | text = "lower"
60 | bpe_tokens = ["low", "er"]
61 | tokens = tokenizer.tokenize(text)
62 | self.assertListEqual(tokens, bpe_tokens)
63 |
64 | input_tokens = tokens + [""]
65 | input_bpe_tokens = [14, 15, 20]
66 | self.assertListEqual(
67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
68 |
69 | def test_sequence_builders(self):
70 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
71 |
72 | text = tokenizer.encode("sequence builders")
73 | text_2 = tokenizer.encode("multi-sequence build")
74 |
75 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
76 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
77 |
78 | assert encoded_sentence == [1] + text + [1]
79 | assert encoded_pair == [1] + text + [1] + text_2 + [1]
80 |
81 | if __name__ == '__main__':
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/pytorch_transformers/tests/tokenization_xlnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 |
20 | from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
21 |
22 | from .tokenization_tests_commons import CommonTestCases
23 |
24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
25 | 'fixtures/test_sentencepiece.model')
26 |
27 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
28 |
29 | tokenizer_class = XLNetTokenizer
30 |
31 | def setUp(self):
32 | super(XLNetTokenizationTest, self).setUp()
33 |
34 | # We have a SentencePiece fixture for testing
35 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
36 | tokenizer.save_pretrained(self.tmpdirname)
37 |
38 | def get_tokenizer(self, **kwargs):
39 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
40 |
41 | def get_input_output_texts(self):
42 | input_text = u"This is a test"
43 | output_text = u"This is a test"
44 | return input_text, output_text
45 |
46 |
47 | def test_full_tokenizer(self):
48 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
49 |
50 | tokens = tokenizer.tokenize(u'This is a test')
51 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
52 |
53 | self.assertListEqual(
54 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
55 |
56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
57 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
58 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
59 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
60 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
61 | ids = tokenizer.convert_tokens_to_ids(tokens)
62 | self.assertListEqual(
63 | ids, [8, 21, 84, 55, 24, 19, 7, 0,
64 | 602, 347, 347, 347, 3, 12, 66,
65 | 46, 72, 80, 6, 0, 4])
66 |
67 | back_tokens = tokenizer.convert_ids_to_tokens(ids)
68 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
69 | u'or', u'n', SPIECE_UNDERLINE + u'in',
70 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',',
71 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
72 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
73 | u'', u'.'])
74 |
75 | def test_tokenizer_lower(self):
76 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
77 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
78 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
79 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
80 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
81 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
82 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
83 |
84 | def test_tokenizer_no_lower(self):
85 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
86 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
87 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
88 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
89 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
90 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
91 |
92 | def test_sequence_builders(self):
93 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
94 |
95 | text = tokenizer.encode("sequence builders")
96 | text_2 = tokenizer.encode("multi-sequence build")
97 |
98 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
99 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
100 |
101 | assert encoded_sentence == text + [4, 3]
102 | assert encoded_pair == text + [4] + text_2 + [4, 3]
103 |
104 |
105 | if __name__ == '__main__':
106 | unittest.main()
107 |
--------------------------------------------------------------------------------
/pytorch_transformers/tokenization_auto.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Auto Model class. """
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import logging
20 |
21 | from .tokenization_bert import BertTokenizer
22 | from .tokenization_openai import OpenAIGPTTokenizer
23 | from .tokenization_gpt2 import GPT2Tokenizer
24 | from .tokenization_transfo_xl import TransfoXLTokenizer
25 | from .tokenization_xlnet import XLNetTokenizer
26 | from .tokenization_xlm import XLMTokenizer
27 | from .tokenization_roberta import RobertaTokenizer
28 | from .tokenization_distilbert import DistilBertTokenizer
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 | class AutoTokenizer(object):
33 | r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class
34 | that will be instantiated as one of the tokenizer classes of the library
35 | when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
36 | class method.
37 |
38 | The `from_pretrained()` method take care of returning the correct tokenizer class instance
39 | using pattern matching on the `pretrained_model_name_or_path` string.
40 |
41 | The tokenizer class to instantiate is selected as the first pattern matching
42 | in the `pretrained_model_name_or_path` string (in the following order):
43 | - contains `distilbert`: DistilBertTokenizer (DistilBert model)
44 | - contains `roberta`: RobertaTokenizer (RoBERTa model)
45 | - contains `bert`: BertTokenizer (Bert model)
46 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
47 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
48 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
49 | - contains `xlnet`: XLNetTokenizer (XLNet model)
50 | - contains `xlm`: XLMTokenizer (XLM model)
51 |
52 | This class cannot be instantiated using `__init__()` (throw an error).
53 | """
54 | def __init__(self):
55 | raise EnvironmentError("AutoTokenizer is designed to be instantiated "
56 | "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.")
57 |
58 | @classmethod
59 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
60 | r""" Instantiate a one of the tokenizer classes of the library
61 | from a pre-trained model vocabulary.
62 |
63 | The tokenizer class to instantiate is selected as the first pattern matching
64 | in the `pretrained_model_name_or_path` string (in the following order):
65 | - contains `distilbert`: DistilBertTokenizer (DistilBert model)
66 | - contains `roberta`: RobertaTokenizer (XLM model)
67 | - contains `bert`: BertTokenizer (Bert model)
68 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
69 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
70 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
71 | - contains `xlnet`: XLNetTokenizer (XLNet model)
72 | - contains `xlm`: XLMTokenizer (XLM model)
73 |
74 | Params:
75 | pretrained_model_name_or_path: either:
76 |
77 | - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
78 | - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
79 | - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
80 |
81 | cache_dir: (`optional`) string:
82 | Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
83 |
84 | force_download: (`optional`) boolean, default False:
85 | Force to (re-)download the vocabulary files and override the cached versions if they exists.
86 |
87 | proxies: (`optional`) dict, default None:
88 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
89 | The proxies are used on each request.
90 |
91 | inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
92 |
93 | kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
94 |
95 | Examples::
96 |
97 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache.
98 | tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
99 |
100 | """
101 | if 'distilbert' in pretrained_model_name_or_path:
102 | return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
103 | elif 'roberta' in pretrained_model_name_or_path:
104 | return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
105 | elif 'bert' in pretrained_model_name_or_path:
106 | return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
107 | elif 'openai-gpt' in pretrained_model_name_or_path:
108 | return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
109 | elif 'gpt2' in pretrained_model_name_or_path:
110 | return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
111 | elif 'transfo-xl' in pretrained_model_name_or_path:
112 | return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
113 | elif 'xlnet' in pretrained_model_name_or_path:
114 | return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
115 | elif 'xlm' in pretrained_model_name_or_path:
116 | return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
117 |
118 | raise ValueError("Unrecognized model identifier in {}. Should contains one of "
119 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
120 | "'xlm', 'roberta'".format(pretrained_model_name_or_path))
121 |
--------------------------------------------------------------------------------
/pytorch_transformers/tokenization_distilbert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for DistilBERT."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from .tokenization_bert import BertTokenizer
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {
32 | 'vocab_file':
33 | {
34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
36 | }
37 | }
38 |
39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
40 | 'distilbert-base-uncased': 512,
41 | 'distilbert-base-uncased-distilled-squad': 512,
42 | }
43 |
44 |
45 | class DistilBertTokenizer(BertTokenizer):
46 | r"""
47 | Constructs a DistilBertTokenizer.
48 | :class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
49 |
50 | Args:
51 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
52 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
53 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
54 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
55 | minimum of this value (if specified) and the underlying BERT model's sequence length.
56 | never_split: List of tokens which will never be split during tokenization. Only has an effect when
57 | do_wordpiece_only=False
58 | """
59 |
60 | vocab_files_names = VOCAB_FILES_NAMES
61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
63 |
--------------------------------------------------------------------------------
/pytorch_transformers/tokenization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import re
23 | from io import open
24 |
25 | from .tokenization_utils import PreTrainedTokenizer
26 | from .tokenization_bert import BasicTokenizer
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | VOCAB_FILES_NAMES = {
31 | 'vocab_file': 'vocab.json',
32 | 'merges_file': 'merges.txt',
33 | }
34 |
35 | PRETRAINED_VOCAB_FILES_MAP = {
36 | 'vocab_file':
37 | {
38 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
39 | },
40 | 'merges_file':
41 | {
42 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
43 | },
44 | }
45 |
46 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
47 | 'openai-gpt': 512,
48 | }
49 |
50 | def get_pairs(word):
51 | """
52 | Return set of symbol pairs in a word.
53 | word is represented as tuple of symbols (symbols being variable-length strings)
54 | """
55 | pairs = set()
56 | prev_char = word[0]
57 | for char in word[1:]:
58 | pairs.add((prev_char, char))
59 | prev_char = char
60 | return pairs
61 |
62 | def text_standardize(text):
63 | """
64 | fixes some issues the spacy tokenizer had on books corpus
65 | also does some whitespace standardization
66 | """
67 | text = text.replace('—', '-')
68 | text = text.replace('–', '-')
69 | text = text.replace('―', '-')
70 | text = text.replace('…', '...')
71 | text = text.replace('´', "'")
72 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
73 | text = re.sub(r'\s*\n\s*', ' \n ', text)
74 | text = re.sub(r'[^\S\n]+', ' ', text)
75 | return text.strip()
76 |
77 | class OpenAIGPTTokenizer(PreTrainedTokenizer):
78 | """
79 | BPE tokenizer. Peculiarities:
80 | - lower case all inputs
81 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
82 | """
83 | vocab_files_names = VOCAB_FILES_NAMES
84 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
85 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
86 |
87 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs):
88 | super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
89 |
90 | self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
91 | self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
92 |
93 | try:
94 | import ftfy
95 | from spacy.lang.en import English
96 | _nlp = English()
97 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
98 | self.fix_text = ftfy.fix_text
99 | except ImportError:
100 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
101 | self.nlp = BasicTokenizer(do_lower_case=True)
102 | self.fix_text = None
103 |
104 | self.encoder = json.load(open(vocab_file, encoding="utf-8"))
105 | self.decoder = {v:k for k,v in self.encoder.items()}
106 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
107 | merges = [tuple(merge.split()) for merge in merges]
108 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
109 | self.cache = {}
110 |
111 | @property
112 | def vocab_size(self):
113 | return len(self.encoder)
114 |
115 | def bpe(self, token):
116 | word = tuple(token[:-1]) + (token[-1] + '',)
117 | if token in self.cache:
118 | return self.cache[token]
119 | pairs = get_pairs(word)
120 |
121 | if not pairs:
122 | return token+''
123 |
124 | while True:
125 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
126 | if bigram not in self.bpe_ranks:
127 | break
128 | first, second = bigram
129 | new_word = []
130 | i = 0
131 | while i < len(word):
132 | try:
133 | j = word.index(first, i)
134 | new_word.extend(word[i:j])
135 | i = j
136 | except:
137 | new_word.extend(word[i:])
138 | break
139 |
140 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
141 | new_word.append(first+second)
142 | i += 2
143 | else:
144 | new_word.append(word[i])
145 | i += 1
146 | new_word = tuple(new_word)
147 | word = new_word
148 | if len(word) == 1:
149 | break
150 | else:
151 | pairs = get_pairs(word)
152 | word = ' '.join(word)
153 | if word == '\n ':
154 | word = '\n'
155 | self.cache[token] = word
156 | return word
157 |
158 | def _tokenize(self, text):
159 | """ Tokenize a string. """
160 | split_tokens = []
161 | if self.fix_text is None:
162 | # Using BERT's BasicTokenizer
163 | text = self.nlp.tokenize(text)
164 | for token in text:
165 | split_tokens.extend([t for t in self.bpe(token).split(' ')])
166 | else:
167 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
168 | text = self.nlp(text_standardize(self.fix_text(text)))
169 | for token in text:
170 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
171 | return split_tokens
172 |
173 | def _convert_token_to_id(self, token):
174 | """ Converts a token (str/unicode) in an id using the vocab. """
175 | return self.encoder.get(token, self.encoder.get(self.unk_token))
176 |
177 | def _convert_id_to_token(self, index):
178 | """Converts an id in a token (BPE) using the vocab."""
179 | return self.decoder.get(index, self.unk_token)
180 |
181 | def convert_tokens_to_string(self, tokens):
182 | """ Converts a sequence of tokens (string) in a single string. """
183 | out_string = ''.join(tokens).replace('', ' ').strip()
184 | return out_string
185 |
186 | def save_vocabulary(self, save_directory):
187 | """Save the tokenizer vocabulary and merge files to a directory."""
188 | if not os.path.isdir(save_directory):
189 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
190 | return
191 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
192 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
193 |
194 | with open(vocab_file, 'w', encoding='utf-8') as f:
195 | f.write(json.dumps(self.encoder, ensure_ascii=False))
196 |
197 | index = 0
198 | with open(merge_file, "w", encoding="utf-8") as writer:
199 | writer.write(u'#version: 0.2\n')
200 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
201 | if index != token_index:
202 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
203 | " Please check that the tokenizer is not corrupted!".format(merge_file))
204 | index = token_index
205 | writer.write(' '.join(bpe_tokens) + u'\n')
206 | index += 1
207 |
208 | return vocab_file, merge_file
209 |
--------------------------------------------------------------------------------
/pytorch_transformers/tokenization_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for RoBERTa."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import sys
20 | import json
21 | import logging
22 | import os
23 | import regex as re
24 | from io import open
25 |
26 | from .tokenization_gpt2 import GPT2Tokenizer
27 |
28 | try:
29 | from functools import lru_cache
30 | except ImportError:
31 | # Just a dummy decorator to get the checks to run on python2
32 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
33 | def lru_cache():
34 | return lambda func: func
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 | VOCAB_FILES_NAMES = {
39 | 'vocab_file': 'vocab.json',
40 | 'merges_file': 'merges.txt',
41 | }
42 |
43 | PRETRAINED_VOCAB_FILES_MAP = {
44 | 'vocab_file':
45 | {
46 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
47 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
48 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
49 | },
50 | 'merges_file':
51 | {
52 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
53 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
54 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
55 | },
56 | }
57 |
58 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
59 | 'roberta-base': 512,
60 | 'roberta-large': 512,
61 | 'roberta-large-mnli': 512,
62 | }
63 |
64 |
65 | class RobertaTokenizer(GPT2Tokenizer):
66 | """
67 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities:
68 | - Byte-level Byte-Pair-Encoding
69 | - Requires a space to start the input string => will add a space is there isn't.
70 | As a consequence, this tokenizer `encode` and `decode` method will not conserve
71 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"
72 | """
73 | vocab_files_names = VOCAB_FILES_NAMES
74 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
75 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
76 |
77 | def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="",
78 | cls_token="", unk_token="", pad_token='', mask_token='', **kwargs):
79 | super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors,
80 | bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
81 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
82 | mask_token=mask_token, **kwargs)
83 |
84 | def add_special_tokens_single_sentence(self, token_ids):
85 | """
86 | Adds special tokens to a sequence for sequence classification tasks.
87 | A RoBERTa sequence has the following format: X
88 | """
89 | return [self.cls_token_id] + token_ids + [self.sep_token_id]
90 |
91 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
92 | """
93 | Adds special tokens to a sequence pair for sequence classification tasks.
94 | A RoBERTa sequence pair has the following format: A B
95 | """
96 | sep = [self.sep_token_id]
97 | cls = [self.cls_token_id]
98 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep
99 |
--------------------------------------------------------------------------------
/pytorch_transformers/tokenization_xlnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Tokenization classes for XLNet model."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import logging
20 | import os
21 | from shutil import copyfile
22 |
23 | import unicodedata
24 | import six
25 |
26 | from .tokenization_utils import PreTrainedTokenizer
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
31 |
32 | PRETRAINED_VOCAB_FILES_MAP = {
33 | 'vocab_file':
34 | {
35 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
36 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
37 | }
38 | }
39 |
40 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
41 | 'xlnet-base-cased': None,
42 | 'xlnet-large-cased': None,
43 | }
44 |
45 | SPIECE_UNDERLINE = u'▁'
46 |
47 | # Segments (not really needed)
48 | SEG_ID_A = 0
49 | SEG_ID_B = 1
50 | SEG_ID_CLS = 2
51 | SEG_ID_SEP = 3
52 | SEG_ID_PAD = 4
53 |
54 | class XLNetTokenizer(PreTrainedTokenizer):
55 | """
56 | SentencePiece based tokenizer. Peculiarities:
57 |
58 | - requires `SentencePiece `_
59 | """
60 | vocab_files_names = VOCAB_FILES_NAMES
61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
63 |
64 | def __init__(self, vocab_file,
65 | do_lower_case=False, remove_space=True, keep_accents=False,
66 | bos_token="", eos_token="", unk_token="", sep_token="",
67 | pad_token="", cls_token="", mask_token="",
68 | additional_special_tokens=["", ""], **kwargs):
69 | super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
70 | unk_token=unk_token, sep_token=sep_token,
71 | pad_token=pad_token, cls_token=cls_token,
72 | mask_token=mask_token, additional_special_tokens=
73 | additional_special_tokens, **kwargs)
74 |
75 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
76 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
77 |
78 | try:
79 | import sentencepiece as spm
80 | except ImportError:
81 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
82 | "pip install sentencepiece")
83 |
84 | self.do_lower_case = do_lower_case
85 | self.remove_space = remove_space
86 | self.keep_accents = keep_accents
87 | self.vocab_file = vocab_file
88 |
89 | self.sp_model = spm.SentencePieceProcessor()
90 | self.sp_model.Load(vocab_file)
91 |
92 | @property
93 | def vocab_size(self):
94 | return len(self.sp_model)
95 |
96 | def __getstate__(self):
97 | state = self.__dict__.copy()
98 | state["sp_model"] = None
99 | return state
100 |
101 | def __setstate__(self, d):
102 | self.__dict__ = d
103 | try:
104 | import sentencepiece as spm
105 | except ImportError:
106 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
107 | "pip install sentencepiece")
108 | self.sp_model = spm.SentencePieceProcessor()
109 | self.sp_model.Load(self.vocab_file)
110 |
111 | def preprocess_text(self, inputs):
112 | if self.remove_space:
113 | outputs = ' '.join(inputs.strip().split())
114 | else:
115 | outputs = inputs
116 | outputs = outputs.replace("``", '"').replace("''", '"')
117 |
118 | if six.PY2 and isinstance(outputs, str):
119 | outputs = outputs.decode('utf-8')
120 |
121 | if not self.keep_accents:
122 | outputs = unicodedata.normalize('NFKD', outputs)
123 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
124 | if self.do_lower_case:
125 | outputs = outputs.lower()
126 |
127 | return outputs
128 |
129 | def _tokenize(self, text, return_unicode=True, sample=False):
130 | """ Tokenize a string.
131 | return_unicode is used only for py2
132 | """
133 | text = self.preprocess_text(text)
134 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2
135 | if six.PY2 and isinstance(text, unicode):
136 | text = text.encode('utf-8')
137 |
138 | if not sample:
139 | pieces = self.sp_model.EncodeAsPieces(text)
140 | else:
141 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
142 | new_pieces = []
143 | for piece in pieces:
144 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
145 | cur_pieces = self.sp_model.EncodeAsPieces(
146 | piece[:-1].replace(SPIECE_UNDERLINE, ''))
147 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
148 | if len(cur_pieces[0]) == 1:
149 | cur_pieces = cur_pieces[1:]
150 | else:
151 | cur_pieces[0] = cur_pieces[0][1:]
152 | cur_pieces.append(piece[-1])
153 | new_pieces.extend(cur_pieces)
154 | else:
155 | new_pieces.append(piece)
156 |
157 | # note(zhiliny): convert back to unicode for py2
158 | if six.PY2 and return_unicode:
159 | ret_pieces = []
160 | for piece in new_pieces:
161 | if isinstance(piece, str):
162 | piece = piece.decode('utf-8')
163 | ret_pieces.append(piece)
164 | new_pieces = ret_pieces
165 |
166 | return new_pieces
167 |
168 | def _convert_token_to_id(self, token):
169 | """ Converts a token (str/unicode) in an id using the vocab. """
170 | return self.sp_model.PieceToId(token)
171 |
172 | def _convert_id_to_token(self, index, return_unicode=True):
173 | """Converts an index (integer) in a token (string/unicode) using the vocab."""
174 | token = self.sp_model.IdToPiece(index)
175 | if six.PY2 and return_unicode and isinstance(token, str):
176 | token = token.decode('utf-8')
177 | return token
178 |
179 | def convert_tokens_to_string(self, tokens):
180 | """Converts a sequence of tokens (strings for sub-words) in a single string."""
181 | out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
182 | return out_string
183 |
184 | def add_special_tokens_single_sentence(self, token_ids):
185 | """
186 | Adds special tokens to a sequence pair for sequence classification tasks.
187 | An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
188 | """
189 | sep = [self.sep_token_id]
190 | cls = [self.cls_token_id]
191 | return token_ids + sep + cls
192 |
193 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
194 | """
195 | Adds special tokens to a sequence for sequence classification tasks.
196 | An XLNet sequence has the following format: X [SEP][CLS]
197 | """
198 | sep = [self.sep_token_id]
199 | cls = [self.cls_token_id]
200 | return token_ids_0 + sep + token_ids_1 + sep + cls
201 |
202 | def save_vocabulary(self, save_directory):
203 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file
204 | to a directory.
205 | """
206 | if not os.path.isdir(save_directory):
207 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
208 | return
209 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
210 |
211 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212 | copyfile(self.vocab_file, out_vocab_file)
213 |
214 | return (out_vocab_file,)
215 |
--------------------------------------------------------------------------------