├── .idea
├── $PRODUCT_WORKSPACE_FILE$
├── Bruce-Bert-Text-Classification.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── workspace.xml
├── README.md
├── THUCNews
└── data
│ ├── class.txt
│ ├── dev.txt
│ ├── test.txt
│ └── train.txt
├── __pycache__
├── train.cpython-37.pyc
└── utils.cpython-37.pyc
├── bert_pretrain
└── README.md
├── main.py
├── models
├── Bert.py
└── __pycache__
│ └── Bert.cpython-37.pyc
├── pytorch_pretrained
├── __init__.py
├── __main__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── file_utils.cpython-37.pyc
│ ├── modeling.cpython-37.pyc
│ ├── modeling_gpt2.cpython-37.pyc
│ ├── modeling_openai.cpython-37.pyc
│ ├── modeling_transfo_xl.cpython-37.pyc
│ ├── modeling_transfo_xl_utilities.cpython-37.pyc
│ ├── optimization.cpython-37.pyc
│ ├── optimization_openai.cpython-37.pyc
│ ├── tokenization.cpython-37.pyc
│ ├── tokenization_gpt2.cpython-37.pyc
│ ├── tokenization_openai.cpython-37.pyc
│ └── tokenization_transfo_xl.cpython-37.pyc
├── convert_gpt2_checkpoint_to_pytorch.py
├── convert_openai_checkpoint_to_pytorch.py
├── convert_tf_checkpoint_to_pytorch.py
├── convert_transfo_xl_checkpoint_to_pytorch.py
├── file_utils.py
├── modeling.py
├── modeling_gpt2.py
├── modeling_openai.py
├── modeling_transfo_xl.py
├── modeling_transfo_xl_utilities.py
├── optimization.py
├── optimization_openai.py
├── tokenization.py
├── tokenization_gpt2.py
├── tokenization_openai.py
└── tokenization_transfo_xl.py
├── train.py
└── utils.py
/.idea/$PRODUCT_WORKSPACE_FILE$:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.idea/Bruce-Bert-Text-Classification.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 | 1647411334171
105 |
106 |
107 | 1647411334171
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | file://$PROJECT_DIR$/train.py
124 | 31
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 基于Pytorch的Bert中文文本分类
2 | Pytorch版本 1.1
3 | python 3.7
4 |
--------------------------------------------------------------------------------
/THUCNews/data/class.txt:
--------------------------------------------------------------------------------
1 | finance
2 | realty
3 | stocks
4 | education
5 | science
6 | society
7 | politics
8 | sports
9 | game
10 | entertainment
--------------------------------------------------------------------------------
/__pycache__/train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/__pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_pretrain/README.md:
--------------------------------------------------------------------------------
1 | ## 此处存放bert预训练模型:
2 | pytorch_model.bin
3 | bert_config.json
4 | vocab.txt
5 |
6 | ## 下载地址:
7 | https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | import torch
4 | import numpy as np
5 | from importlib import import_module
6 | import argparse
7 | import utils
8 | import train
9 |
10 | parser = argparse.ArgumentParser(description='WH-Bert-Text-Classification')
11 | parser.add_argument('--model', type=str, default='Bert', help = 'choose a model')
12 | args = parser.parse_args()
13 |
14 |
15 | if __name__ == '__main__':
16 | dataset = 'THUCNews' #数据集地址
17 | model_name = args.model
18 | x = import_module('models.' + model_name)
19 | config = x.Config(dataset)
20 | np.random.seed(1)
21 | torch.manual_seed(1)
22 | torch.cuda.manual_seed_all(4)
23 | torch.backends.cudnn.deterministic = True #保证每次运行结果一样
24 |
25 | start_time = time.time()
26 | print('加载数据集')
27 | train_data, dev_data, test_data = utils.build_dataset(config)
28 | train_iter = utils.build_iterator(train_data, config)
29 | dev_iter = utils.build_iterator(dev_data, config)
30 | test_iter = utils.build_iterator(test_data, config)
31 |
32 | time_dif = utils.get_time_dif(start_time)
33 | print("模型开始前,准备数据时间:", time_dif)
34 |
35 | #模型训练,评估与测试
36 | model = x.Model(config).to(config.device)
37 | train.train(config, model, train_iter, dev_iter, test_iter)
38 |
--------------------------------------------------------------------------------
/models/Bert.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import torch.nn as nn
4 | # from pytorch_pretrained_bert import BertModel, BertTokenizer
5 | from pytorch_pretrained import BertModel, BertTokenizer
6 |
7 |
8 | class Config(object):
9 |
10 | """配置参数"""
11 | def __init__(self, dataset):
12 | self.model_name = 'Bert'
13 | self.train_path = dataset + '/data/train.txt' # 训练集
14 | self.dev_path = dataset + '/data/dev.txt' # 验证集
15 | self.test_path = dataset + '/data/test.txt' # 测试集
16 | self.class_list = [x.strip() for x in open(
17 | dataset + '/data/class.txt').readlines()] # 类别名单
18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果
19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
20 |
21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
22 | self.num_classes = len(self.class_list) # 类别数
23 | self.num_epochs = 3 # epoch数
24 | self.batch_size = 128 # mini-batch大小
25 | self.pad_size = 32 # 每句话处理成的长度(短填长切)
26 | self.learning_rate = 5e-5 # 学习率
27 | self.bert_path = './bert_pretrain'
28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
29 | self.hidden_size = 768
30 |
31 |
32 | class Model(nn.Module):
33 |
34 | def __init__(self, config):
35 | super(Model, self).__init__()
36 | self.bert = BertModel.from_pretrained(config.bert_path)
37 | for param in self.bert.parameters():
38 | param.requires_grad = True
39 | self.fc = nn.Linear(config.hidden_size, config.num_classes)
40 |
41 | def forward(self, x):
42 | context = x[0] # 输入的句子
43 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
44 | _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
45 | out = self.fc(pooled)
46 | return out
--------------------------------------------------------------------------------
/models/__pycache__/Bert.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/models/__pycache__/Bert.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.2"
2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | from .tokenization_openai import OpenAIGPTTokenizer
4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
5 | from .tokenization_gpt2 import GPT2Tokenizer
6 |
7 | from .modeling import (BertConfig, BertModel, BertForPreTraining,
8 | BertForMaskedLM, BertForNextSentencePrediction,
9 | BertForSequenceClassification, BertForMultipleChoice,
10 | BertForTokenClassification, BertForQuestionAnswering,
11 | load_tf_weights_in_bert)
12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
14 | load_tf_weights_in_openai_gpt)
15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
16 | load_tf_weights_in_transfo_xl)
17 | from .modeling_gpt2 import (GPT2Config, GPT2Model,
18 | GPT2LMHeadModel, GPT2DoubleHeadsModel,
19 | load_tf_weights_in_gpt2)
20 |
21 | from .optimization import BertAdam
22 | from .optimization_openai import OpenAIAdam
23 |
24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
25 |
--------------------------------------------------------------------------------
/pytorch_pretrained/__main__.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | def main():
3 | import sys
4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
5 | "convert_tf_checkpoint_to_pytorch",
6 | "convert_openai_checkpoint",
7 | "convert_transfo_xl_checkpoint",
8 | "convert_gpt2_checkpoint",
9 | ]:
10 | print(
11 | "Should be used as one of: \n"
12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
16 | else:
17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
18 | try:
19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
20 | except ImportError:
21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
22 | "In that case, it requires TensorFlow to be installed. Please see "
23 | "https://www.tensorflow.org/install/ for installation instructions.")
24 | raise
25 |
26 | if len(sys.argv) != 5:
27 | # pylint: disable=line-too-long
28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
29 | else:
30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
31 | TF_CONFIG = sys.argv.pop()
32 | TF_CHECKPOINT = sys.argv.pop()
33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
34 | elif sys.argv[1] == "convert_openai_checkpoint":
35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
37 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
38 | if len(sys.argv) == 5:
39 | OPENAI_GPT_CONFIG = sys.argv[4]
40 | else:
41 | OPENAI_GPT_CONFIG = ""
42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
43 | OPENAI_GPT_CONFIG,
44 | PYTORCH_DUMP_OUTPUT)
45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint":
46 | try:
47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
48 | except ImportError:
49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
50 | "In that case, it requires TensorFlow to be installed. Please see "
51 | "https://www.tensorflow.org/install/ for installation instructions.")
52 | raise
53 |
54 | if 'ckpt' in sys.argv[2].lower():
55 | TF_CHECKPOINT = sys.argv[2]
56 | TF_DATASET_FILE = ""
57 | else:
58 | TF_DATASET_FILE = sys.argv[2]
59 | TF_CHECKPOINT = ""
60 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
61 | if len(sys.argv) == 5:
62 | TF_CONFIG = sys.argv[4]
63 | else:
64 | TF_CONFIG = ""
65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
66 | else:
67 | try:
68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
69 | except ImportError:
70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
71 | "In that case, it requires TensorFlow to be installed. Please see "
72 | "https://www.tensorflow.org/install/ for installation instructions.")
73 | raise
74 |
75 | TF_CHECKPOINT = sys.argv[2]
76 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
77 | if len(sys.argv) == 5:
78 | TF_CONFIG = sys.argv[4]
79 | else:
80 | TF_CONFIG = ""
81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/file_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/file_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/modeling.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/modeling.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/modeling_openai.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/modeling_openai.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/optimization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/optimization.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/optimization_openai.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/optimization_openai.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/tokenization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/tokenization.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/tokenization_openai.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/tokenization_openai.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LePetitPrinceWh/Bert-Pytorch-TextClassification/962085ff1501ae75d9b784a14bd95c52c34fb0f7/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-37.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained/convert_gpt2_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
25 | GPT2Config,
26 | GPT2Model,
27 | load_tf_weights_in_gpt2)
28 |
29 |
30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if gpt2_config_file == "":
33 | config = GPT2Config()
34 | else:
35 | config = GPT2Config(gpt2_config_file)
36 | model = GPT2Model(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--gpt2_checkpoint_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--gpt2_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
71 | args.gpt2_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/pytorch_pretrained/convert_openai_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
25 | OpenAIGPTConfig,
26 | OpenAIGPTModel,
27 | load_tf_weights_in_openai_gpt)
28 |
29 |
30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if openai_config_file == "":
33 | config = OpenAIGPTConfig()
34 | else:
35 | config = OpenAIGPTConfig(openai_config_file)
36 | model = OpenAIGPTModel(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--openai_checkpoint_folder_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--openai_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
71 | args.openai_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/pytorch_pretrained/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import re
23 | import argparse
24 | import tensorflow as tf
25 | import torch
26 | import numpy as np
27 |
28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | # Initialise PyTorch model
32 | config = BertConfig.from_json_file(bert_config_file)
33 | print("Building PyTorch model from configuration: {}".format(str(config)))
34 | model = BertForPreTraining(config)
35 |
36 | # Load weights from tf checkpoint
37 | load_tf_weights_in_bert(model, tf_checkpoint_path)
38 |
39 | # Save pytorch-model
40 | print("Save PyTorch model to {}".format(pytorch_dump_path))
41 | torch.save(model.state_dict(), pytorch_dump_path)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | ## Required parameters
47 | parser.add_argument("--tf_checkpoint_path",
48 | default = None,
49 | type = str,
50 | required = True,
51 | help = "Path the TensorFlow checkpoint path.")
52 | parser.add_argument("--bert_config_file",
53 | default = None,
54 | type = str,
55 | required = True,
56 | help = "The config json file corresponding to the pre-trained BERT model. \n"
57 | "This specifies the model architecture.")
58 | parser.add_argument("--pytorch_dump_path",
59 | default = None,
60 | type = str,
61 | required = True,
62 | help = "Path to the output PyTorch model.")
63 | args = parser.parse_args()
64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
65 | args.bert_config_file,
66 | args.pytorch_dump_path)
67 |
--------------------------------------------------------------------------------
/pytorch_pretrained/convert_transfo_xl_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert Transformer XL checkpoint and datasets."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import os
21 | import sys
22 | from io import open
23 |
24 | import torch
25 |
26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
28 | WEIGHTS_NAME,
29 | TransfoXLConfig,
30 | TransfoXLLMHeadModel,
31 | load_tf_weights_in_transfo_xl)
32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
33 | VOCAB_NAME)
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 | # We do this to be able to load python 2 datasets pickles
41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
42 | data_utils.Vocab = data_utils.TransfoXLTokenizer
43 | data_utils.Corpus = data_utils.TransfoXLCorpus
44 | sys.modules['data_utils'] = data_utils
45 | sys.modules['vocabulary'] = data_utils
46 |
47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
48 | transfo_xl_config_file,
49 | pytorch_dump_folder_path,
50 | transfo_xl_dataset_file):
51 | if transfo_xl_dataset_file:
52 | # Convert a pre-processed corpus (see original TensorFlow repo)
53 | with open(transfo_xl_dataset_file, "rb") as fp:
54 | corpus = pickle.load(fp, encoding="latin1")
55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
58 | corpus_vocab_dict = corpus.vocab.__dict__
59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
60 |
61 | corpus_dict_no_vocab = corpus.__dict__
62 | corpus_dict_no_vocab.pop('vocab', None)
63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
64 | print("Save dataset to {}".format(pytorch_dataset_dump_path))
65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
66 |
67 | if tf_checkpoint_path:
68 | # Convert a pre-trained TensorFlow model
69 | config_path = os.path.abspath(transfo_xl_config_file)
70 | tf_path = os.path.abspath(tf_checkpoint_path)
71 |
72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
73 | # Initialise PyTorch model
74 | if transfo_xl_config_file == "":
75 | config = TransfoXLConfig()
76 | else:
77 | config = TransfoXLConfig(transfo_xl_config_file)
78 | print("Building PyTorch model from configuration: {}".format(str(config)))
79 | model = TransfoXLLMHeadModel(config)
80 |
81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path)
82 | # Save pytorch-model
83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
86 | torch.save(model.state_dict(), pytorch_weights_dump_path)
87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
89 | f.write(config.to_json_string())
90 |
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--pytorch_dump_folder_path",
95 | default = None,
96 | type = str,
97 | required = True,
98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.")
99 | parser.add_argument("--tf_checkpoint_path",
100 | default = "",
101 | type = str,
102 | help = "An optional path to a TensorFlow checkpoint path to be converted.")
103 | parser.add_argument("--transfo_xl_config_file",
104 | default = "",
105 | type = str,
106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n"
107 | "This specifies the model architecture.")
108 | parser.add_argument("--transfo_xl_dataset_file",
109 | default = "",
110 | type = str,
111 | help = "An optional dataset file to be converted in a vocabulary.")
112 | args = parser.parse_args()
113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
114 | args.transfo_xl_config_file,
115 | args.pytorch_dump_folder_path,
116 | args.transfo_xl_dataset_file)
117 |
--------------------------------------------------------------------------------
/pytorch_pretrained/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import sys
9 | import json
10 | import logging
11 | import os
12 | import shutil
13 | import tempfile
14 | import fnmatch
15 | from functools import wraps
16 | from hashlib import sha256
17 | import sys
18 | from io import open
19 |
20 | import boto3
21 | import requests
22 | from botocore.exceptions import ClientError
23 | from tqdm import tqdm
24 |
25 | try:
26 | from urllib.parse import urlparse
27 | except ImportError:
28 | pass
29 |
30 | try:
31 | from pathlib import Path
32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
33 | Path.home() / '.pytorch_pretrained_bert'))
34 | except (AttributeError, ImportError):
35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
37 |
38 | CONFIG_NAME = "config.json"
39 | WEIGHTS_NAME = "pytorch_model.bin"
40 |
41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
42 |
43 |
44 | def url_to_filename(url, etag=None):
45 | """
46 | Convert `url` into a hashed filename in a repeatable way.
47 | If `etag` is specified, append its hash to the url's, delimited
48 | by a period.
49 | """
50 | url_bytes = url.encode('utf-8')
51 | url_hash = sha256(url_bytes)
52 | filename = url_hash.hexdigest()
53 |
54 | if etag:
55 | etag_bytes = etag.encode('utf-8')
56 | etag_hash = sha256(etag_bytes)
57 | filename += '.' + etag_hash.hexdigest()
58 |
59 | return filename
60 |
61 |
62 | def filename_to_url(filename, cache_dir=None):
63 | """
64 | Return the url and etag (which may be ``None``) stored for `filename`.
65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
66 | """
67 | if cache_dir is None:
68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
70 | cache_dir = str(cache_dir)
71 |
72 | cache_path = os.path.join(cache_dir, filename)
73 | if not os.path.exists(cache_path):
74 | raise EnvironmentError("file {} not found".format(cache_path))
75 |
76 | meta_path = cache_path + '.json'
77 | if not os.path.exists(meta_path):
78 | raise EnvironmentError("file {} not found".format(meta_path))
79 |
80 | with open(meta_path, encoding="utf-8") as meta_file:
81 | metadata = json.load(meta_file)
82 | url = metadata['url']
83 | etag = metadata['etag']
84 |
85 | return url, etag
86 |
87 |
88 | def cached_path(url_or_filename, cache_dir=None):
89 | """
90 | Given something that might be a URL (or might be a local path),
91 | determine which. If it's a URL, download the file and cache it, and
92 | return the path to the cached file. If it's already a local path,
93 | make sure the file exists and then return the path.
94 | """
95 | if cache_dir is None:
96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
98 | url_or_filename = str(url_or_filename)
99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
100 | cache_dir = str(cache_dir)
101 |
102 | parsed = urlparse(url_or_filename)
103 |
104 | if parsed.scheme in ('http', 'https', 's3'):
105 | # URL, so get it from the cache (downloading if necessary)
106 | return get_from_cache(url_or_filename, cache_dir)
107 | elif os.path.exists(url_or_filename):
108 | # File, and it exists.
109 | return url_or_filename
110 | elif parsed.scheme == '':
111 | # File, but it doesn't exist.
112 | raise EnvironmentError("file {} not found".format(url_or_filename))
113 | else:
114 | # Something unknown
115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
116 |
117 |
118 | def split_s3_path(url):
119 | """Split a full s3 path into the bucket name and path."""
120 | parsed = urlparse(url)
121 | if not parsed.netloc or not parsed.path:
122 | raise ValueError("bad s3 path {}".format(url))
123 | bucket_name = parsed.netloc
124 | s3_path = parsed.path
125 | # Remove '/' at beginning of path.
126 | if s3_path.startswith("/"):
127 | s3_path = s3_path[1:]
128 | return bucket_name, s3_path
129 |
130 |
131 | def s3_request(func):
132 | """
133 | Wrapper function for s3 requests in order to create more helpful error
134 | messages.
135 | """
136 |
137 | @wraps(func)
138 | def wrapper(url, *args, **kwargs):
139 | try:
140 | return func(url, *args, **kwargs)
141 | except ClientError as exc:
142 | if int(exc.response["Error"]["Code"]) == 404:
143 | raise EnvironmentError("file {} not found".format(url))
144 | else:
145 | raise
146 |
147 | return wrapper
148 |
149 |
150 | @s3_request
151 | def s3_etag(url):
152 | """Check ETag on S3 object."""
153 | s3_resource = boto3.resource("s3")
154 | bucket_name, s3_path = split_s3_path(url)
155 | s3_object = s3_resource.Object(bucket_name, s3_path)
156 | return s3_object.e_tag
157 |
158 |
159 | @s3_request
160 | def s3_get(url, temp_file):
161 | """Pull a file directly from S3."""
162 | s3_resource = boto3.resource("s3")
163 | bucket_name, s3_path = split_s3_path(url)
164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
165 |
166 |
167 | def http_get(url, temp_file):
168 | req = requests.get(url, stream=True)
169 | content_length = req.headers.get('Content-Length')
170 | total = int(content_length) if content_length is not None else None
171 | progress = tqdm(unit="B", total=total)
172 | for chunk in req.iter_content(chunk_size=1024):
173 | if chunk: # filter out keep-alive new chunks
174 | progress.update(len(chunk))
175 | temp_file.write(chunk)
176 | progress.close()
177 |
178 |
179 | def get_from_cache(url, cache_dir=None):
180 | """
181 | Given a URL, look for the corresponding dataset in the local cache.
182 | If it's not there, download it. Then return the path to the cached file.
183 | """
184 | if cache_dir is None:
185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
187 | cache_dir = str(cache_dir)
188 |
189 | if not os.path.exists(cache_dir):
190 | os.makedirs(cache_dir)
191 |
192 | # Get eTag to add to filename, if it exists.
193 | if url.startswith("s3://"):
194 | etag = s3_etag(url)
195 | else:
196 | try:
197 | response = requests.head(url, allow_redirects=True)
198 | if response.status_code != 200:
199 | etag = None
200 | else:
201 | etag = response.headers.get("ETag")
202 | except EnvironmentError:
203 | etag = None
204 |
205 | if sys.version_info[0] == 2 and etag is not None:
206 | etag = etag.decode('utf-8')
207 | filename = url_to_filename(url, etag)
208 |
209 | # get cache path to put the file
210 | cache_path = os.path.join(cache_dir, filename)
211 |
212 | # If we don't have a connection (etag is None) and can't identify the file
213 | # try to get the last downloaded one
214 | if not os.path.exists(cache_path) and etag is None:
215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
217 | if matching_files:
218 | cache_path = os.path.join(cache_dir, matching_files[-1])
219 |
220 | if not os.path.exists(cache_path):
221 | # Download to temporary file, then copy to cache dir once finished.
222 | # Otherwise you get corrupt cache entries if the download gets interrupted.
223 | with tempfile.NamedTemporaryFile() as temp_file:
224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
225 |
226 | # GET file object
227 | if url.startswith("s3://"):
228 | s3_get(url, temp_file)
229 | else:
230 | http_get(url, temp_file)
231 |
232 | # we are copying the file before closing it, so flush to avoid truncation
233 | temp_file.flush()
234 | # shutil.copyfileobj() starts at the current position, so go to the start
235 | temp_file.seek(0)
236 |
237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
238 | with open(cache_path, 'wb') as cache_file:
239 | shutil.copyfileobj(temp_file, cache_file)
240 |
241 | logger.info("creating metadata file for %s", cache_path)
242 | meta = {'url': url, 'etag': etag}
243 | meta_path = cache_path + '.json'
244 | with open(meta_path, 'w') as meta_file:
245 | output_string = json.dumps(meta)
246 | if sys.version_info[0] == 2 and isinstance(output_string, str):
247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2
248 | meta_file.write(output_string)
249 |
250 | logger.info("removing temp file %s", temp_file.name)
251 |
252 | return cache_path
253 |
254 |
255 | def read_set_from_file(filename):
256 | '''
257 | Extract a de-duped collection (set) of text from a file.
258 | Expected file format is one item per line.
259 | '''
260 | collection = set()
261 | with open(filename, 'r', encoding='utf-8') as file_:
262 | for line in file_:
263 | collection.add(line.rstrip())
264 | return collection
265 |
266 |
267 | def get_file_extension(path, dot=True, lower=True):
268 | ext = os.path.splitext(path)[1]
269 | ext = ext if dot else ext[1:]
270 | return ext.lower() if lower else ext
271 |
--------------------------------------------------------------------------------
/pytorch_pretrained/modeling_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch OpenAI GPT-2 model."""
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import collections
21 | import copy
22 | import json
23 | import logging
24 | import math
25 | import os
26 | import shutil
27 | import tarfile
28 | import tempfile
29 | import sys
30 | from io import open
31 |
32 | import torch
33 | import torch.nn as nn
34 | from torch.nn import CrossEntropyLoss
35 | from torch.nn.parameter import Parameter
36 |
37 | from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
38 | from .modeling import BertLayerNorm as LayerNorm
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 | PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
43 | PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
44 |
45 | def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
46 | """ Load tf checkpoints in a pytorch model
47 | """
48 | try:
49 | import re
50 | import numpy as np
51 | import tensorflow as tf
52 | except ImportError:
53 | print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
54 | "https://www.tensorflow.org/install/ for installation instructions.")
55 | raise
56 | tf_path = os.path.abspath(gpt2_checkpoint_path)
57 | print("Converting TensorFlow checkpoint from {}".format(tf_path))
58 | # Load weights from TF model
59 | init_vars = tf.train.list_variables(tf_path)
60 | names = []
61 | arrays = []
62 | for name, shape in init_vars:
63 | print("Loading TF weight {} with shape {}".format(name, shape))
64 | array = tf.train.load_variable(tf_path, name)
65 | names.append(name)
66 | arrays.append(array.squeeze())
67 |
68 | for name, array in zip(names, arrays):
69 | name = name[6:] # skip "model/"
70 | name = name.split('/')
71 | pointer = model
72 | for m_name in name:
73 | if re.fullmatch(r'[A-Za-z]+\d+', m_name):
74 | l = re.split(r'(\d+)', m_name)
75 | else:
76 | l = [m_name]
77 | if l[0] == 'w' or l[0] == 'g':
78 | pointer = getattr(pointer, 'weight')
79 | elif l[0] == 'b':
80 | pointer = getattr(pointer, 'bias')
81 | elif l[0] == 'wpe' or l[0] == 'wte':
82 | pointer = getattr(pointer, l[0])
83 | pointer = getattr(pointer, 'weight')
84 | else:
85 | pointer = getattr(pointer, l[0])
86 | if len(l) >= 2:
87 | num = int(l[1])
88 | pointer = pointer[num]
89 | try:
90 | assert pointer.shape == array.shape
91 | except AssertionError as e:
92 | e.args += (pointer.shape, array.shape)
93 | raise
94 | print("Initialize PyTorch weight {}".format(name))
95 | pointer.data = torch.from_numpy(array)
96 | return model
97 |
98 |
99 | def gelu(x):
100 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
101 |
102 |
103 | class GPT2Config(object):
104 | """Configuration class to store the configuration of a `GPT2Model`.
105 | """
106 |
107 | def __init__(
108 | self,
109 | vocab_size_or_config_json_file=50257,
110 | n_positions=1024,
111 | n_ctx=1024,
112 | n_embd=768,
113 | n_layer=12,
114 | n_head=12,
115 | layer_norm_epsilon=1e-5,
116 | initializer_range=0.02,
117 | ):
118 | """Constructs GPT2Config.
119 |
120 | Args:
121 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
122 | n_positions: Number of positional embeddings.
123 | n_ctx: Size of the causal mask (usually same as n_positions).
124 | n_embd: Dimensionality of the embeddings and hidden states.
125 | n_layer: Number of hidden layers in the Transformer encoder.
126 | n_head: Number of attention heads for each attention layer in
127 | the Transformer encoder.
128 | layer_norm_epsilon: epsilon to use in the layer norm layers
129 | initializer_range: The sttdev of the truncated_normal_initializer for
130 | initializing all weight matrices.
131 | """
132 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
133 | and isinstance(vocab_size_or_config_json_file, unicode)):
134 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
135 | json_config = json.loads(reader.read())
136 | for key, value in json_config.items():
137 | self.__dict__[key] = value
138 | elif isinstance(vocab_size_or_config_json_file, int):
139 | self.vocab_size = vocab_size_or_config_json_file
140 | self.n_ctx = n_ctx
141 | self.n_positions = n_positions
142 | self.n_embd = n_embd
143 | self.n_layer = n_layer
144 | self.n_head = n_head
145 | self.layer_norm_epsilon = layer_norm_epsilon
146 | self.initializer_range = initializer_range
147 | else:
148 | raise ValueError(
149 | "First argument must be either a vocabulary size (int)"
150 | "or the path to a pretrained model config file (str)"
151 | )
152 |
153 | @classmethod
154 | def from_dict(cls, json_object):
155 | """Constructs a `GPT2Config` from a Python dictionary of parameters."""
156 | config = GPT2Config(vocab_size_or_config_json_file=-1)
157 | for key, value in json_object.items():
158 | config.__dict__[key] = value
159 | return config
160 |
161 | @classmethod
162 | def from_json_file(cls, json_file):
163 | """Constructs a `GPT2Config` from a json file of parameters."""
164 | with open(json_file, "r", encoding="utf-8") as reader:
165 | text = reader.read()
166 | return cls.from_dict(json.loads(text))
167 |
168 | def __repr__(self):
169 | return str(self.to_json_string())
170 |
171 | def to_dict(self):
172 | """Serializes this instance to a Python dictionary."""
173 | output = copy.deepcopy(self.__dict__)
174 | return output
175 |
176 | def to_json_string(self):
177 | """Serializes this instance to a JSON string."""
178 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
179 |
180 | def to_json_file(self, json_file_path):
181 | """ Save this instance to a json file."""
182 | with open(json_file_path, "w", encoding='utf-8') as writer:
183 | writer.write(self.to_json_string())
184 |
185 |
186 | class Conv1D(nn.Module):
187 | def __init__(self, nf, nx):
188 | super(Conv1D, self).__init__()
189 | self.nf = nf
190 | w = torch.empty(nx, nf)
191 | nn.init.normal_(w, std=0.02)
192 | self.weight = Parameter(w)
193 | self.bias = Parameter(torch.zeros(nf))
194 |
195 | def forward(self, x):
196 | size_out = x.size()[:-1] + (self.nf,)
197 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
198 | x = x.view(*size_out)
199 | return x
200 |
201 |
202 | class Attention(nn.Module):
203 | def __init__(self, nx, n_ctx, config, scale=False):
204 | super(Attention, self).__init__()
205 | n_state = nx # in Attention: n_state=768 (nx=n_embd)
206 | # [switch nx => n_state from Block to Attention to keep identical to TF implem]
207 | assert n_state % config.n_head == 0
208 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
209 | self.n_head = config.n_head
210 | self.split_size = n_state
211 | self.scale = scale
212 | self.c_attn = Conv1D(n_state * 3, nx)
213 | self.c_proj = Conv1D(n_state, nx)
214 |
215 | def _attn(self, q, k, v):
216 | w = torch.matmul(q, k)
217 | if self.scale:
218 | w = w / math.sqrt(v.size(-1))
219 | nd, ns = w.size(-2), w.size(-1)
220 | b = self.bias[:, :, ns-nd:ns, :ns]
221 | w = w * b - 1e4 * (1 - b)
222 |
223 | w = nn.Softmax(dim=-1)(w)
224 | return torch.matmul(w, v)
225 |
226 | def merge_heads(self, x):
227 | x = x.permute(0, 2, 1, 3).contiguous()
228 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
229 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
230 |
231 | def split_heads(self, x, k=False):
232 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
233 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
234 | if k:
235 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
236 | else:
237 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
238 |
239 | def forward(self, x, layer_past=None):
240 | x = self.c_attn(x)
241 | query, key, value = x.split(self.split_size, dim=2)
242 | query = self.split_heads(query)
243 | key = self.split_heads(key, k=True)
244 | value = self.split_heads(value)
245 | if layer_past is not None:
246 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
247 | key = torch.cat((past_key, key), dim=-1)
248 | value = torch.cat((past_value, value), dim=-2)
249 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
250 | a = self._attn(query, key, value)
251 | a = self.merge_heads(a)
252 | a = self.c_proj(a)
253 | return a, present
254 |
255 |
256 | class MLP(nn.Module):
257 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
258 | super(MLP, self).__init__()
259 | nx = config.n_embd
260 | self.c_fc = Conv1D(n_state, nx)
261 | self.c_proj = Conv1D(nx, n_state)
262 | self.act = gelu
263 |
264 | def forward(self, x):
265 | h = self.act(self.c_fc(x))
266 | h2 = self.c_proj(h)
267 | return h2
268 |
269 |
270 | class Block(nn.Module):
271 | def __init__(self, n_ctx, config, scale=False):
272 | super(Block, self).__init__()
273 | nx = config.n_embd
274 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
275 | self.attn = Attention(nx, n_ctx, config, scale)
276 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
277 | self.mlp = MLP(4 * nx, config)
278 |
279 | def forward(self, x, layer_past=None):
280 | a, present = self.attn(self.ln_1(x), layer_past=layer_past)
281 | x = x + a
282 | m = self.mlp(self.ln_2(x))
283 | x = x + m
284 | return x, present
285 |
286 |
287 | class GPT2LMHead(nn.Module):
288 | """ Language Model Head for the transformer """
289 |
290 | def __init__(self, model_embeddings_weights, config):
291 | super(GPT2LMHead, self).__init__()
292 | self.n_embd = config.n_embd
293 | self.set_embeddings_weights(model_embeddings_weights)
294 |
295 | def set_embeddings_weights(self, model_embeddings_weights):
296 | embed_shape = model_embeddings_weights.shape
297 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
298 | self.decoder.weight = model_embeddings_weights # Tied weights
299 |
300 | def forward(self, hidden_state):
301 | # Truncated Language modeling logits (we remove the last token)
302 | # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
303 | lm_logits = self.decoder(hidden_state)
304 | return lm_logits
305 |
306 |
307 | class GPT2MultipleChoiceHead(nn.Module):
308 | """ Classifier Head for the transformer """
309 |
310 | def __init__(self, config):
311 | super(GPT2MultipleChoiceHead, self).__init__()
312 | self.n_embd = config.n_embd
313 | self.linear = nn.Linear(config.n_embd, 1)
314 |
315 | nn.init.normal_(self.linear.weight, std=0.02)
316 | nn.init.normal_(self.linear.bias, 0)
317 |
318 | def forward(self, hidden_states, mc_token_ids):
319 | # Classification logits
320 | # hidden_state (bsz, num_choices, seq_length, hidden_size)
321 | # mc_token_ids (bsz, num_choices)
322 | mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
323 | # (bsz, num_choices, 1, hidden_size)
324 | multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
325 | # (bsz, num_choices, hidden_size)
326 | multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
327 | # (bsz, num_choices)
328 | return multiple_choice_logits
329 |
330 |
331 | class GPT2PreTrainedModel(nn.Module):
332 | """ An abstract class to handle weights initialization and
333 | a simple interface for dowloading and loading pretrained models.
334 | """
335 |
336 | def __init__(self, config, *inputs, **kwargs):
337 | super(GPT2PreTrainedModel, self).__init__()
338 | if not isinstance(config, GPT2Config):
339 | raise ValueError(
340 | "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
341 | "To create a model from a pretrained model use "
342 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
343 | self.__class__.__name__, self.__class__.__name__
344 | )
345 | )
346 | self.config = config
347 |
348 | def set_tied(self):
349 | pass
350 |
351 | def init_weights(self, module):
352 | """ Initialize the weights.
353 | """
354 | if isinstance(module, (nn.Linear, nn.Embedding)):
355 | # Slightly different from the TF version which uses truncated_normal for initialization
356 | # cf https://github.com/pytorch/pytorch/pull/5617
357 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
358 | elif isinstance(module, LayerNorm):
359 | module.bias.data.zero_()
360 | module.weight.data.fill_(1.0)
361 | if isinstance(module, nn.Linear) and module.bias is not None:
362 | module.bias.data.zero_()
363 |
364 | @classmethod
365 | def from_pretrained(
366 | cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
367 | ):
368 | """
369 | Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
370 | Download and cache the pre-trained model file if needed.
371 |
372 | Params:
373 | pretrained_model_name_or_path: either:
374 | - a str with the name of a pre-trained model to load selected in the list of:
375 | . `gpt2`
376 | - a path or url to a pretrained model archive containing:
377 | . `gpt2_config.json` a configuration file for the model
378 | . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
379 | - a path or url to a pretrained model archive containing:
380 | . `gpt2_config.json` a configuration file for the model
381 | . a TensorFlow checkpoint with trained weights
382 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint
383 | cache_dir: an optional path to a folder in which the pre-trained models will be cached.
384 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
385 | *inputs, **kwargs: additional input for the specific GPT class
386 | """
387 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
388 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
389 | config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
390 | else:
391 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
392 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
393 | # redirect to the cache, if necessary
394 | try:
395 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
396 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
397 | except EnvironmentError:
398 | logger.error(
399 | "Model name '{}' was not found in model name list ({}). "
400 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
401 | "at this path or url.".format(
402 | pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
403 | archive_file, config_file
404 | )
405 | )
406 | return None
407 | if resolved_archive_file == archive_file and resolved_config_file == config_file:
408 | logger.info("loading weights file {}".format(archive_file))
409 | logger.info("loading configuration file {}".format(config_file))
410 | else:
411 | logger.info("loading weights file {} from cache at {}".format(
412 | archive_file, resolved_archive_file))
413 | logger.info("loading configuration file {} from cache at {}".format(
414 | config_file, resolved_config_file))
415 | # Load config
416 | config = GPT2Config.from_json_file(resolved_config_file)
417 | logger.info("Model config {}".format(config))
418 | # Instantiate model.
419 | model = cls(config, *inputs, **kwargs)
420 | if state_dict is None and not from_tf:
421 | state_dict = torch.load(resolved_archive_file, map_location='cpu')
422 | if from_tf:
423 | # Directly load from a TensorFlow checkpoint (stored as NumPy array)
424 | return load_tf_weights_in_gpt2(model, resolved_archive_file)
425 |
426 | old_keys = []
427 | new_keys = []
428 | for key in state_dict.keys():
429 | new_key = None
430 | if key.endswith(".g"):
431 | new_key = key[:-2] + ".weight"
432 | elif key.endswith(".b"):
433 | new_key = key[:-2] + ".bias"
434 | elif key.endswith(".w"):
435 | new_key = key[:-2] + ".weight"
436 | if new_key:
437 | old_keys.append(key)
438 | new_keys.append(new_key)
439 | for old_key, new_key in zip(old_keys, new_keys):
440 | state_dict[new_key] = state_dict.pop(old_key)
441 |
442 | missing_keys = []
443 | unexpected_keys = []
444 | error_msgs = []
445 | # copy state_dict so _load_from_state_dict can modify it
446 | metadata = getattr(state_dict, "_metadata", None)
447 | state_dict = state_dict.copy()
448 | if metadata is not None:
449 | state_dict._metadata = metadata
450 |
451 | def load(module, prefix=""):
452 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
453 | module._load_from_state_dict(
454 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
455 | )
456 | for name, child in module._modules.items():
457 | if child is not None:
458 | load(child, prefix + name + ".")
459 |
460 | start_model = model
461 | if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
462 | start_model = model.transformer
463 | load(start_model, prefix="")
464 |
465 | if len(missing_keys) > 0:
466 | logger.info(
467 | "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
468 | )
469 | if len(unexpected_keys) > 0:
470 | logger.info(
471 | "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
472 | )
473 | if len(error_msgs) > 0:
474 | raise RuntimeError(
475 | "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
476 | )
477 |
478 | # Make sure we are still sharing the output and input embeddings after loading weights
479 | model.set_tied()
480 | return model
481 |
482 |
483 | class GPT2Model(GPT2PreTrainedModel):
484 | """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
485 |
486 | Params:
487 | config: a GPT2Config class instance with the configuration to build a new model
488 |
489 | Inputs:
490 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
491 | were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
492 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids
493 | with the position indices (selected in the range [0, config.n_positions - 1[.
494 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
495 | You can use it to add a third type of embedding to each input token in the sequence
496 | (the previous two being the word and position embeddings).
497 | The input, position and token_type embeddings are summed inside the Transformer before the first
498 | self-attention block.
499 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
500 | (key and values in the attention blocks) to speed up sequential decoding
501 | (this is the presents output of the model, cf. below).
502 |
503 | Outputs a tuple consisting of:
504 | `hidden_states`: the encoded-hidden-states at the top of the model
505 | as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
506 | (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
507 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
508 | torch.FloatTensors. They can be reused to speed up sequential decoding.
509 |
510 | Example usage:
511 | ```python
512 | # Already been converted into BPE token ids
513 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
514 |
515 | config = modeling_gpt2.GPT2Config()
516 |
517 | model = modeling_gpt2.GPT2Model(config)
518 | hidden_states, presents = model(input_ids)
519 | ```
520 | """
521 |
522 | def __init__(self, config):
523 | super(GPT2Model, self).__init__(config)
524 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
525 | self.wpe = nn.Embedding(config.n_positions, config.n_embd)
526 | block = Block(config.n_ctx, config, scale=True)
527 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
528 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
529 |
530 | self.apply(self.init_weights)
531 |
532 | def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
533 | if past is None:
534 | past_length = 0
535 | past = [None] * len(self.h)
536 | else:
537 | past_length = past[0][0].size(-2)
538 | if position_ids is None:
539 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
540 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
541 |
542 | input_shape = input_ids.size()
543 | input_ids = input_ids.view(-1, input_ids.size(-1))
544 | position_ids = position_ids.view(-1, position_ids.size(-1))
545 |
546 | inputs_embeds = self.wte(input_ids)
547 | position_embeds = self.wpe(position_ids)
548 | if token_type_ids is not None:
549 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
550 | token_type_embeds = self.wte(token_type_ids)
551 | else:
552 | token_type_embeds = 0
553 | hidden_states = inputs_embeds + position_embeds + token_type_embeds
554 | presents = []
555 | for block, layer_past in zip(self.h, past):
556 | hidden_states, present = block(hidden_states, layer_past)
557 | presents.append(present)
558 | hidden_states = self.ln_f(hidden_states)
559 | output_shape = input_shape + (hidden_states.size(-1),)
560 | return hidden_states.view(*output_shape), presents
561 |
562 |
563 | class GPT2LMHeadModel(GPT2PreTrainedModel):
564 | """OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
565 |
566 | Params:
567 | config: a GPT2Config class instance with the configuration to build a new model
568 |
569 | Inputs:
570 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
571 | were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
572 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids
573 | with the position indices (selected in the range [0, config.n_positions - 1[.
574 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
575 | You can use it to add a third type of embedding to each input token in the sequence
576 | (the previous two being the word and position embeddings).
577 | The input, position and token_type embeddings are summed inside the Transformer before the first
578 | self-attention block.
579 | `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
580 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
581 | is only computed for the labels set in [0, ..., vocab_size]
582 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
583 | (key and values in the attention blocks) to speed up sequential decoding
584 | (this is the presents output of the model, cf. below).
585 |
586 | Outputs:
587 | if `lm_labels` is not `None`:
588 | Outputs the language modeling loss.
589 | else a tuple:
590 | `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
591 | (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
592 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
593 | torch.FloatTensors. They can be reused to speed up sequential decoding.
594 |
595 | Example usage:
596 | ```python
597 | # Already been converted into BPE token ids
598 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
599 |
600 | config = modeling_gpt2.GPT2Config()
601 |
602 | model = modeling_gpt2.GPT2LMHeadModel(config)
603 | lm_logits, presents = model(input_ids)
604 | ```
605 | """
606 |
607 | def __init__(self, config):
608 | super(GPT2LMHeadModel, self).__init__(config)
609 | self.transformer = GPT2Model(config)
610 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
611 | self.apply(self.init_weights)
612 |
613 | def set_tied(self):
614 | """ Make sure we are sharing the embeddings
615 | """
616 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
617 |
618 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
619 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
620 | lm_logits = self.lm_head(hidden_states)
621 | if lm_labels is not None:
622 | # Shift so that tokens < n predict n
623 | shift_logits = lm_logits[:, :-1].contiguous()
624 | shift_labels = lm_labels[:, 1:].contiguous()
625 |
626 | # Flatten the tokens
627 | loss_fct = CrossEntropyLoss(ignore_index=-1)
628 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
629 | shift_labels.view(-1))
630 | return loss
631 | return lm_logits, presents
632 |
633 |
634 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
635 | """OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
636 |
637 | Params:
638 | config: a GPT2Config class instance with the configuration to build a new model
639 |
640 | Inputs:
641 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
642 | indices selected in the range [0, config.vocab_size[
643 | `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
644 | which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
645 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids
646 | with the position indices (selected in the range [0, config.n_positions - 1[.
647 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
648 | You can use it to add a third type of embedding to each input token in the sequence
649 | (the previous two being the word and position embeddings).
650 | The input, position and token_type embeddings are summed inside the Transformer before the first
651 | self-attention block.
652 | `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
653 | with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
654 | is only computed for the labels set in [0, ..., config.vocab_size]
655 | `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
656 | with indices selected in [0, ..., num_choices].
657 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
658 | (key and values in the attention blocks) to speed up sequential decoding
659 | (this is the presents output of the model, cf. below).
660 |
661 | Outputs:
662 | if `lm_labels` and `multiple_choice_labels` are not `None`:
663 | Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
664 | else: a tuple with
665 | `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
666 | `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
667 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
668 | torch.FloatTensors. They can be reused to speed up sequential decoding.
669 |
670 | Example usage:
671 | ```python
672 | # Already been converted into BPE token ids
673 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
674 | mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
675 |
676 | config = modeling_gpt2.GPT2Config()
677 |
678 | model = modeling_gpt2.GPT2LMHeadModel(config)
679 | lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids)
680 | ```
681 | """
682 |
683 | def __init__(self, config):
684 | super(GPT2DoubleHeadsModel, self).__init__(config)
685 | self.transformer = GPT2Model(config)
686 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
687 | self.multiple_choice_head = GPT2MultipleChoiceHead(config)
688 | self.apply(self.init_weights)
689 |
690 | def set_tied(self):
691 | """ Make sure we are sharing the embeddings
692 | """
693 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
694 |
695 | def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
696 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
697 | lm_logits = self.lm_head(hidden_states)
698 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
699 | losses = []
700 | if lm_labels is not None:
701 | shift_logits = lm_logits[:, :-1].contiguous()
702 | shift_labels = lm_labels[:, 1:].contiguous()
703 | loss_fct = CrossEntropyLoss(ignore_index=-1)
704 | losses.append(loss_fct(shift_logits.view(-1,
705 | shift_logits.size(-1)), shift_labels.view(-1)))
706 | if mc_labels is not None:
707 | loss_fct = CrossEntropyLoss()
708 | losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
709 | if losses:
710 | return losses
711 | return lm_logits, mc_logits, presents
712 |
--------------------------------------------------------------------------------
/pytorch_pretrained/modeling_transfo_xl_utilities.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Utilities for PyTorch Transformer XL model.
17 | Directly adapted from https://github.com/kimiyoung/transformer-xl.
18 | """
19 |
20 | from collections import defaultdict
21 |
22 | import numpy as np
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1])
30 |
31 | class ProjectedAdaptiveLogSoftmax(nn.Module):
32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
33 | keep_order=False):
34 | super(ProjectedAdaptiveLogSoftmax, self).__init__()
35 |
36 | self.n_token = n_token
37 | self.d_embed = d_embed
38 | self.d_proj = d_proj
39 |
40 | self.cutoffs = cutoffs + [n_token]
41 | self.cutoff_ends = [0] + self.cutoffs
42 | self.div_val = div_val
43 |
44 | self.shortlist_size = self.cutoffs[0]
45 | self.n_clusters = len(self.cutoffs) - 1
46 | self.head_size = self.shortlist_size + self.n_clusters
47 |
48 | if self.n_clusters > 0:
49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
51 |
52 | self.out_layers = nn.ModuleList()
53 | self.out_projs = nn.ParameterList()
54 |
55 | if div_val == 1:
56 | for i in range(len(self.cutoffs)):
57 | if d_proj != d_embed:
58 | self.out_projs.append(
59 | nn.Parameter(torch.Tensor(d_proj, d_embed))
60 | )
61 | else:
62 | self.out_projs.append(None)
63 |
64 | self.out_layers.append(nn.Linear(d_embed, n_token))
65 | else:
66 | for i in range(len(self.cutoffs)):
67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
68 | d_emb_i = d_embed // (div_val ** i)
69 |
70 | self.out_projs.append(
71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i))
72 | )
73 |
74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
75 |
76 | self.keep_order = keep_order
77 |
78 | def _compute_logit(self, hidden, weight, bias, proj):
79 | if proj is None:
80 | logit = F.linear(hidden, weight, bias=bias)
81 | else:
82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
83 | proj_hid = F.linear(hidden, proj.t().contiguous())
84 | logit = F.linear(proj_hid, weight, bias=bias)
85 | # else:
86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
87 | # if bias is not None:
88 | # logit = logit + bias
89 |
90 | return logit
91 |
92 | def forward(self, hidden, target=None, keep_order=False):
93 | '''
94 | Params:
95 | hidden :: [len*bsz x d_proj]
96 | target :: [len*bsz]
97 | Return:
98 | if target is None:
99 | out :: [len*bsz] Negative log likelihood
100 | else:
101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
102 | We could replace this implementation by the native PyTorch one
103 | if their's had an option to set bias on all clusters in the native one.
104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
105 | '''
106 |
107 | if target is not None:
108 | target = target.view(-1)
109 | if hidden.size(0) != target.size(0):
110 | raise RuntimeError('Input and target should have the same size '
111 | 'in the batch dimension.')
112 |
113 | if self.n_clusters == 0:
114 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
115 | self.out_layers[0].bias, self.out_projs[0])
116 | if target is not None:
117 | output = -F.log_softmax(logit, dim=-1) \
118 | .gather(1, target.unsqueeze(1)).squeeze(1)
119 | else:
120 | output = F.log_softmax(logit, dim=-1)
121 | else:
122 | # construct weights and biases
123 | weights, biases = [], []
124 | for i in range(len(self.cutoffs)):
125 | if self.div_val == 1:
126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
127 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
128 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
129 | else:
130 | weight_i = self.out_layers[i].weight
131 | bias_i = self.out_layers[i].bias
132 |
133 | if i == 0:
134 | weight_i = torch.cat(
135 | [weight_i, self.cluster_weight], dim=0)
136 | bias_i = torch.cat(
137 | [bias_i, self.cluster_bias], dim=0)
138 |
139 | weights.append(weight_i)
140 | biases.append(bias_i)
141 |
142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
143 |
144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
145 | head_logprob = F.log_softmax(head_logit, dim=1)
146 |
147 | if target is None:
148 | out = hidden.new_empty((head_logit.size(0), self.n_token))
149 | else:
150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
151 |
152 | offset = 0
153 | cutoff_values = [0] + self.cutoffs
154 | for i in range(len(cutoff_values) - 1):
155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
156 |
157 | if target is not None:
158 | mask_i = (target >= l_idx) & (target < r_idx)
159 | indices_i = mask_i.nonzero().squeeze()
160 |
161 | if indices_i.numel() == 0:
162 | continue
163 |
164 | target_i = target.index_select(0, indices_i) - l_idx
165 | head_logprob_i = head_logprob.index_select(0, indices_i)
166 | hidden_i = hidden.index_select(0, indices_i)
167 | else:
168 | hidden_i = hidden
169 |
170 | if i == 0:
171 | if target is not None:
172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
173 | else:
174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
175 | else:
176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
177 |
178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
181 | if target is not None:
182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \
183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
184 | else:
185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
186 | out[:, l_idx:r_idx] = logprob_i
187 |
188 | if target is not None:
189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
190 | out.index_copy_(0, indices_i, -logprob_i)
191 | else:
192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
193 | offset += logprob_i.size(0)
194 |
195 | return out
196 |
197 |
198 | def log_prob(self, hidden):
199 | r""" Computes log probabilities for all :math:`n\_classes`
200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
201 | Args:
202 | hidden (Tensor): a minibatch of examples
203 | Returns:
204 | log-probabilities of for each class :math:`c`
205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
207 | Shape:
208 | - Input: :math:`(N, in\_features)`
209 | - Output: :math:`(N, n\_classes)`
210 | """
211 | if self.n_clusters == 0:
212 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
213 | self.out_layers[0].bias, self.out_projs[0])
214 | return F.log_softmax(logit, dim=-1)
215 | else:
216 | # construct weights and biases
217 | weights, biases = [], []
218 | for i in range(len(self.cutoffs)):
219 | if self.div_val == 1:
220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
221 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
222 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
223 | else:
224 | weight_i = self.out_layers[i].weight
225 | bias_i = self.out_layers[i].bias
226 |
227 | if i == 0:
228 | weight_i = torch.cat(
229 | [weight_i, self.cluster_weight], dim=0)
230 | bias_i = torch.cat(
231 | [bias_i, self.cluster_bias], dim=0)
232 |
233 | weights.append(weight_i)
234 | biases.append(bias_i)
235 |
236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
238 |
239 | out = hidden.new_empty((head_logit.size(0), self.n_token))
240 | head_logprob = F.log_softmax(head_logit, dim=1)
241 |
242 | cutoff_values = [0] + self.cutoffs
243 | for i in range(len(cutoff_values) - 1):
244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
245 |
246 | if i == 0:
247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
248 | else:
249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
250 |
251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
253 |
254 | logprob_i = head_logprob[:, -i] + tail_logprob_i
255 | out[:, start_idx, stop_idx] = logprob_i
256 |
257 | return out
258 |
259 |
260 | class LogUniformSampler(object):
261 | def __init__(self, range_max, n_sample):
262 | """
263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
265 |
266 | expected count can be approximated by 1 - (1 - p)^n
267 | and we use a numerically stable version -expm1(num_tries * log1p(-p))
268 |
269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
270 | """
271 | with torch.no_grad():
272 | self.range_max = range_max
273 | log_indices = torch.arange(1., range_max+2., 1.).log_()
274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
275 | # print('P', self.dist.numpy().tolist()[-30:])
276 |
277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
278 |
279 | self.n_sample = n_sample
280 |
281 | def sample(self, labels):
282 | """
283 | labels: [b1, b2]
284 | Return
285 | true_log_probs: [b1, b2]
286 | samp_log_probs: [n_sample]
287 | neg_samples: [n_sample]
288 | """
289 |
290 | # neg_samples = torch.empty(0).long()
291 | n_sample = self.n_sample
292 | n_tries = 2 * n_sample
293 |
294 | with torch.no_grad():
295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
296 | device = labels.device
297 | neg_samples = neg_samples.to(device)
298 | true_log_probs = self.log_q[labels].to(device)
299 | samp_log_probs = self.log_q[neg_samples].to(device)
300 | return true_log_probs, samp_log_probs, neg_samples
301 |
302 | def sample_logits(embedding, bias, labels, inputs, sampler):
303 | """
304 | embedding: an nn.Embedding layer
305 | bias: [n_vocab]
306 | labels: [b1, b2]
307 | inputs: [b1, b2, n_emb]
308 | sampler: you may use a LogUniformSampler
309 | Return
310 | logits: [b1, b2, 1 + n_sample]
311 | """
312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
313 | n_sample = neg_samples.size(0)
314 | b1, b2 = labels.size(0), labels.size(1)
315 | all_ids = torch.cat([labels.view(-1), neg_samples])
316 | all_w = embedding(all_ids)
317 | true_w = all_w[: -n_sample].view(b1, b2, -1)
318 | sample_w = all_w[- n_sample:].view(n_sample, -1)
319 |
320 | all_b = bias[all_ids]
321 | true_b = all_b[: -n_sample].view(b1, b2)
322 | sample_b = all_b[- n_sample:]
323 |
324 | hit = (labels[:, :, None] == neg_samples).detach()
325 |
326 | true_logits = torch.einsum('ijk,ijk->ij',
327 | [true_w, inputs]) + true_b - true_log_probs
328 | sample_logits = torch.einsum('lk,ijk->ijl',
329 | [sample_w, inputs]) + sample_b - samp_log_probs
330 | sample_logits.masked_fill_(hit, -1e30)
331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
332 |
333 | return logits
334 |
335 |
336 | # class LogUniformSampler(object):
337 | # def __init__(self, range_max, unique=False):
338 | # """
339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
341 | # """
342 | # self.range_max = range_max
343 | # log_indices = torch.arange(1., range_max+2., 1.).log_()
344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
345 |
346 | # self.unique = unique
347 |
348 | # if self.unique:
349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
350 |
351 | # def sample(self, n_sample, labels):
352 | # pos_sample, new_labels = labels.unique(return_inverse=True)
353 | # n_pos_sample = pos_sample.size(0)
354 | # n_neg_sample = n_sample - n_pos_sample
355 |
356 | # if self.unique:
357 | # self.exclude_mask.index_fill_(0, pos_sample, 1)
358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
359 | # self.exclude_mask.index_fill_(0, pos_sample, 0)
360 | # else:
361 | # sample_dist = self.dist
362 |
363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample)
364 |
365 | # sample = torch.cat([pos_sample, neg_sample])
366 | # sample_prob = self.dist[sample]
367 |
368 | # return new_labels, sample, sample_prob
369 |
370 |
371 | if __name__ == '__main__':
372 | S, B = 3, 4
373 | n_vocab = 10000
374 | n_sample = 5
375 | H = 32
376 |
377 | labels = torch.LongTensor(S, B).random_(0, n_vocab)
378 |
379 | # sampler = LogUniformSampler(n_vocab, unique=False)
380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
381 |
382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
384 |
385 | # print('true_probs', true_probs.numpy().tolist())
386 | # print('samp_probs', samp_probs.numpy().tolist())
387 | # print('neg_samples', neg_samples.numpy().tolist())
388 |
389 | # print('sum', torch.sum(sampler.dist).item())
390 |
391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
392 |
393 | embedding = nn.Embedding(n_vocab, H)
394 | bias = torch.zeros(n_vocab)
395 | inputs = torch.Tensor(S, B, H).normal_()
396 |
397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
398 | print('logits', logits.detach().numpy().tolist())
399 | print('logits shape', logits.size())
400 | print('out_labels', out_labels.detach().numpy().tolist())
401 | print('out_labels shape', out_labels.size())
402 |
403 |
--------------------------------------------------------------------------------
/pytorch_pretrained/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 | import abc
24 | import sys
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | if sys.version_info >= (3, 4):
30 | ABC = abc.ABC
31 | else:
32 | ABC = abc.ABCMeta('ABC', (), {})
33 |
34 |
35 | class _LRSchedule(ABC):
36 | """ Parent of all LRSchedules here. """
37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
38 | def __init__(self, warmup=0.002, t_total=-1, **kw):
39 | """
40 | :param warmup: what fraction of t_total steps will be used for linear warmup
41 | :param t_total: how many training steps (updates) are planned
42 | :param kw:
43 | """
44 | super(_LRSchedule, self).__init__(**kw)
45 | if t_total < 0:
46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
47 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
49 | warmup = max(warmup, 0.)
50 | self.warmup, self.t_total = float(warmup), float(t_total)
51 | self.warned_for_t_total_at_progress = -1
52 |
53 | def get_lr(self, step, nowarn=False):
54 | """
55 | :param step: which of t_total steps we're on
56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
57 | :return: learning rate multiplier for current update
58 | """
59 | if self.t_total < 0:
60 | return 1.
61 | progress = float(step) / self.t_total
62 | ret = self.get_lr_(progress)
63 | # warning for exceeding t_total (only active with warmup_linear
64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
65 | logger.warning(
66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
67 | .format(ret, self.__class__.__name__))
68 | self.warned_for_t_total_at_progress = progress
69 | # end warning
70 | return ret
71 |
72 | @abc.abstractmethod
73 | def get_lr_(self, progress):
74 | """
75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
76 | :return: learning rate multiplier for current update
77 | """
78 | return 1.
79 |
80 |
81 | class ConstantLR(_LRSchedule):
82 | def get_lr_(self, progress):
83 | return 1.
84 |
85 |
86 | class WarmupCosineSchedule(_LRSchedule):
87 | """
88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
91 | """
92 | warn_t_total = True
93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
94 | """
95 | :param warmup: see LRSchedule
96 | :param t_total: see LRSchedule
97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
98 | :param kw:
99 | """
100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
101 | self.cycles = cycles
102 |
103 | def get_lr_(self, progress):
104 | if progress < self.warmup:
105 | return progress / self.warmup
106 | else:
107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
109 |
110 |
111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
112 | """
113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
115 | learning rate (with hard restarts).
116 | """
117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
119 | assert(cycles >= 1.)
120 |
121 | def get_lr_(self, progress):
122 | if progress < self.warmup:
123 | return progress / self.warmup
124 | else:
125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
127 | return ret
128 |
129 |
130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
131 | """
132 | All training progress is divided in `cycles` (default=1.) parts of equal length.
133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve.
135 | """
136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
137 | assert(warmup * cycles < 1.)
138 | warmup = warmup * cycles if warmup >= 0 else warmup
139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
140 |
141 | def get_lr_(self, progress):
142 | progress = progress * self.cycles % 1.
143 | if progress < self.warmup:
144 | return progress / self.warmup
145 | else:
146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
147 | ret = 0.5 * (1. + math.cos(math.pi * progress))
148 | return ret
149 |
150 |
151 | class WarmupConstantSchedule(_LRSchedule):
152 | """
153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
154 | Keeps learning rate equal to 1. after warmup.
155 | """
156 | def get_lr_(self, progress):
157 | if progress < self.warmup:
158 | return progress / self.warmup
159 | return 1.
160 |
161 |
162 | class WarmupLinearSchedule(_LRSchedule):
163 | """
164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
166 | """
167 | warn_t_total = True
168 | def get_lr_(self, progress):
169 | if progress < self.warmup:
170 | return progress / self.warmup
171 | return max((progress - 1.) / (self.warmup - 1.), 0.)
172 |
173 |
174 | SCHEDULES = {
175 | None: ConstantLR,
176 | "none": ConstantLR,
177 | "warmup_cosine": WarmupCosineSchedule,
178 | "warmup_constant": WarmupConstantSchedule,
179 | "warmup_linear": WarmupLinearSchedule
180 | }
181 |
182 |
183 | class BertAdam(Optimizer):
184 | """Implements BERT version of Adam algorithm with weight decay fix.
185 | Params:
186 | lr: learning rate
187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
188 | t_total: total number of training steps for the learning
189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
190 | schedule: schedule to use for the warmup (see above).
191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
192 | If `None` or `'none'`, learning rate is always kept constant.
193 | Default : `'warmup_linear'`
194 | b1: Adams b1. Default: 0.9
195 | b2: Adams b2. Default: 0.999
196 | e: Adams epsilon. Default: 1e-6
197 | weight_decay: Weight decay. Default: 0.01
198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
199 | """
200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
202 | if lr is not required and lr < 0.0:
203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
205 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
206 | if not 0.0 <= b1 < 1.0:
207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
208 | if not 0.0 <= b2 < 1.0:
209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
210 | if not e >= 0.0:
211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
212 | # initialize schedule object
213 | if not isinstance(schedule, _LRSchedule):
214 | schedule_type = SCHEDULES[schedule]
215 | schedule = schedule_type(warmup=warmup, t_total=t_total)
216 | else:
217 | if warmup != -1 or t_total != -1:
218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
219 | "Please specify custom warmup and t_total in _LRSchedule object.")
220 | defaults = dict(lr=lr, schedule=schedule,
221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
222 | max_grad_norm=max_grad_norm)
223 | super(BertAdam, self).__init__(params, defaults)
224 |
225 | def get_lr(self):
226 | lr = []
227 | for group in self.param_groups:
228 | for p in group['params']:
229 | state = self.state[p]
230 | if len(state) == 0:
231 | return [0]
232 | lr_scheduled = group['lr']
233 | lr_scheduled *= group['schedule'].get_lr(state['step'])
234 | lr.append(lr_scheduled)
235 | return lr
236 |
237 | def step(self, closure=None):
238 | """Performs a single optimization step.
239 |
240 | Arguments:
241 | closure (callable, optional): A closure that reevaluates the model
242 | and returns the loss.
243 | """
244 | loss = None
245 | if closure is not None:
246 | loss = closure()
247 |
248 | for group in self.param_groups:
249 | for p in group['params']:
250 | if p.grad is None:
251 | continue
252 | grad = p.grad.data
253 | if grad.is_sparse:
254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
255 |
256 | state = self.state[p]
257 |
258 | # State initialization
259 | if len(state) == 0:
260 | state['step'] = 0
261 | # Exponential moving average of gradient values
262 | state['next_m'] = torch.zeros_like(p.data)
263 | # Exponential moving average of squared gradient values
264 | state['next_v'] = torch.zeros_like(p.data)
265 |
266 | next_m, next_v = state['next_m'], state['next_v']
267 | beta1, beta2 = group['b1'], group['b2']
268 |
269 | # Add grad clipping
270 | if group['max_grad_norm'] > 0:
271 | clip_grad_norm_(p, group['max_grad_norm'])
272 |
273 | # Decay the first and second moment running average coefficient
274 | # In-place operations to update the averages at the same time
275 | next_m.mul_(beta1).add_(1 - beta1, grad)
276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
277 | update = next_m / (next_v.sqrt() + group['e'])
278 |
279 | # Just adding the square of the weights to the loss function is *not*
280 | # the correct way of using L2 regularization/weight decay with Adam,
281 | # since that will interact with the m and v parameters in strange ways.
282 | #
283 | # Instead we want to decay the weights in a manner that doesn't interact
284 | # with the m/v parameters. This is equivalent to adding the square
285 | # of the weights to the loss with plain (non-momentum) SGD.
286 | if group['weight_decay'] > 0.0:
287 | update += group['weight_decay'] * p.data
288 |
289 | lr_scheduled = group['lr']
290 | lr_scheduled *= group['schedule'].get_lr(state['step'])
291 |
292 | update_with_lr = lr_scheduled * update
293 | p.data.add_(-update_with_lr)
294 |
295 | state['step'] += 1
296 |
297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
298 | # No bias correction
299 | # bias_correction1 = 1 - beta1 ** state['step']
300 | # bias_correction2 = 1 - beta2 ** state['step']
301 |
302 | return loss
303 |
--------------------------------------------------------------------------------
/pytorch_pretrained/optimization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for OpenAI GPT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | class OpenAIAdam(Optimizer):
30 | """Implements Open AI version of Adam algorithm with weight decay fix.
31 | """
32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
34 | vector_l2=False, max_grad_norm=-1, **kwargs):
35 | if lr is not required and lr < 0.0:
36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
38 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
39 | if not 0.0 <= b1 < 1.0:
40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
41 | if not 0.0 <= b2 < 1.0:
42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
43 | if not e >= 0.0:
44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
45 | # initialize schedule object
46 | if not isinstance(schedule, _LRSchedule):
47 | schedule_type = SCHEDULES[schedule]
48 | schedule = schedule_type(warmup=warmup, t_total=t_total)
49 | else:
50 | if warmup != -1 or t_total != -1:
51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
52 | "Please specify custom warmup and t_total in _LRSchedule object.")
53 | defaults = dict(lr=lr, schedule=schedule,
54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
55 | max_grad_norm=max_grad_norm)
56 | super(OpenAIAdam, self).__init__(params, defaults)
57 |
58 | def get_lr(self):
59 | lr = []
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | state = self.state[p]
63 | if len(state) == 0:
64 | return [0]
65 | lr_scheduled = group['lr']
66 | lr_scheduled *= group['schedule'].get_lr(state['step'])
67 | lr.append(lr_scheduled)
68 | return lr
69 |
70 | def step(self, closure=None):
71 | """Performs a single optimization step.
72 |
73 | Arguments:
74 | closure (callable, optional): A closure that reevaluates the model
75 | and returns the loss.
76 | """
77 | loss = None
78 | if closure is not None:
79 | loss = closure()
80 |
81 | for group in self.param_groups:
82 | for p in group['params']:
83 | if p.grad is None:
84 | continue
85 | grad = p.grad.data
86 | if grad.is_sparse:
87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
88 |
89 | state = self.state[p]
90 |
91 | # State initialization
92 | if len(state) == 0:
93 | state['step'] = 0
94 | # Exponential moving average of gradient values
95 | state['exp_avg'] = torch.zeros_like(p.data)
96 | # Exponential moving average of squared gradient values
97 | state['exp_avg_sq'] = torch.zeros_like(p.data)
98 |
99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
100 | beta1, beta2 = group['b1'], group['b2']
101 |
102 | state['step'] += 1
103 |
104 | # Add grad clipping
105 | if group['max_grad_norm'] > 0:
106 | clip_grad_norm_(p, group['max_grad_norm'])
107 |
108 | # Decay the first and second moment running average coefficient
109 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
111 | denom = exp_avg_sq.sqrt().add_(group['e'])
112 |
113 | bias_correction1 = 1 - beta1 ** state['step']
114 | bias_correction2 = 1 - beta2 ** state['step']
115 |
116 | lr_scheduled = group['lr']
117 | lr_scheduled *= group['schedule'].get_lr(state['step'])
118 |
119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
120 |
121 | p.data.addcdiv_(-step_size, exp_avg, denom)
122 |
123 | # Add weight decay at the end (fixed version)
124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
126 |
127 | return loss
128 |
--------------------------------------------------------------------------------
/pytorch_pretrained/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from .file_utils import cached_path
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
37 | }
38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
39 | 'bert-base-uncased': 512,
40 | 'bert-large-uncased': 512,
41 | 'bert-base-cased': 512,
42 | 'bert-large-cased': 512,
43 | 'bert-base-multilingual-uncased': 512,
44 | 'bert-base-multilingual-cased': 512,
45 | 'bert-base-chinese': 512,
46 | }
47 | VOCAB_NAME = 'vocab.txt'
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | index = 0
54 | with open(vocab_file, "r", encoding="utf-8") as reader:
55 | while True:
56 | token = reader.readline()
57 | if not token:
58 | break
59 | token = token.strip()
60 | vocab[token] = index
61 | index += 1
62 | return vocab
63 |
64 |
65 | def whitespace_tokenize(text):
66 | """Runs basic whitespace cleaning and splitting on a piece of text."""
67 | text = text.strip()
68 | if not text:
69 | return []
70 | tokens = text.split()
71 | return tokens
72 |
73 |
74 | class BertTokenizer(object):
75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
76 |
77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
79 | """Constructs a BertTokenizer.
80 |
81 | Args:
82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
83 | do_lower_case: Whether to lower case the input
84 | Only has an effect when do_wordpiece_only=False
85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
86 | max_len: An artificial maximum length to truncate tokenized sequences to;
87 | Effective maximum length is always the minimum of this
88 | value (if specified) and the underlying BERT model's
89 | sequence length.
90 | never_split: List of tokens which will never be split during tokenization.
91 | Only has an effect when do_wordpiece_only=False
92 | """
93 | if not os.path.isfile(vocab_file):
94 | raise ValueError(
95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
97 | self.vocab = load_vocab(vocab_file)
98 | self.ids_to_tokens = collections.OrderedDict(
99 | [(ids, tok) for tok, ids in self.vocab.items()])
100 | self.do_basic_tokenize = do_basic_tokenize
101 | if do_basic_tokenize:
102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
103 | never_split=never_split)
104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
105 | self.max_len = max_len if max_len is not None else int(1e12)
106 |
107 | def tokenize(self, text):
108 | split_tokens = []
109 | if self.do_basic_tokenize:
110 | for token in self.basic_tokenizer.tokenize(text):
111 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
112 | split_tokens.append(sub_token)
113 | else:
114 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
115 | return split_tokens
116 |
117 | def convert_tokens_to_ids(self, tokens):
118 | """Converts a sequence of tokens into ids using the vocab."""
119 | ids = []
120 | for token in tokens:
121 | ids.append(self.vocab[token])
122 | if len(ids) > self.max_len:
123 | logger.warning(
124 | "Token indices sequence length is longer than the specified maximum "
125 | " sequence length for this BERT model ({} > {}). Running this"
126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
127 | )
128 | return ids
129 |
130 | def convert_ids_to_tokens(self, ids):
131 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
132 | tokens = []
133 | for i in ids:
134 | tokens.append(self.ids_to_tokens[i])
135 | return tokens
136 |
137 | def save_vocabulary(self, vocab_path):
138 | """Save the tokenizer vocabulary to a directory or file."""
139 | index = 0
140 | if os.path.isdir(vocab_path):
141 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
142 | with open(vocab_file, "w", encoding="utf-8") as writer:
143 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
144 | if index != token_index:
145 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
146 | " Please check that the vocabulary is not corrupted!".format(vocab_file))
147 | index = token_index
148 | writer.write(token + u'\n')
149 | index += 1
150 | return vocab_file
151 |
152 | @classmethod
153 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
154 | """
155 | Instantiate a PreTrainedBertModel from a pre-trained model file.
156 | Download and cache the pre-trained model file if needed.
157 | """
158 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
159 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
160 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
161 | logger.warning("The pre-trained model you are loading is a cased model but you have not set "
162 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
163 | "you may want to check this behavior.")
164 | kwargs['do_lower_case'] = False
165 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
166 | logger.warning("The pre-trained model you are loading is an uncased model but you have set "
167 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
168 | "but you may want to check this behavior.")
169 | kwargs['do_lower_case'] = True
170 | else:
171 | vocab_file = pretrained_model_name_or_path
172 | if os.path.isdir(vocab_file):
173 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
174 | # redirect to the cache, if necessary
175 | try:
176 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
177 | except EnvironmentError:
178 | logger.error(
179 | "Model name '{}' was not found in model name list ({}). "
180 | "We assumed '{}' was a path or url but couldn't find any file "
181 | "associated to this path or url.".format(
182 | pretrained_model_name_or_path,
183 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
184 | vocab_file))
185 | return None
186 | if resolved_vocab_file == vocab_file:
187 | logger.info("loading vocabulary file {}".format(vocab_file))
188 | else:
189 | logger.info("loading vocabulary file {} from cache at {}".format(
190 | vocab_file, resolved_vocab_file))
191 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
192 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
193 | # than the number of positional embeddings
194 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
195 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
196 | # Instantiate tokenizer.
197 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
198 | return tokenizer
199 |
200 |
201 | class BasicTokenizer(object):
202 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
203 |
204 | def __init__(self,
205 | do_lower_case=True,
206 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
207 | """Constructs a BasicTokenizer.
208 |
209 | Args:
210 | do_lower_case: Whether to lower case the input.
211 | """
212 | self.do_lower_case = do_lower_case
213 | self.never_split = never_split
214 |
215 | def tokenize(self, text):
216 | """Tokenizes a piece of text."""
217 | text = self._clean_text(text)
218 | # This was added on November 1st, 2018 for the multilingual and Chinese
219 | # models. This is also applied to the English models now, but it doesn't
220 | # matter since the English models were not trained on any Chinese data
221 | # and generally don't have any Chinese data in them (there are Chinese
222 | # characters in the vocabulary because Wikipedia does have some Chinese
223 | # words in the English Wikipedia.).
224 | text = self._tokenize_chinese_chars(text)
225 | orig_tokens = whitespace_tokenize(text)
226 | split_tokens = []
227 | for token in orig_tokens:
228 | if self.do_lower_case and token not in self.never_split:
229 | token = token.lower()
230 | token = self._run_strip_accents(token)
231 | split_tokens.extend(self._run_split_on_punc(token))
232 |
233 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
234 | return output_tokens
235 |
236 | def _run_strip_accents(self, text):
237 | """Strips accents from a piece of text."""
238 | text = unicodedata.normalize("NFD", text)
239 | output = []
240 | for char in text:
241 | cat = unicodedata.category(char)
242 | if cat == "Mn":
243 | continue
244 | output.append(char)
245 | return "".join(output)
246 |
247 | def _run_split_on_punc(self, text):
248 | """Splits punctuation on a piece of text."""
249 | if text in self.never_split:
250 | return [text]
251 | chars = list(text)
252 | i = 0
253 | start_new_word = True
254 | output = []
255 | while i < len(chars):
256 | char = chars[i]
257 | if _is_punctuation(char):
258 | output.append([char])
259 | start_new_word = True
260 | else:
261 | if start_new_word:
262 | output.append([])
263 | start_new_word = False
264 | output[-1].append(char)
265 | i += 1
266 |
267 | return ["".join(x) for x in output]
268 |
269 | def _tokenize_chinese_chars(self, text):
270 | """Adds whitespace around any CJK character."""
271 | output = []
272 | for char in text:
273 | cp = ord(char)
274 | if self._is_chinese_char(cp):
275 | output.append(" ")
276 | output.append(char)
277 | output.append(" ")
278 | else:
279 | output.append(char)
280 | return "".join(output)
281 |
282 | def _is_chinese_char(self, cp):
283 | """Checks whether CP is the codepoint of a CJK character."""
284 | # This defines a "chinese character" as anything in the CJK Unicode block:
285 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
286 | #
287 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
288 | # despite its name. The modern Korean Hangul alphabet is a different block,
289 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
290 | # space-separated words, so they are not treated specially and handled
291 | # like the all of the other languages.
292 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
293 | (cp >= 0x3400 and cp <= 0x4DBF) or #
294 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
295 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
296 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
297 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
298 | (cp >= 0xF900 and cp <= 0xFAFF) or #
299 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
300 | return True
301 |
302 | return False
303 |
304 | def _clean_text(self, text):
305 | """Performs invalid character removal and whitespace cleanup on text."""
306 | output = []
307 | for char in text:
308 | cp = ord(char)
309 | if cp == 0 or cp == 0xfffd or _is_control(char):
310 | continue
311 | if _is_whitespace(char):
312 | output.append(" ")
313 | else:
314 | output.append(char)
315 | return "".join(output)
316 |
317 |
318 | class WordpieceTokenizer(object):
319 | """Runs WordPiece tokenization."""
320 |
321 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
322 | self.vocab = vocab
323 | self.unk_token = unk_token
324 | self.max_input_chars_per_word = max_input_chars_per_word
325 |
326 | def tokenize(self, text):
327 | """Tokenizes a piece of text into its word pieces.
328 |
329 | This uses a greedy longest-match-first algorithm to perform tokenization
330 | using the given vocabulary.
331 |
332 | For example:
333 | input = "unaffable"
334 | output = ["un", "##aff", "##able"]
335 |
336 | Args:
337 | text: A single token or whitespace separated tokens. This should have
338 | already been passed through `BasicTokenizer`.
339 |
340 | Returns:
341 | A list of wordpiece tokens.
342 | """
343 |
344 | output_tokens = []
345 | for token in whitespace_tokenize(text):
346 | chars = list(token)
347 | if len(chars) > self.max_input_chars_per_word:
348 | output_tokens.append(self.unk_token)
349 | continue
350 |
351 | is_bad = False
352 | start = 0
353 | sub_tokens = []
354 | while start < len(chars):
355 | end = len(chars)
356 | cur_substr = None
357 | while start < end:
358 | substr = "".join(chars[start:end])
359 | if start > 0:
360 | substr = "##" + substr
361 | if substr in self.vocab:
362 | cur_substr = substr
363 | break
364 | end -= 1
365 | if cur_substr is None:
366 | is_bad = True
367 | break
368 | sub_tokens.append(cur_substr)
369 | start = end
370 |
371 | if is_bad:
372 | output_tokens.append(self.unk_token)
373 | else:
374 | output_tokens.extend(sub_tokens)
375 | return output_tokens
376 |
377 |
378 | def _is_whitespace(char):
379 | """Checks whether `chars` is a whitespace character."""
380 | # \t, \n, and \r are technically contorl characters but we treat them
381 | # as whitespace since they are generally considered as such.
382 | if char == " " or char == "\t" or char == "\n" or char == "\r":
383 | return True
384 | cat = unicodedata.category(char)
385 | if cat == "Zs":
386 | return True
387 | return False
388 |
389 |
390 | def _is_control(char):
391 | """Checks whether `chars` is a control character."""
392 | # These are technically control characters but we count them as whitespace
393 | # characters.
394 | if char == "\t" or char == "\n" or char == "\r":
395 | return False
396 | cat = unicodedata.category(char)
397 | if cat.startswith("C"):
398 | return True
399 | return False
400 |
401 |
402 | def _is_punctuation(char):
403 | """Checks whether `chars` is a punctuation character."""
404 | cp = ord(char)
405 | # We treat all non-letter/number ASCII as punctuation.
406 | # Characters such as "^", "$", and "`" are not in the Unicode
407 | # Punctuation class but we treat them as punctuation anyways, for
408 | # consistency.
409 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
410 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
411 | return True
412 | cat = unicodedata.category(char)
413 | if cat.startswith("P"):
414 | return True
415 | return False
416 |
--------------------------------------------------------------------------------
/pytorch_pretrained/tokenization_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import sys
20 | import json
21 | import logging
22 | import os
23 | import regex as re
24 | from io import open
25 |
26 | try:
27 | from functools import lru_cache
28 | except ImportError:
29 | # Just a dummy decorator to get the checks to run on python2
30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
31 | def lru_cache():
32 | return lambda func: func
33 |
34 | from .file_utils import cached_path
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
40 | }
41 | PRETRAINED_MERGES_ARCHIVE_MAP = {
42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
43 | }
44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
45 | 'gpt2': 1024,
46 | }
47 | VOCAB_NAME = 'vocab.json'
48 | MERGES_NAME = 'merges.txt'
49 | SPECIAL_TOKENS_NAME = 'special_tokens.txt'
50 |
51 | @lru_cache()
52 | def bytes_to_unicode():
53 | """
54 | Returns list of utf-8 byte and a corresponding list of unicode strings.
55 | The reversible bpe codes work on unicode strings.
56 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
57 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
58 | This is a signficant percentage of your normal, say, 32K bpe vocab.
59 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
60 | And avoids mapping to whitespace/control characters the bpe code barfs on.
61 | """
62 | _chr = unichr if sys.version_info[0] == 2 else chr
63 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
64 | cs = bs[:]
65 | n = 0
66 | for b in range(2**8):
67 | if b not in bs:
68 | bs.append(b)
69 | cs.append(2**8+n)
70 | n += 1
71 | cs = [_chr(n) for n in cs]
72 | return dict(zip(bs, cs))
73 |
74 | def get_pairs(word):
75 | """Return set of symbol pairs in a word.
76 |
77 | Word is represented as tuple of symbols (symbols being variable-length strings).
78 | """
79 | pairs = set()
80 | prev_char = word[0]
81 | for char in word[1:]:
82 | pairs.add((prev_char, char))
83 | prev_char = char
84 | return pairs
85 |
86 | class GPT2Tokenizer(object):
87 | """
88 | GPT-2 BPE tokenizer. Peculiarities:
89 | - Byte-level BPE
90 | """
91 | @classmethod
92 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
93 | """
94 | Instantiate a PreTrainedBertModel from a pre-trained model file.
95 | Download and cache the pre-trained model file if needed.
96 | """
97 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
98 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
99 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
100 | special_tokens_file = None
101 | else:
102 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
103 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
104 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
105 | if not os.path.exists(special_tokens_file):
106 | special_tokens_file = None
107 | else:
108 | logger.info("loading special tokens file {}".format(special_tokens_file))
109 | # redirect to the cache, if necessary
110 | try:
111 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
112 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
113 | except EnvironmentError:
114 | logger.error(
115 | "Model name '{}' was not found in model name list ({}). "
116 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
117 | "at this path or url.".format(
118 | pretrained_model_name_or_path,
119 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
120 | pretrained_model_name_or_path,
121 | vocab_file, merges_file))
122 | return None
123 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
124 | logger.info("loading vocabulary file {}".format(vocab_file))
125 | logger.info("loading merges file {}".format(merges_file))
126 | else:
127 | logger.info("loading vocabulary file {} from cache at {}".format(
128 | vocab_file, resolved_vocab_file))
129 | logger.info("loading merges file {} from cache at {}".format(
130 | merges_file, resolved_merges_file))
131 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
132 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
133 | # than the number of positional embeddings
134 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
135 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
136 | # Instantiate tokenizer.
137 | if special_tokens_file and 'special_tokens' not in kwargs:
138 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
139 | else:
140 | special_tokens = kwargs.pop('special_tokens', [])
141 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
142 | return tokenizer
143 |
144 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
145 | self.max_len = max_len if max_len is not None else int(1e12)
146 | self.encoder = json.load(open(vocab_file))
147 | self.decoder = {v:k for k,v in self.encoder.items()}
148 | self.errors = errors # how to handle errors in decoding
149 | self.byte_encoder = bytes_to_unicode()
150 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
151 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
152 | bpe_merges = [tuple(merge.split()) for merge in bpe_data]
153 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
154 | self.cache = {}
155 |
156 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
157 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
158 |
159 | self.special_tokens = {}
160 | self.special_tokens_decoder = {}
161 | self.set_special_tokens(special_tokens)
162 |
163 | def __len__(self):
164 | return len(self.encoder) + len(self.special_tokens)
165 |
166 | def set_special_tokens(self, special_tokens):
167 | """ Add a list of additional tokens to the encoder.
168 | The additional tokens are indexed starting from the last index of the
169 | current vocabulary in the order of the `special_tokens` list.
170 | """
171 | if not special_tokens:
172 | self.special_tokens = {}
173 | self.special_tokens_decoder = {}
174 | return
175 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
176 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
177 | logger.info("Special tokens {}".format(self.special_tokens))
178 |
179 | def bpe(self, token):
180 | if token in self.cache:
181 | return self.cache[token]
182 | word = tuple(token)
183 | pairs = get_pairs(word)
184 |
185 | if not pairs:
186 | return token
187 |
188 | while True:
189 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
190 | if bigram not in self.bpe_ranks:
191 | break
192 | first, second = bigram
193 | new_word = []
194 | i = 0
195 | while i < len(word):
196 | try:
197 | j = word.index(first, i)
198 | new_word.extend(word[i:j])
199 | i = j
200 | except:
201 | new_word.extend(word[i:])
202 | break
203 |
204 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
205 | new_word.append(first+second)
206 | i += 2
207 | else:
208 | new_word.append(word[i])
209 | i += 1
210 | new_word = tuple(new_word)
211 | word = new_word
212 | if len(word) == 1:
213 | break
214 | else:
215 | pairs = get_pairs(word)
216 | word = ' '.join(word)
217 | self.cache[token] = word
218 | return word
219 |
220 | def tokenize(self, text):
221 | """ Tokenize a string. """
222 | bpe_tokens = []
223 | for token in re.findall(self.pat, text):
224 | token = ''.join(self.byte_encoder[ord(b)] for b in token)
225 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
226 | return bpe_tokens
227 |
228 | def convert_tokens_to_ids(self, tokens):
229 | """ Converts a sequence of tokens into ids using the vocab. """
230 | ids = []
231 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
232 | if tokens in self.special_tokens:
233 | return self.special_tokens[tokens]
234 | else:
235 | return self.encoder.get(tokens, 0)
236 | for token in tokens:
237 | if token in self.special_tokens:
238 | ids.append(self.special_tokens[token])
239 | else:
240 | ids.append(self.encoder.get(token, 0))
241 | if len(ids) > self.max_len:
242 | logger.warning(
243 | "Token indices sequence length is longer than the specified maximum "
244 | " sequence length for this OpenAI GPT model ({} > {}). Running this"
245 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
246 | )
247 | return ids
248 |
249 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
250 | """Converts a sequence of ids in BPE tokens using the vocab."""
251 | tokens = []
252 | for i in ids:
253 | if i in self.special_tokens_decoder:
254 | if not skip_special_tokens:
255 | tokens.append(self.special_tokens_decoder[i])
256 | else:
257 | tokens.append(self.decoder[i])
258 | return tokens
259 |
260 | def encode(self, text):
261 | return self.convert_tokens_to_ids(self.tokenize(text))
262 |
263 | def decode(self, tokens):
264 | text = ''.join([self.decoder[token] for token in tokens])
265 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
266 | return text
267 |
268 | def save_vocabulary(self, vocab_path):
269 | """Save the tokenizer vocabulary and merge files to a directory."""
270 | if not os.path.isdir(vocab_path):
271 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
272 | return
273 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
274 | merge_file = os.path.join(vocab_path, MERGES_NAME)
275 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
276 |
277 | with open(vocab_file, 'w', encoding='utf-8') as f:
278 | f.write(json.dumps(self.encoder, ensure_ascii=False))
279 |
280 | index = 0
281 | with open(merge_file, "w", encoding="utf-8") as writer:
282 | writer.write(u'#version: 0.2\n')
283 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
284 | if index != token_index:
285 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
286 | " Please check that the tokenizer is not corrupted!".format(merge_file))
287 | index = token_index
288 | writer.write(' '.join(bpe_tokens) + u'\n')
289 | index += 1
290 |
291 | index = len(self.encoder)
292 | with open(special_tokens_file, 'w', encoding='utf-8') as writer:
293 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
294 | if index != token_index:
295 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
296 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
297 | index = token_index
298 | writer.write(token + u'\n')
299 | index += 1
300 |
301 | return vocab_file, merge_file, special_tokens_file
302 |
--------------------------------------------------------------------------------
/pytorch_pretrained/tokenization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import re
23 | import sys
24 | from io import open
25 |
26 | from tqdm import tqdm
27 |
28 | from .file_utils import cached_path
29 | from .tokenization import BasicTokenizer
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
35 | }
36 | PRETRAINED_MERGES_ARCHIVE_MAP = {
37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
38 | }
39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
40 | 'openai-gpt': 512,
41 | }
42 | VOCAB_NAME = 'vocab.json'
43 | MERGES_NAME = 'merges.txt'
44 | SPECIAL_TOKENS_NAME = 'special_tokens.txt'
45 |
46 | def get_pairs(word):
47 | """
48 | Return set of symbol pairs in a word.
49 | word is represented as tuple of symbols (symbols being variable-length strings)
50 | """
51 | pairs = set()
52 | prev_char = word[0]
53 | for char in word[1:]:
54 | pairs.add((prev_char, char))
55 | prev_char = char
56 | return pairs
57 |
58 | def text_standardize(text):
59 | """
60 | fixes some issues the spacy tokenizer had on books corpus
61 | also does some whitespace standardization
62 | """
63 | text = text.replace('—', '-')
64 | text = text.replace('–', '-')
65 | text = text.replace('―', '-')
66 | text = text.replace('…', '...')
67 | text = text.replace('´', "'")
68 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
69 | text = re.sub(r'\s*\n\s*', ' \n ', text)
70 | text = re.sub(r'[^\S\n]+', ' ', text)
71 | return text.strip()
72 |
73 | class OpenAIGPTTokenizer(object):
74 | """
75 | BPE tokenizer. Peculiarities:
76 | - lower case all inputs
77 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
78 | - argument special_tokens and function set_special_tokens:
79 | can be used to add additional symbols (ex: "__classify__") to a vocabulary.
80 | """
81 | @classmethod
82 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
83 | """
84 | Instantiate a PreTrainedBertModel from a pre-trained model file.
85 | Download and cache the pre-trained model file if needed.
86 | """
87 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
88 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
89 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
90 | special_tokens_file = None
91 | else:
92 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
93 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
94 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
95 | if not os.path.exists(special_tokens_file):
96 | special_tokens_file = None
97 | else:
98 | logger.info("loading special tokens file {}".format(special_tokens_file))
99 | # redirect to the cache, if necessary
100 | try:
101 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
102 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
103 | except EnvironmentError:
104 | logger.error(
105 | "Model name '{}' was not found in model name list ({}). "
106 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
107 | "at this path or url.".format(
108 | pretrained_model_name_or_path,
109 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
110 | pretrained_model_name_or_path,
111 | vocab_file, merges_file))
112 | return None
113 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
114 | logger.info("loading vocabulary file {}".format(vocab_file))
115 | logger.info("loading merges file {}".format(merges_file))
116 | else:
117 | logger.info("loading vocabulary file {} from cache at {}".format(
118 | vocab_file, resolved_vocab_file))
119 | logger.info("loading merges file {} from cache at {}".format(
120 | merges_file, resolved_merges_file))
121 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
122 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
123 | # than the number of positional embeddings
124 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
125 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
126 | # Instantiate tokenizer.
127 | if special_tokens_file and 'special_tokens' not in kwargs:
128 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
129 | else:
130 | special_tokens = kwargs.pop('special_tokens', [])
131 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
132 | return tokenizer
133 |
134 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
135 | try:
136 | import ftfy
137 | import spacy
138 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
139 | self.fix_text = ftfy.fix_text
140 | except ImportError:
141 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
142 | self.nlp = BasicTokenizer(do_lower_case=True,
143 | never_split=special_tokens if special_tokens is not None else [])
144 | self.fix_text = None
145 |
146 | self.max_len = max_len if max_len is not None else int(1e12)
147 | self.encoder = json.load(open(vocab_file, encoding="utf-8"))
148 | self.decoder = {v:k for k,v in self.encoder.items()}
149 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
150 | merges = [tuple(merge.split()) for merge in merges]
151 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
152 | self.cache = {}
153 | self.special_tokens = {}
154 | self.special_tokens_decoder = {}
155 | self.set_special_tokens(special_tokens)
156 |
157 | def __len__(self):
158 | return len(self.encoder) + len(self.special_tokens)
159 |
160 | def set_special_tokens(self, special_tokens):
161 | """ Add a list of additional tokens to the encoder.
162 | The additional tokens are indexed starting from the last index of the
163 | current vocabulary in the order of the `special_tokens` list.
164 | """
165 | if not special_tokens:
166 | self.special_tokens = {}
167 | self.special_tokens_decoder = {}
168 | return
169 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
170 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
171 | if self.fix_text is None:
172 | # Using BERT's BasicTokenizer: we can update the tokenizer
173 | self.nlp.never_split = special_tokens
174 | logger.info("Special tokens {}".format(self.special_tokens))
175 |
176 | def bpe(self, token):
177 | word = tuple(token[:-1]) + (token[-1] + '',)
178 | if token in self.cache:
179 | return self.cache[token]
180 | pairs = get_pairs(word)
181 |
182 | if not pairs:
183 | return token+''
184 |
185 | while True:
186 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
187 | if bigram not in self.bpe_ranks:
188 | break
189 | first, second = bigram
190 | new_word = []
191 | i = 0
192 | while i < len(word):
193 | try:
194 | j = word.index(first, i)
195 | new_word.extend(word[i:j])
196 | i = j
197 | except:
198 | new_word.extend(word[i:])
199 | break
200 |
201 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
202 | new_word.append(first+second)
203 | i += 2
204 | else:
205 | new_word.append(word[i])
206 | i += 1
207 | new_word = tuple(new_word)
208 | word = new_word
209 | if len(word) == 1:
210 | break
211 | else:
212 | pairs = get_pairs(word)
213 | word = ' '.join(word)
214 | if word == '\n ':
215 | word = '\n'
216 | self.cache[token] = word
217 | return word
218 |
219 | def tokenize(self, text):
220 | """ Tokenize a string. """
221 | split_tokens = []
222 | if self.fix_text is None:
223 | # Using BERT's BasicTokenizer
224 | text = self.nlp.tokenize(text)
225 | for token in text:
226 | split_tokens.extend([t for t in self.bpe(token).split(' ')])
227 | else:
228 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
229 | text = self.nlp(text_standardize(self.fix_text(text)))
230 | for token in text:
231 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
232 | return split_tokens
233 |
234 | def convert_tokens_to_ids(self, tokens):
235 | """ Converts a sequence of tokens into ids using the vocab. """
236 | ids = []
237 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
238 | if tokens in self.special_tokens:
239 | return self.special_tokens[tokens]
240 | else:
241 | return self.encoder.get(tokens, 0)
242 | for token in tokens:
243 | if token in self.special_tokens:
244 | ids.append(self.special_tokens[token])
245 | else:
246 | ids.append(self.encoder.get(token, 0))
247 | if len(ids) > self.max_len:
248 | logger.warning(
249 | "Token indices sequence length is longer than the specified maximum "
250 | " sequence length for this OpenAI GPT model ({} > {}). Running this"
251 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
252 | )
253 | return ids
254 |
255 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
256 | """Converts a sequence of ids in BPE tokens using the vocab."""
257 | tokens = []
258 | for i in ids:
259 | if i in self.special_tokens_decoder:
260 | if not skip_special_tokens:
261 | tokens.append(self.special_tokens_decoder[i])
262 | else:
263 | tokens.append(self.decoder[i])
264 | return tokens
265 |
266 | def encode(self, text):
267 | return self.convert_tokens_to_ids(self.tokenize(text))
268 |
269 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
270 | """Converts a sequence of ids in a string."""
271 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
272 | out_string = ''.join(tokens).replace('', ' ').strip()
273 | if clean_up_tokenization_spaces:
274 | out_string = out_string.replace('', '')
275 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
276 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
277 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
278 | return out_string
279 |
280 | def save_vocabulary(self, vocab_path):
281 | """Save the tokenizer vocabulary and merge files to a directory."""
282 | if not os.path.isdir(vocab_path):
283 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
284 | return
285 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
286 | merge_file = os.path.join(vocab_path, MERGES_NAME)
287 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
288 |
289 | with open(vocab_file, 'w', encoding='utf-8') as f:
290 | f.write(json.dumps(self.encoder, ensure_ascii=False))
291 |
292 | index = 0
293 | with open(merge_file, "w", encoding="utf-8") as writer:
294 | writer.write(u'#version: 0.2\n')
295 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
296 | if index != token_index:
297 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
298 | " Please check that the tokenizer is not corrupted!".format(merge_file))
299 | index = token_index
300 | writer.write(' '.join(bpe_tokens) + u'\n')
301 | index += 1
302 |
303 | index = len(self.encoder)
304 | with open(special_tokens_file, 'w', encoding='utf-8') as writer:
305 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
306 | if index != token_index:
307 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
308 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
309 | index = token_index
310 | writer.write(token + u'\n')
311 | index += 1
312 |
313 | return vocab_file, merge_file, special_tokens_file
314 |
--------------------------------------------------------------------------------
/pytorch_pretrained/tokenization_transfo_xl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Tokenization classes for Transformer XL model.
17 | Adapted from https://github.com/kimiyoung/transformer-xl.
18 | """
19 | from __future__ import (absolute_import, division, print_function,
20 | unicode_literals)
21 |
22 | import glob
23 | import logging
24 | import os
25 | import sys
26 | from collections import Counter, OrderedDict
27 | from io import open
28 | import unicodedata
29 |
30 | import torch
31 | import numpy as np
32 |
33 | from .file_utils import cached_path
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 |
41 | logger = logging.getLogger(__name__)
42 |
43 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
45 | }
46 | VOCAB_NAME = 'vocab.bin'
47 |
48 | PRETRAINED_CORPUS_ARCHIVE_MAP = {
49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
50 | }
51 | CORPUS_NAME = 'corpus.bin'
52 |
53 | class TransfoXLTokenizer(object):
54 | """
55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
56 | """
57 | @classmethod
58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
59 | """
60 | Instantiate a TransfoXLTokenizer.
61 | The TransfoXLTokenizer.
62 | """
63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
65 | else:
66 | if os.path.isdir(pretrained_model_name_or_path):
67 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
68 | else:
69 | vocab_file = pretrained_model_name_or_path
70 | # redirect to the cache, if necessary
71 | try:
72 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
73 | except EnvironmentError:
74 | logger.error(
75 | "Model name '{}' was not found in model name list ({}). "
76 | "We assumed '{}' was a path or url but couldn't find files {} "
77 | "at this path or url.".format(
78 | pretrained_model_name_or_path,
79 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
80 | pretrained_model_name_or_path,
81 | vocab_file))
82 | return None
83 | if resolved_vocab_file == vocab_file:
84 | logger.info("loading vocabulary file {}".format(vocab_file))
85 | else:
86 | logger.info("loading vocabulary file {} from cache at {}".format(
87 | vocab_file, resolved_vocab_file))
88 |
89 | # Instantiate tokenizer.
90 | tokenizer = cls(*inputs, **kwargs)
91 | vocab_dict = torch.load(resolved_vocab_file)
92 | for key, value in vocab_dict.items():
93 | tokenizer.__dict__[key] = value
94 | return tokenizer
95 |
96 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
97 | delimiter=None, vocab_file=None, never_split=("", "", "")):
98 | self.counter = Counter()
99 | self.special = special
100 | self.min_freq = min_freq
101 | self.max_size = max_size
102 | self.lower_case = lower_case
103 | self.delimiter = delimiter
104 | self.vocab_file = vocab_file
105 | self.never_split = never_split
106 |
107 | def count_file(self, path, verbose=False, add_eos=False):
108 | if verbose: print('counting file {} ...'.format(path))
109 | assert os.path.exists(path)
110 |
111 | sents = []
112 | with open(path, 'r', encoding='utf-8') as f:
113 | for idx, line in enumerate(f):
114 | if verbose and idx > 0 and idx % 500000 == 0:
115 | print(' line {}'.format(idx))
116 | symbols = self.tokenize(line, add_eos=add_eos)
117 | self.counter.update(symbols)
118 | sents.append(symbols)
119 |
120 | return sents
121 |
122 | def count_sents(self, sents, verbose=False):
123 | """
124 | sents : a list of sentences, each a list of tokenized symbols
125 | """
126 | if verbose: print('counting {} sents ...'.format(len(sents)))
127 | for idx, symbols in enumerate(sents):
128 | if verbose and idx > 0 and idx % 500000 == 0:
129 | print(' line {}'.format(idx))
130 | self.counter.update(symbols)
131 |
132 | def _build_from_file(self, vocab_file):
133 | self.idx2sym = []
134 | self.sym2idx = OrderedDict()
135 |
136 | with open(vocab_file, 'r', encoding='utf-8') as f:
137 | for line in f:
138 | symb = line.strip().split()[0]
139 | self.add_symbol(symb)
140 | if '' in self.sym2idx:
141 | self.unk_idx = self.sym2idx['']
142 | elif '' in self.sym2idx:
143 | self.unk_idx = self.sym2idx['']
144 | else:
145 | raise ValueError('No token in vocabulary')
146 |
147 | def save_vocabulary(self, vocab_path):
148 | """Save the tokenizer vocabulary to a directory or file."""
149 | index = 0
150 | if os.path.isdir(vocab_path):
151 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
152 | torch.save(self.__dict__, vocab_file)
153 | return vocab_file
154 |
155 | def build_vocab(self):
156 | if self.vocab_file:
157 | print('building vocab from {}'.format(self.vocab_file))
158 | self._build_from_file(self.vocab_file)
159 | print('final vocab size {}'.format(len(self)))
160 | else:
161 | print('building vocab with min_freq={}, max_size={}'.format(
162 | self.min_freq, self.max_size))
163 | self.idx2sym = []
164 | self.sym2idx = OrderedDict()
165 |
166 | for sym in self.special:
167 | self.add_special(sym)
168 |
169 | for sym, cnt in self.counter.most_common(self.max_size):
170 | if cnt < self.min_freq: break
171 | self.add_symbol(sym)
172 |
173 | print('final vocab size {} from {} unique tokens'.format(
174 | len(self), len(self.counter)))
175 |
176 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
177 | add_double_eos=False):
178 | if verbose: print('encoding file {} ...'.format(path))
179 | assert os.path.exists(path)
180 | encoded = []
181 | with open(path, 'r', encoding='utf-8') as f:
182 | for idx, line in enumerate(f):
183 | if verbose and idx > 0 and idx % 500000 == 0:
184 | print(' line {}'.format(idx))
185 | symbols = self.tokenize(line, add_eos=add_eos,
186 | add_double_eos=add_double_eos)
187 | encoded.append(self.convert_to_tensor(symbols))
188 |
189 | if ordered:
190 | encoded = torch.cat(encoded)
191 |
192 | return encoded
193 |
194 | def encode_sents(self, sents, ordered=False, verbose=False):
195 | if verbose: print('encoding {} sents ...'.format(len(sents)))
196 | encoded = []
197 | for idx, symbols in enumerate(sents):
198 | if verbose and idx > 0 and idx % 500000 == 0:
199 | print(' line {}'.format(idx))
200 | encoded.append(self.convert_to_tensor(symbols))
201 |
202 | if ordered:
203 | encoded = torch.cat(encoded)
204 |
205 | return encoded
206 |
207 | def add_special(self, sym):
208 | if sym not in self.sym2idx:
209 | self.idx2sym.append(sym)
210 | self.sym2idx[sym] = len(self.idx2sym) - 1
211 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
212 |
213 | def add_symbol(self, sym):
214 | if sym not in self.sym2idx:
215 | self.idx2sym.append(sym)
216 | self.sym2idx[sym] = len(self.idx2sym) - 1
217 |
218 | def get_sym(self, idx):
219 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
220 | return self.idx2sym[idx]
221 |
222 | def get_idx(self, sym):
223 | if sym in self.sym2idx:
224 | return self.sym2idx[sym]
225 | else:
226 | # print('encounter unk {}'.format(sym))
227 | # assert '' not in sym
228 | if hasattr(self, 'unk_idx'):
229 | return self.sym2idx.get(sym, self.unk_idx)
230 | # Backward compatibility with pre-trained models
231 | elif '' in self.sym2idx:
232 | return self.sym2idx['']
233 | elif '' in self.sym2idx:
234 | return self.sym2idx['']
235 | else:
236 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement')
237 |
238 | def convert_ids_to_tokens(self, indices):
239 | """Converts a sequence of indices in symbols using the vocab."""
240 | return [self.get_sym(idx) for idx in indices]
241 |
242 | def convert_tokens_to_ids(self, symbols):
243 | """Converts a sequence of symbols into ids using the vocab."""
244 | return [self.get_idx(sym) for sym in symbols]
245 |
246 | def convert_to_tensor(self, symbols):
247 | return torch.LongTensor(self.convert_tokens_to_ids(symbols))
248 |
249 | def decode(self, indices, exclude=None):
250 | """Converts a sequence of indices in a string."""
251 | if exclude is None:
252 | return ' '.join([self.get_sym(idx) for idx in indices])
253 | else:
254 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
255 |
256 | def __len__(self):
257 | return len(self.idx2sym)
258 |
259 | def tokenize(self, line, add_eos=False, add_double_eos=False):
260 | line = line.strip()
261 | # convert to lower case
262 | if self.lower_case:
263 | line = line.lower()
264 |
265 | # empty delimiter '' will evaluate False
266 | if self.delimiter == '':
267 | symbols = line
268 | else:
269 | symbols = line.split(self.delimiter)
270 |
271 | if add_double_eos: # lm1b
272 | return [''] + symbols + ['']
273 | elif add_eos:
274 | return symbols + ['']
275 | else:
276 | return symbols
277 |
278 |
279 | class LMOrderedIterator(object):
280 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
281 | """
282 | data -- LongTensor -- the LongTensor is strictly ordered
283 | """
284 | self.bsz = bsz
285 | self.bptt = bptt
286 | self.ext_len = ext_len if ext_len is not None else 0
287 |
288 | self.device = device
289 |
290 | # Work out how cleanly we can divide the dataset into bsz parts.
291 | self.n_step = data.size(0) // bsz
292 |
293 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
294 | data = data.narrow(0, 0, self.n_step * bsz)
295 |
296 | # Evenly divide the data across the bsz batches.
297 | self.data = data.view(bsz, -1).t().contiguous().to(device)
298 |
299 | # Number of mini-batches
300 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
301 |
302 | def get_batch(self, i, bptt=None):
303 | if bptt is None: bptt = self.bptt
304 | seq_len = min(bptt, self.data.size(0) - 1 - i)
305 |
306 | end_idx = i + seq_len
307 | beg_idx = max(0, i - self.ext_len)
308 |
309 | data = self.data[beg_idx:end_idx]
310 | target = self.data[i+1:i+1+seq_len]
311 |
312 | data_out = data.transpose(0, 1).contiguous().to(self.device)
313 | target_out = target.transpose(0, 1).contiguous().to(self.device)
314 |
315 | return data_out, target_out, seq_len
316 |
317 | def get_fixlen_iter(self, start=0):
318 | for i in range(start, self.data.size(0) - 1, self.bptt):
319 | yield self.get_batch(i)
320 |
321 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
322 | max_len = self.bptt + max_deviation * std
323 | i = start
324 | while True:
325 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
326 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
327 | data, target, seq_len = self.get_batch(i, bptt)
328 | i += seq_len
329 | yield data, target, seq_len
330 | if i >= self.data.size(0) - 2:
331 | break
332 |
333 | def __iter__(self):
334 | return self.get_fixlen_iter()
335 |
336 |
337 | class LMShuffledIterator(object):
338 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
339 | """
340 | data -- list[LongTensor] -- there is no order among the LongTensors
341 | """
342 | self.data = data
343 |
344 | self.bsz = bsz
345 | self.bptt = bptt
346 | self.ext_len = ext_len if ext_len is not None else 0
347 |
348 | self.device = device
349 | self.shuffle = shuffle
350 |
351 | def get_sent_stream(self):
352 | # index iterator
353 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
354 | else np.array(range(len(self.data)))
355 |
356 | # sentence iterator
357 | for idx in epoch_indices:
358 | yield self.data[idx]
359 |
360 | def stream_iterator(self, sent_stream):
361 | # streams for each data in the batch
362 | streams = [None] * self.bsz
363 |
364 | data = torch.LongTensor(self.bptt, self.bsz)
365 | target = torch.LongTensor(self.bptt, self.bsz)
366 |
367 | n_retain = 0
368 |
369 | while True:
370 | # data : [n_retain+bptt x bsz]
371 | # target : [bptt x bsz]
372 | data[n_retain:].fill_(-1)
373 | target.fill_(-1)
374 |
375 | valid_batch = True
376 |
377 | for i in range(self.bsz):
378 | n_filled = 0
379 | try:
380 | while n_filled < self.bptt:
381 | if streams[i] is None or len(streams[i]) <= 1:
382 | streams[i] = next(sent_stream)
383 | # number of new tokens to fill in
384 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
385 | # first n_retain tokens are retained from last batch
386 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
387 | streams[i][:n_new]
388 | target[n_filled:n_filled+n_new, i] = \
389 | streams[i][1:n_new+1]
390 | streams[i] = streams[i][n_new:]
391 | n_filled += n_new
392 | except StopIteration:
393 | valid_batch = False
394 | break
395 |
396 | if not valid_batch:
397 | return
398 |
399 | data_out = data.transpose(0, 1).contiguous().to(self.device)
400 | target_out = target.transpose(0, 1).contiguous().to(self.device)
401 |
402 | yield data_out, target_out, self.bptt
403 |
404 | n_retain = min(data.size(0), self.ext_len)
405 | if n_retain > 0:
406 | data[:n_retain] = data[-n_retain:]
407 | data.resize_(n_retain + self.bptt, data.size(1))
408 |
409 | def __iter__(self):
410 | # sent_stream is an iterator
411 | sent_stream = self.get_sent_stream()
412 |
413 | for batch in self.stream_iterator(sent_stream):
414 | yield batch
415 |
416 |
417 | class LMMultiFileIterator(LMShuffledIterator):
418 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
419 | shuffle=False):
420 |
421 | self.paths = paths
422 | self.vocab = vocab
423 |
424 | self.bsz = bsz
425 | self.bptt = bptt
426 | self.ext_len = ext_len if ext_len is not None else 0
427 |
428 | self.device = device
429 | self.shuffle = shuffle
430 |
431 | def get_sent_stream(self, path):
432 | sents = self.vocab.encode_file(path, add_double_eos=True)
433 | if self.shuffle:
434 | np.random.shuffle(sents)
435 | sent_stream = iter(sents)
436 |
437 | return sent_stream
438 |
439 | def __iter__(self):
440 | if self.shuffle:
441 | np.random.shuffle(self.paths)
442 |
443 | for path in self.paths:
444 | # sent_stream is an iterator
445 | sent_stream = self.get_sent_stream(path)
446 | for batch in self.stream_iterator(sent_stream):
447 | yield batch
448 |
449 |
450 | class TransfoXLCorpus(object):
451 | @classmethod
452 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
453 | """
454 | Instantiate a pre-processed corpus.
455 | """
456 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
457 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
458 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
459 | else:
460 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
461 | # redirect to the cache, if necessary
462 | try:
463 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
464 | except EnvironmentError:
465 | logger.error(
466 | "Corpus '{}' was not found in corpus list ({}). "
467 | "We assumed '{}' was a path or url but couldn't find files {} "
468 | "at this path or url.".format(
469 | pretrained_model_name_or_path,
470 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
471 | pretrained_model_name_or_path,
472 | corpus_file))
473 | return None
474 | if resolved_corpus_file == corpus_file:
475 | logger.info("loading corpus file {}".format(corpus_file))
476 | else:
477 | logger.info("loading corpus file {} from cache at {}".format(
478 | corpus_file, resolved_corpus_file))
479 |
480 | # Instantiate tokenizer.
481 | corpus = cls(*inputs, **kwargs)
482 | corpus_dict = torch.load(resolved_corpus_file)
483 | for key, value in corpus_dict.items():
484 | corpus.__dict__[key] = value
485 | corpus.vocab = vocab
486 | if corpus.train is not None:
487 | corpus.train = torch.tensor(corpus.train, dtype=torch.long)
488 | if corpus.valid is not None:
489 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
490 | if corpus.test is not None:
491 | corpus.test = torch.tensor(corpus.test, dtype=torch.long)
492 | return corpus
493 |
494 | def __init__(self, *args, **kwargs):
495 | self.vocab = TransfoXLTokenizer(*args, **kwargs)
496 | self.dataset = None
497 | self.train = None
498 | self.valid = None
499 | self.test = None
500 |
501 | def build_corpus(self, path, dataset):
502 | self.dataset = dataset
503 |
504 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
505 | self.vocab.count_file(os.path.join(path, 'train.txt'))
506 | self.vocab.count_file(os.path.join(path, 'valid.txt'))
507 | self.vocab.count_file(os.path.join(path, 'test.txt'))
508 | elif self.dataset == 'wt103':
509 | self.vocab.count_file(os.path.join(path, 'train.txt'))
510 | elif self.dataset == 'lm1b':
511 | train_path_pattern = os.path.join(
512 | path, '1-billion-word-language-modeling-benchmark-r13output',
513 | 'training-monolingual.tokenized.shuffled', 'news.en-*')
514 | train_paths = glob.glob(train_path_pattern)
515 | # the vocab will load from file when build_vocab() is called
516 |
517 | self.vocab.build_vocab()
518 |
519 | if self.dataset in ['ptb', 'wt2', 'wt103']:
520 | self.train = self.vocab.encode_file(
521 | os.path.join(path, 'train.txt'), ordered=True)
522 | self.valid = self.vocab.encode_file(
523 | os.path.join(path, 'valid.txt'), ordered=True)
524 | self.test = self.vocab.encode_file(
525 | os.path.join(path, 'test.txt'), ordered=True)
526 | elif self.dataset in ['enwik8', 'text8']:
527 | self.train = self.vocab.encode_file(
528 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
529 | self.valid = self.vocab.encode_file(
530 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
531 | self.test = self.vocab.encode_file(
532 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
533 | elif self.dataset == 'lm1b':
534 | self.train = train_paths
535 | self.valid = self.vocab.encode_file(
536 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
537 | self.test = self.vocab.encode_file(
538 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
539 |
540 | def get_iterator(self, split, *args, **kwargs):
541 | if split == 'train':
542 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
543 | data_iter = LMOrderedIterator(self.train, *args, **kwargs)
544 | elif self.dataset == 'lm1b':
545 | kwargs['shuffle'] = True
546 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
547 | elif split in ['valid', 'test']:
548 | data = self.valid if split == 'valid' else self.test
549 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
550 | data_iter = LMOrderedIterator(data, *args, **kwargs)
551 | elif self.dataset == 'lm1b':
552 | data_iter = LMShuffledIterator(data, *args, **kwargs)
553 |
554 | return data_iter
555 |
556 |
557 | def get_lm_corpus(datadir, dataset):
558 | fn = os.path.join(datadir, 'cache.pt')
559 | fn_pickle = os.path.join(datadir, 'cache.pkl')
560 | if os.path.exists(fn):
561 | print('Loading cached dataset...')
562 | corpus = torch.load(fn_pickle)
563 | elif os.path.exists(fn):
564 | print('Loading cached dataset from pickle...')
565 | with open(fn, "rb") as fp:
566 | corpus = pickle.load(fp)
567 | else:
568 | print('Producing dataset {}...'.format(dataset))
569 | kwargs = {}
570 | if dataset in ['wt103', 'wt2']:
571 | kwargs['special'] = ['']
572 | kwargs['lower_case'] = False
573 | elif dataset == 'ptb':
574 | kwargs['special'] = ['']
575 | kwargs['lower_case'] = True
576 | elif dataset == 'lm1b':
577 | kwargs['special'] = []
578 | kwargs['lower_case'] = False
579 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
580 | elif dataset in ['enwik8', 'text8']:
581 | pass
582 |
583 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
584 | torch.save(corpus, fn)
585 |
586 | return corpus
587 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from sklearn import metrics
7 | import time
8 | from utils import get_time_dif
9 | from pytorch_pretrained.optimization import BertAdam
10 |
11 |
12 | # 权重初始化,默认xavier
13 | def init_network(model, method='xavier', exclude='embedding', seed=123):
14 | for name, w in model.named_parameters():
15 | if exclude not in name:
16 | if len(w.size()) < 2:
17 | continue
18 | if 'weight' in name:
19 | if method == 'xavier':
20 | nn.init.xavier_normal_(w)
21 | elif method == 'kaiming':
22 | nn.init.kaiming_normal_(w)
23 | else:
24 | nn.init.normal_(w)
25 | elif 'bias' in name:
26 | nn.init.constant_(w, 0)
27 | else:
28 | pass
29 |
30 |
31 | def train(config, model, train_iter, dev_iter, test_iter):
32 | start_time = time.time()
33 | model.train()
34 | param_optimizer = list(model.named_parameters())
35 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
36 | optimizer_grouped_parameters = [
37 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
38 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
39 | # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
40 | optimizer = BertAdam(optimizer_grouped_parameters,
41 | lr=config.learning_rate,
42 | warmup=0.05,
43 | t_total=len(train_iter) * config.num_epochs)
44 | total_batch = 0 # 记录进行到多少batch
45 | dev_best_loss = float('inf')
46 | last_improve = 0 # 记录上次验证集loss下降的batch数
47 | flag = False # 记录是否很久没有效果提升
48 | model.train()
49 | for epoch in range(config.num_epochs):
50 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
51 | for i, (trains, labels) in enumerate(train_iter):
52 | outputs = model(trains)
53 | model.zero_grad()
54 | loss = F.cross_entropy(outputs, labels)
55 | loss.backward()
56 | optimizer.step()
57 | if total_batch % 100 == 0:
58 | # 每多少轮输出在训练集和验证集上的效果
59 | true = labels.data.cpu()
60 | predic = torch.max(outputs.data, 1)[1].cpu()
61 | train_acc = metrics.accuracy_score(true, predic)
62 | dev_acc, dev_loss = evaluate(config, model, dev_iter)
63 | if dev_loss < dev_best_loss:
64 | dev_best_loss = dev_loss
65 | torch.save(model.state_dict(), config.save_path)
66 | improve = '*'
67 | last_improve = total_batch
68 | else:
69 | improve = ''
70 | time_dif = get_time_dif(start_time)
71 | msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
72 | print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
73 | model.train()
74 | total_batch += 1
75 | if total_batch - last_improve > config.require_improvement:
76 | # 验证集loss超过1000batch没下降,结束训练
77 | print("No optimization for a long time, auto-stopping...")
78 | flag = True
79 | break
80 | if flag:
81 | break
82 | test(config, model, test_iter)
83 |
84 |
85 | def test(config, model, test_iter):
86 | # test
87 | model.load_state_dict(torch.load(config.save_path))
88 | model.eval()
89 | start_time = time.time()
90 | test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
91 | msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
92 | print(msg.format(test_loss, test_acc))
93 | print("Precision, Recall and F1-Score...")
94 | print(test_report)
95 | print("Confusion Matrix...")
96 | print(test_confusion)
97 | time_dif = get_time_dif(start_time)
98 | print("Time usage:", time_dif)
99 |
100 |
101 | def evaluate(config, model, data_iter, test=False):
102 | model.eval()
103 | loss_total = 0
104 | predict_all = np.array([], dtype=int)
105 | labels_all = np.array([], dtype=int)
106 | with torch.no_grad():
107 | for texts, labels in data_iter:
108 | outputs = model(texts)
109 | loss = F.cross_entropy(outputs, labels)
110 | loss_total += loss
111 | labels = labels.data.cpu().numpy()
112 | predic = torch.max(outputs.data, 1)[1].cpu().numpy()
113 | labels_all = np.append(labels_all, labels)
114 | predict_all = np.append(predict_all, predic)
115 |
116 | acc = metrics.accuracy_score(labels_all, predict_all)
117 | if test:
118 | report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
119 | confusion = metrics.confusion_matrix(labels_all, predict_all)
120 | return acc, loss_total / len(data_iter), report, confusion
121 | return acc, loss_total / len(data_iter)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | from tqdm import tqdm
4 | import time
5 | from datetime import timedelta
6 |
7 | PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
8 |
9 |
10 | def build_dataset(config):
11 |
12 | def load_dataset(path, pad_size=32):
13 | contents = []
14 | with open(path, 'r', encoding='UTF-8') as f:
15 | for line in tqdm(f):
16 | lin = line.strip()
17 | if not lin:
18 | continue
19 | content, label = lin.split('\t')
20 | token = config.tokenizer.tokenize(content)
21 | token = [CLS] + token
22 | seq_len = len(token)
23 | mask = []
24 | token_ids = config.tokenizer.convert_tokens_to_ids(token)
25 |
26 | if pad_size:
27 | if len(token) < pad_size:
28 | mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
29 | token_ids += ([0] * (pad_size - len(token)))
30 | else:
31 | mask = [1] * pad_size
32 | token_ids = token_ids[:pad_size]
33 | seq_len = pad_size
34 | contents.append((token_ids, int(label), seq_len, mask))
35 | return contents
36 | train = load_dataset(config.train_path, config.pad_size)
37 | dev = load_dataset(config.dev_path, config.pad_size)
38 | test = load_dataset(config.test_path, config.pad_size)
39 | return train, dev, test
40 |
41 |
42 | class DatasetIterater(object):
43 | def __init__(self, batches, batch_size, device):
44 | self.batch_size = batch_size
45 | self.batches = batches
46 | self.n_batches = len(batches) // batch_size
47 | self.residue = False # 记录batch数量是否为整数
48 | if len(batches) % self.n_batches != 0:
49 | self.residue = True
50 | self.index = 0
51 | self.device = device
52 |
53 | def _to_tensor(self, datas):
54 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
55 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
56 |
57 | # pad前的长度(超过pad_size的设为pad_size)
58 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
59 | mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
60 | return (x, seq_len, mask), y
61 |
62 | def __next__(self):
63 | if self.residue and self.index == self.n_batches:
64 | batches = self.batches[self.index * self.batch_size: len(self.batches)]
65 | self.index += 1
66 | batches = self._to_tensor(batches)
67 | return batches
68 |
69 | elif self.index >= self.n_batches:
70 | self.index = 0
71 | raise StopIteration
72 | else:
73 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
74 | self.index += 1
75 | batches = self._to_tensor(batches)
76 | return batches
77 |
78 | def __iter__(self):
79 | return self
80 |
81 | def __len__(self):
82 | if self.residue:
83 | return self.n_batches + 1
84 | else:
85 | return self.n_batches
86 |
87 |
88 | def build_iterator(dataset, config):
89 | iter = DatasetIterater(dataset, config.batch_size, config.device)
90 | return iter
91 |
92 |
93 | def get_time_dif(start_time):
94 | """获取已使用时间"""
95 | end_time = time.time()
96 | time_dif = end_time - start_time
97 | return timedelta(seconds=int(round(time_dif)))
--------------------------------------------------------------------------------