├── ERNIE_pretrain └── README.md ├── LICENSE ├── README.md ├── THUCNews └── data │ ├── class.txt │ ├── dev.txt │ ├── test.txt │ └── train.txt ├── bert_pretrain └── README.md ├── models ├── ERNIE.py ├── bert.py ├── bert_CNN.py ├── bert_DPCNN.py ├── bert_RCNN.py └── bert_RNN.py ├── pytorch_pretrained ├── __init__.py ├── __main__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __main__.cpython-37.pyc │ ├── convert_gpt2_checkpoint_to_pytorch.cpython-37.pyc │ ├── convert_openai_checkpoint_to_pytorch.cpython-37.pyc │ ├── convert_tf_checkpoint_to_pytorch.cpython-37.pyc │ ├── convert_transfo_xl_checkpoint_to_pytorch.cpython-37.pyc │ ├── file_utils.cpython-37.pyc │ ├── modeling.cpython-37.pyc │ ├── modeling_gpt2.cpython-37.pyc │ ├── modeling_openai.cpython-37.pyc │ ├── modeling_transfo_xl.cpython-37.pyc │ ├── modeling_transfo_xl_utilities.cpython-37.pyc │ ├── optimization.cpython-37.pyc │ ├── optimization_openai.cpython-37.pyc │ ├── tokenization.cpython-37.pyc │ ├── tokenization_gpt2.cpython-37.pyc │ ├── tokenization_openai.cpython-37.pyc │ └── tokenization_transfo_xl.cpython-37.pyc ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── optimization.py ├── optimization_openai.py ├── tokenization.py ├── tokenization_gpt2.py ├── tokenization_openai.py └── tokenization_transfo_xl.py ├── run.py ├── train_eval.py └── utils.py /ERNIE_pretrain/README.md: -------------------------------------------------------------------------------- 1 | ## 此处存放ERNIE预训练模型: 2 | pytorch_model.bin 3 | bert_config.json 4 | vocab.txt 5 | 6 | ## 下载地址: 7 | http://image.nghuyong.top/ERNIE.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 huwenxing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bert-Chinese-Text-Classification-Pytorch 2 | [![LICENSE](https://img.shields.io/badge/license-Anti%20996-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE) 3 | 4 | 中文文本分类,Bert,ERNIE,基于pytorch,开箱即用。 5 | 6 | ## 介绍 7 | 模型介绍、数据流动过程:还没写完,写好之后再贴博客地址。 8 | 机器:一块2080Ti , 训练时间:30分钟。 9 | 10 | ## 环境 11 | python 3.7 12 | pytorch 1.1 13 | tqdm 14 | sklearn 15 | tensorboardX 16 | ~~pytorch_pretrained_bert~~(预训练代码也上传了, 不需要这个库了) 17 | 18 | 19 | ## 中文数据集 20 | 我从[THUCNews](http://thuctc.thunlp.org/)中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。 21 | 22 | 类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。 23 | 24 | 数据集划分: 25 | 26 | 数据集|数据量 27 | --|-- 28 | 训练集|18万 29 | 验证集|1万 30 | 测试集|1万 31 | 32 | 33 | ### 更换自己的数据集 34 | - 按照我数据集的格式来格式化你的中文数据集。 35 | 36 | 37 | ## 效果 38 | 39 | 模型|acc|备注 40 | --|--|-- 41 | bert|94.83%|单纯的bert 42 | ERNIE|94.61%|说好的中文碾压bert呢 43 | bert_CNN|94.44%|bert + CNN 44 | bert_RNN|94.57%|bert + RNN 45 | bert_RCNN|94.51%|bert + RCNN 46 | bert_DPCNN|94.47%|bert + DPCNN 47 | 48 | 原始的bert效果就很好了,把bert当作embedding层送入其它模型,效果反而降了,之后会尝试长文本的效果对比。 49 | 50 | CNN、RNN、DPCNN、RCNN、RNN+Attention、FastText等模型效果,请见我另外一个[仓库](https://github.com/649453932/Chinese-Text-Classification-Pytorch)。 51 | 52 | ## 预训练语言模型 53 | bert模型放在 bert_pretain目录下,ERNIE模型放在ERNIE_pretrain目录下,每个目录下都是三个文件: 54 | - pytorch_model.bin 55 | - bert_config.json 56 | - vocab.txt 57 | 58 | 预训练模型下载地址: 59 | bert_Chinese: 模型 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz 60 | 词表 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt 61 | 来自[这里](https://github.com/huggingface/pytorch-transformers) 62 | 备用:模型的网盘地址:https://pan.baidu.com/s/1qSAD5gwClq7xlgzl_4W3Pw 63 | 64 | ERNIE_Chinese: http://image.nghuyong.top/ERNIE.zip 65 | 来自[这里](https://github.com/nghuyong/ERNIE-Pytorch) 66 | 备用:网盘地址:https://pan.baidu.com/s/1lEPdDN1-YQJmKEd_g9rLgw 67 | 68 | 解压后,按照上面说的放在对应目录下,文件名称确认无误即可。 69 | 70 | ## 使用说明 71 | 下载好预训练模型就可以跑了。 72 | ``` 73 | # 训练并测试: 74 | # bert 75 | python run.py --model bert 76 | 77 | # bert + 其它 78 | python run.py --model bert_CNN 79 | 80 | # ERNIE 81 | python run.py --model ERNIE 82 | ``` 83 | 84 | ### 参数 85 | 模型都在models目录下,超参定义和模型定义在同一文件中。 86 | 87 | ## 未完待续 88 | - bert + CNN, RNN, RCNN, DPCNN等 89 | - ERNIE + CNN, RNN, RCNN, DPCNN等 90 | - XLNET 91 | - 另外想加个label smoothing试试效果 92 | - 封装预测功能 93 | 94 | 95 | ## 对应论文 96 | [1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 97 | [2] ERNIE: Enhanced Representation through Knowledge Integration 98 | -------------------------------------------------------------------------------- /THUCNews/data/class.txt: -------------------------------------------------------------------------------- 1 | finance 2 | realty 3 | stocks 4 | education 5 | science 6 | society 7 | politics 8 | sports 9 | game 10 | entertainment -------------------------------------------------------------------------------- /bert_pretrain/README.md: -------------------------------------------------------------------------------- 1 | ## 此处存放bert预训练模型: 2 | pytorch_model.bin 3 | bert_config.json 4 | vocab.txt 5 | 6 | ## 下载地址: 7 | https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz -------------------------------------------------------------------------------- /models/ERNIE.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | class Config(object): 8 | 9 | """配置参数""" 10 | def __init__(self, dataset): 11 | self.model_name = 'ERNIE' 12 | self.train_path = dataset + '/data/train.txt' # 训练集 13 | self.dev_path = dataset + '/data/dev.txt' # 验证集 14 | self.test_path = dataset + '/data/test.txt' # 测试集 15 | self.class_list = [x.strip() for x in open( 16 | dataset + '/data/class.txt').readlines()] # 类别名单 17 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 18 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 19 | 20 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 21 | self.num_classes = len(self.class_list) # 类别数 22 | self.num_epochs = 3 # epoch数 23 | self.batch_size = 128 # mini-batch大小 24 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 25 | self.learning_rate = 5e-5 # 学习率 26 | self.bert_path = './ERNIE_pretrain' 27 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 28 | print(self.tokenizer) 29 | self.hidden_size = 768 30 | 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, config): 35 | super(Model, self).__init__() 36 | self.bert = BertModel.from_pretrained(config.bert_path) 37 | for param in self.bert.parameters(): 38 | param.requires_grad = True 39 | self.fc = nn.Linear(config.hidden_size, config.num_classes) 40 | 41 | def forward(self, x): 42 | context = x[0] # 输入的句子 43 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 44 | _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 45 | out = self.fc(pooled) 46 | return out 47 | -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, config): 35 | super(Model, self).__init__() 36 | self.bert = BertModel.from_pretrained(config.bert_path) 37 | for param in self.bert.parameters(): 38 | param.requires_grad = True 39 | self.fc = nn.Linear(config.hidden_size, config.num_classes) 40 | 41 | def forward(self, x): 42 | context = x[0] # 输入的句子 43 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 44 | _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 45 | out = self.fc(pooled) 46 | return out 47 | -------------------------------------------------------------------------------- /models/bert_CNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 31 | self.num_filters = 256 # 卷积核数量(channels数) 32 | self.dropout = 0.1 33 | 34 | 35 | class Model(nn.Module): 36 | 37 | def __init__(self, config): 38 | super(Model, self).__init__() 39 | self.bert = BertModel.from_pretrained(config.bert_path) 40 | for param in self.bert.parameters(): 41 | param.requires_grad = True 42 | self.convs = nn.ModuleList( 43 | [nn.Conv2d(1, config.num_filters, (k, config.hidden_size)) for k in config.filter_sizes]) 44 | self.dropout = nn.Dropout(config.dropout) 45 | 46 | self.fc_cnn = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes) 47 | 48 | def conv_and_pool(self, x, conv): 49 | x = F.relu(conv(x)).squeeze(3) 50 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 51 | return x 52 | 53 | def forward(self, x): 54 | context = x[0] # 输入的句子 55 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 56 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 57 | out = encoder_out.unsqueeze(1) 58 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) 59 | out = self.dropout(out) 60 | out = self.fc_cnn(out) 61 | return out 62 | -------------------------------------------------------------------------------- /models/bert_DPCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 6 | from pytorch_pretrained import BertModel, BertTokenizer 7 | 8 | 9 | class Config(object): 10 | 11 | """配置参数""" 12 | def __init__(self, dataset): 13 | self.model_name = 'bert' 14 | self.train_path = dataset + '/data/train.txt' # 训练集 15 | self.dev_path = dataset + '/data/dev.txt' # 验证集 16 | self.test_path = dataset + '/data/test.txt' # 测试集 17 | self.class_list = [x.strip() for x in open( 18 | dataset + '/data/class.txt').readlines()] # 类别名单 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 21 | 22 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 23 | self.num_classes = len(self.class_list) # 类别数 24 | self.num_epochs = 3 # epoch数 25 | self.batch_size = 128 # mini-batch大小 26 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 27 | self.learning_rate = 5e-5 # 学习率 28 | self.bert_path = './bert_pretrain' 29 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 30 | self.hidden_size = 768 31 | self.num_filters = 250 # 卷积核数量(channels数) 32 | 33 | 34 | class Model(nn.Module): 35 | 36 | def __init__(self, config): 37 | super(Model, self).__init__() 38 | self.bert = BertModel.from_pretrained(config.bert_path) 39 | for param in self.bert.parameters(): 40 | param.requires_grad = True 41 | # self.fc = nn.Linear(config.hidden_size, config.num_classes) 42 | self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.hidden_size), stride=1) 43 | self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1) 44 | self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) 45 | self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom 46 | self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom 47 | self.relu = nn.ReLU() 48 | self.fc = nn.Linear(config.num_filters, config.num_classes) 49 | 50 | def forward(self, x): 51 | context = x[0] # 输入的句子 52 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 53 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 54 | x = encoder_out.unsqueeze(1) # [batch_size, 1, seq_len, embed] 55 | x = self.conv_region(x) # [batch_size, 250, seq_len-3+1, 1] 56 | 57 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 58 | x = self.relu(x) 59 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 60 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 61 | x = self.relu(x) 62 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 63 | while x.size()[2] > 2: 64 | x = self._block(x) 65 | x = x.squeeze() # [batch_size, num_filters(250)] 66 | x = self.fc(x) 67 | return x 68 | 69 | def _block(self, x): 70 | x = self.padding2(x) 71 | px = self.max_pool(x) 72 | x = self.padding1(px) 73 | x = F.relu(x) 74 | x = self.conv(x) 75 | x = self.padding1(x) 76 | x = F.relu(x) 77 | x = self.conv(x) 78 | x = x + px # short cut 79 | return x 80 | -------------------------------------------------------------------------------- /models/bert_RCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 31 | self.num_filters = 256 # 卷积核数量(channels数) 32 | self.dropout = 0.1 33 | self.rnn_hidden = 256 34 | self.num_layers = 2 35 | 36 | 37 | class Model(nn.Module): 38 | 39 | def __init__(self, config): 40 | super(Model, self).__init__() 41 | self.bert = BertModel.from_pretrained(config.bert_path) 42 | for param in self.bert.parameters(): 43 | param.requires_grad = True 44 | self.lstm = nn.LSTM(config.hidden_size, config.rnn_hidden, config.num_layers, 45 | bidirectional=True, batch_first=True, dropout=config.dropout) 46 | self.maxpool = nn.MaxPool1d(config.pad_size) 47 | self.fc = nn.Linear(config.rnn_hidden * 2 + config.hidden_size, config.num_classes) 48 | 49 | def forward(self, x): 50 | context = x[0] # 输入的句子 51 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 52 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 53 | out, _ = self.lstm(encoder_out) 54 | out = torch.cat((encoder_out, out), 2) 55 | out = F.relu(out) 56 | out = out.permute(0, 2, 1) 57 | out = self.maxpool(out).squeeze() 58 | out = self.fc(out) 59 | return out 60 | -------------------------------------------------------------------------------- /models/bert_RNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 31 | self.num_filters = 256 # 卷积核数量(channels数) 32 | self.dropout = 0.1 33 | self.rnn_hidden = 768 34 | self.num_layers = 2 35 | 36 | 37 | class Model(nn.Module): 38 | 39 | def __init__(self, config): 40 | super(Model, self).__init__() 41 | self.bert = BertModel.from_pretrained(config.bert_path) 42 | for param in self.bert.parameters(): 43 | param.requires_grad = True 44 | self.lstm = nn.LSTM(config.hidden_size, config.rnn_hidden, config.num_layers, 45 | bidirectional=True, batch_first=True, dropout=config.dropout) 46 | self.dropout = nn.Dropout(config.dropout) 47 | self.fc_rnn = nn.Linear(config.rnn_hidden * 2, config.num_classes) 48 | 49 | def forward(self, x): 50 | context = x[0] # 输入的句子 51 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 52 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 53 | out, _ = self.lstm(encoder_out) 54 | out = self.dropout(out) 55 | out = self.fc_rnn(out[:, -1, :]) # 句子最后时刻的 hidden state 56 | return out 57 | -------------------------------------------------------------------------------- /pytorch_pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /pytorch_pretrained/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/__main__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/__main__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/convert_gpt2_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/convert_gpt2_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/convert_openai_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/convert_openai_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/convert_tf_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/convert_tf_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/convert_transfo_xl_checkpoint_to_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/convert_transfo_xl_checkpoint_to_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/file_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/file_utils.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/modeling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/modeling.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/modeling_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/modeling_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/modeling_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/modeling_openai.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/modeling_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/modeling_transfo_xl_utilities.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/optimization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/optimization.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/optimization_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/optimization_openai.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/tokenization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/tokenization.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/tokenization_gpt2.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/tokenization_openai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/tokenization_openai.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dasiki/Bert-Chinese-Text-Classification-Pytorch/6e0d5f27f6e2b298c32702b3506c67a2f417cb68/pytorch_pretrained/__pycache__/tokenization_transfo_xl.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_pretrained/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /pytorch_pretrained/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | -------------------------------------------------------------------------------- /pytorch_pretrained/modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import collections 21 | import copy 22 | import json 23 | import logging 24 | import math 25 | import os 26 | import shutil 27 | import tarfile 28 | import tempfile 29 | import sys 30 | from io import open 31 | 32 | import torch 33 | import torch.nn as nn 34 | from torch.nn import CrossEntropyLoss 35 | from torch.nn.parameter import Parameter 36 | 37 | from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME 38 | from .modeling import BertLayerNorm as LayerNorm 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} 43 | PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} 44 | 45 | def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): 46 | """ Load tf checkpoints in a pytorch model 47 | """ 48 | try: 49 | import re 50 | import numpy as np 51 | import tensorflow as tf 52 | except ImportError: 53 | print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 54 | "https://www.tensorflow.org/install/ for installation instructions.") 55 | raise 56 | tf_path = os.path.abspath(gpt2_checkpoint_path) 57 | print("Converting TensorFlow checkpoint from {}".format(tf_path)) 58 | # Load weights from TF model 59 | init_vars = tf.train.list_variables(tf_path) 60 | names = [] 61 | arrays = [] 62 | for name, shape in init_vars: 63 | print("Loading TF weight {} with shape {}".format(name, shape)) 64 | array = tf.train.load_variable(tf_path, name) 65 | names.append(name) 66 | arrays.append(array.squeeze()) 67 | 68 | for name, array in zip(names, arrays): 69 | name = name[6:] # skip "model/" 70 | name = name.split('/') 71 | pointer = model 72 | for m_name in name: 73 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 74 | l = re.split(r'(\d+)', m_name) 75 | else: 76 | l = [m_name] 77 | if l[0] == 'w' or l[0] == 'g': 78 | pointer = getattr(pointer, 'weight') 79 | elif l[0] == 'b': 80 | pointer = getattr(pointer, 'bias') 81 | elif l[0] == 'wpe' or l[0] == 'wte': 82 | pointer = getattr(pointer, l[0]) 83 | pointer = getattr(pointer, 'weight') 84 | else: 85 | pointer = getattr(pointer, l[0]) 86 | if len(l) >= 2: 87 | num = int(l[1]) 88 | pointer = pointer[num] 89 | try: 90 | assert pointer.shape == array.shape 91 | except AssertionError as e: 92 | e.args += (pointer.shape, array.shape) 93 | raise 94 | print("Initialize PyTorch weight {}".format(name)) 95 | pointer.data = torch.from_numpy(array) 96 | return model 97 | 98 | 99 | def gelu(x): 100 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 101 | 102 | 103 | class GPT2Config(object): 104 | """Configuration class to store the configuration of a `GPT2Model`. 105 | """ 106 | 107 | def __init__( 108 | self, 109 | vocab_size_or_config_json_file=50257, 110 | n_positions=1024, 111 | n_ctx=1024, 112 | n_embd=768, 113 | n_layer=12, 114 | n_head=12, 115 | layer_norm_epsilon=1e-5, 116 | initializer_range=0.02, 117 | ): 118 | """Constructs GPT2Config. 119 | 120 | Args: 121 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 122 | n_positions: Number of positional embeddings. 123 | n_ctx: Size of the causal mask (usually same as n_positions). 124 | n_embd: Dimensionality of the embeddings and hidden states. 125 | n_layer: Number of hidden layers in the Transformer encoder. 126 | n_head: Number of attention heads for each attention layer in 127 | the Transformer encoder. 128 | layer_norm_epsilon: epsilon to use in the layer norm layers 129 | initializer_range: The sttdev of the truncated_normal_initializer for 130 | initializing all weight matrices. 131 | """ 132 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 133 | and isinstance(vocab_size_or_config_json_file, unicode)): 134 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 135 | json_config = json.loads(reader.read()) 136 | for key, value in json_config.items(): 137 | self.__dict__[key] = value 138 | elif isinstance(vocab_size_or_config_json_file, int): 139 | self.vocab_size = vocab_size_or_config_json_file 140 | self.n_ctx = n_ctx 141 | self.n_positions = n_positions 142 | self.n_embd = n_embd 143 | self.n_layer = n_layer 144 | self.n_head = n_head 145 | self.layer_norm_epsilon = layer_norm_epsilon 146 | self.initializer_range = initializer_range 147 | else: 148 | raise ValueError( 149 | "First argument must be either a vocabulary size (int)" 150 | "or the path to a pretrained model config file (str)" 151 | ) 152 | 153 | @classmethod 154 | def from_dict(cls, json_object): 155 | """Constructs a `GPT2Config` from a Python dictionary of parameters.""" 156 | config = GPT2Config(vocab_size_or_config_json_file=-1) 157 | for key, value in json_object.items(): 158 | config.__dict__[key] = value 159 | return config 160 | 161 | @classmethod 162 | def from_json_file(cls, json_file): 163 | """Constructs a `GPT2Config` from a json file of parameters.""" 164 | with open(json_file, "r", encoding="utf-8") as reader: 165 | text = reader.read() 166 | return cls.from_dict(json.loads(text)) 167 | 168 | def __repr__(self): 169 | return str(self.to_json_string()) 170 | 171 | def to_dict(self): 172 | """Serializes this instance to a Python dictionary.""" 173 | output = copy.deepcopy(self.__dict__) 174 | return output 175 | 176 | def to_json_string(self): 177 | """Serializes this instance to a JSON string.""" 178 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 179 | 180 | def to_json_file(self, json_file_path): 181 | """ Save this instance to a json file.""" 182 | with open(json_file_path, "w", encoding='utf-8') as writer: 183 | writer.write(self.to_json_string()) 184 | 185 | 186 | class Conv1D(nn.Module): 187 | def __init__(self, nf, nx): 188 | super(Conv1D, self).__init__() 189 | self.nf = nf 190 | w = torch.empty(nx, nf) 191 | nn.init.normal_(w, std=0.02) 192 | self.weight = Parameter(w) 193 | self.bias = Parameter(torch.zeros(nf)) 194 | 195 | def forward(self, x): 196 | size_out = x.size()[:-1] + (self.nf,) 197 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 198 | x = x.view(*size_out) 199 | return x 200 | 201 | 202 | class Attention(nn.Module): 203 | def __init__(self, nx, n_ctx, config, scale=False): 204 | super(Attention, self).__init__() 205 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 206 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 207 | assert n_state % config.n_head == 0 208 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 209 | self.n_head = config.n_head 210 | self.split_size = n_state 211 | self.scale = scale 212 | self.c_attn = Conv1D(n_state * 3, nx) 213 | self.c_proj = Conv1D(n_state, nx) 214 | 215 | def _attn(self, q, k, v): 216 | w = torch.matmul(q, k) 217 | if self.scale: 218 | w = w / math.sqrt(v.size(-1)) 219 | nd, ns = w.size(-2), w.size(-1) 220 | b = self.bias[:, :, ns-nd:ns, :ns] 221 | w = w * b - 1e4 * (1 - b) 222 | 223 | w = nn.Softmax(dim=-1)(w) 224 | return torch.matmul(w, v) 225 | 226 | def merge_heads(self, x): 227 | x = x.permute(0, 2, 1, 3).contiguous() 228 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 229 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 230 | 231 | def split_heads(self, x, k=False): 232 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 233 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 234 | if k: 235 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 236 | else: 237 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 238 | 239 | def forward(self, x, layer_past=None): 240 | x = self.c_attn(x) 241 | query, key, value = x.split(self.split_size, dim=2) 242 | query = self.split_heads(query) 243 | key = self.split_heads(key, k=True) 244 | value = self.split_heads(value) 245 | if layer_past is not None: 246 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 247 | key = torch.cat((past_key, key), dim=-1) 248 | value = torch.cat((past_value, value), dim=-2) 249 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 250 | a = self._attn(query, key, value) 251 | a = self.merge_heads(a) 252 | a = self.c_proj(a) 253 | return a, present 254 | 255 | 256 | class MLP(nn.Module): 257 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 258 | super(MLP, self).__init__() 259 | nx = config.n_embd 260 | self.c_fc = Conv1D(n_state, nx) 261 | self.c_proj = Conv1D(nx, n_state) 262 | self.act = gelu 263 | 264 | def forward(self, x): 265 | h = self.act(self.c_fc(x)) 266 | h2 = self.c_proj(h) 267 | return h2 268 | 269 | 270 | class Block(nn.Module): 271 | def __init__(self, n_ctx, config, scale=False): 272 | super(Block, self).__init__() 273 | nx = config.n_embd 274 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 275 | self.attn = Attention(nx, n_ctx, config, scale) 276 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 277 | self.mlp = MLP(4 * nx, config) 278 | 279 | def forward(self, x, layer_past=None): 280 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 281 | x = x + a 282 | m = self.mlp(self.ln_2(x)) 283 | x = x + m 284 | return x, present 285 | 286 | 287 | class GPT2LMHead(nn.Module): 288 | """ Language Model Head for the transformer """ 289 | 290 | def __init__(self, model_embeddings_weights, config): 291 | super(GPT2LMHead, self).__init__() 292 | self.n_embd = config.n_embd 293 | self.set_embeddings_weights(model_embeddings_weights) 294 | 295 | def set_embeddings_weights(self, model_embeddings_weights): 296 | embed_shape = model_embeddings_weights.shape 297 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 298 | self.decoder.weight = model_embeddings_weights # Tied weights 299 | 300 | def forward(self, hidden_state): 301 | # Truncated Language modeling logits (we remove the last token) 302 | # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) 303 | lm_logits = self.decoder(hidden_state) 304 | return lm_logits 305 | 306 | 307 | class GPT2MultipleChoiceHead(nn.Module): 308 | """ Classifier Head for the transformer """ 309 | 310 | def __init__(self, config): 311 | super(GPT2MultipleChoiceHead, self).__init__() 312 | self.n_embd = config.n_embd 313 | self.linear = nn.Linear(config.n_embd, 1) 314 | 315 | nn.init.normal_(self.linear.weight, std=0.02) 316 | nn.init.normal_(self.linear.bias, 0) 317 | 318 | def forward(self, hidden_states, mc_token_ids): 319 | # Classification logits 320 | # hidden_state (bsz, num_choices, seq_length, hidden_size) 321 | # mc_token_ids (bsz, num_choices) 322 | mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) 323 | # (bsz, num_choices, 1, hidden_size) 324 | multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) 325 | # (bsz, num_choices, hidden_size) 326 | multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) 327 | # (bsz, num_choices) 328 | return multiple_choice_logits 329 | 330 | 331 | class GPT2PreTrainedModel(nn.Module): 332 | """ An abstract class to handle weights initialization and 333 | a simple interface for dowloading and loading pretrained models. 334 | """ 335 | 336 | def __init__(self, config, *inputs, **kwargs): 337 | super(GPT2PreTrainedModel, self).__init__() 338 | if not isinstance(config, GPT2Config): 339 | raise ValueError( 340 | "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " 341 | "To create a model from a pretrained model use " 342 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 343 | self.__class__.__name__, self.__class__.__name__ 344 | ) 345 | ) 346 | self.config = config 347 | 348 | def set_tied(self): 349 | pass 350 | 351 | def init_weights(self, module): 352 | """ Initialize the weights. 353 | """ 354 | if isinstance(module, (nn.Linear, nn.Embedding)): 355 | # Slightly different from the TF version which uses truncated_normal for initialization 356 | # cf https://github.com/pytorch/pytorch/pull/5617 357 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 358 | elif isinstance(module, LayerNorm): 359 | module.bias.data.zero_() 360 | module.weight.data.fill_(1.0) 361 | if isinstance(module, nn.Linear) and module.bias is not None: 362 | module.bias.data.zero_() 363 | 364 | @classmethod 365 | def from_pretrained( 366 | cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs 367 | ): 368 | """ 369 | Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict. 370 | Download and cache the pre-trained model file if needed. 371 | 372 | Params: 373 | pretrained_model_name_or_path: either: 374 | - a str with the name of a pre-trained model to load selected in the list of: 375 | . `gpt2` 376 | - a path or url to a pretrained model archive containing: 377 | . `gpt2_config.json` a configuration file for the model 378 | . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance 379 | - a path or url to a pretrained model archive containing: 380 | . `gpt2_config.json` a configuration file for the model 381 | . a TensorFlow checkpoint with trained weights 382 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 383 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 384 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 385 | *inputs, **kwargs: additional input for the specific GPT class 386 | """ 387 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 388 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 389 | config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] 390 | else: 391 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 392 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 393 | # redirect to the cache, if necessary 394 | try: 395 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 396 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir) 397 | except EnvironmentError: 398 | logger.error( 399 | "Model name '{}' was not found in model name list ({}). " 400 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 401 | "at this path or url.".format( 402 | pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, 403 | archive_file, config_file 404 | ) 405 | ) 406 | return None 407 | if resolved_archive_file == archive_file and resolved_config_file == config_file: 408 | logger.info("loading weights file {}".format(archive_file)) 409 | logger.info("loading configuration file {}".format(config_file)) 410 | else: 411 | logger.info("loading weights file {} from cache at {}".format( 412 | archive_file, resolved_archive_file)) 413 | logger.info("loading configuration file {} from cache at {}".format( 414 | config_file, resolved_config_file)) 415 | # Load config 416 | config = GPT2Config.from_json_file(resolved_config_file) 417 | logger.info("Model config {}".format(config)) 418 | # Instantiate model. 419 | model = cls(config, *inputs, **kwargs) 420 | if state_dict is None and not from_tf: 421 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 422 | if from_tf: 423 | # Directly load from a TensorFlow checkpoint (stored as NumPy array) 424 | return load_tf_weights_in_gpt2(model, resolved_archive_file) 425 | 426 | old_keys = [] 427 | new_keys = [] 428 | for key in state_dict.keys(): 429 | new_key = None 430 | if key.endswith(".g"): 431 | new_key = key[:-2] + ".weight" 432 | elif key.endswith(".b"): 433 | new_key = key[:-2] + ".bias" 434 | elif key.endswith(".w"): 435 | new_key = key[:-2] + ".weight" 436 | if new_key: 437 | old_keys.append(key) 438 | new_keys.append(new_key) 439 | for old_key, new_key in zip(old_keys, new_keys): 440 | state_dict[new_key] = state_dict.pop(old_key) 441 | 442 | missing_keys = [] 443 | unexpected_keys = [] 444 | error_msgs = [] 445 | # copy state_dict so _load_from_state_dict can modify it 446 | metadata = getattr(state_dict, "_metadata", None) 447 | state_dict = state_dict.copy() 448 | if metadata is not None: 449 | state_dict._metadata = metadata 450 | 451 | def load(module, prefix=""): 452 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 453 | module._load_from_state_dict( 454 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs 455 | ) 456 | for name, child in module._modules.items(): 457 | if child is not None: 458 | load(child, prefix + name + ".") 459 | 460 | start_model = model 461 | if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): 462 | start_model = model.transformer 463 | load(start_model, prefix="") 464 | 465 | if len(missing_keys) > 0: 466 | logger.info( 467 | "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys) 468 | ) 469 | if len(unexpected_keys) > 0: 470 | logger.info( 471 | "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys) 472 | ) 473 | if len(error_msgs) > 0: 474 | raise RuntimeError( 475 | "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) 476 | ) 477 | 478 | # Make sure we are still sharing the output and input embeddings after loading weights 479 | model.set_tied() 480 | return model 481 | 482 | 483 | class GPT2Model(GPT2PreTrainedModel): 484 | """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners"). 485 | 486 | Params: 487 | config: a GPT2Config class instance with the configuration to build a new model 488 | 489 | Inputs: 490 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] 491 | were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ 492 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids 493 | with the position indices (selected in the range [0, config.n_positions - 1[. 494 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids 495 | You can use it to add a third type of embedding to each input token in the sequence 496 | (the previous two being the word and position embeddings). 497 | The input, position and token_type embeddings are summed inside the Transformer before the first 498 | self-attention block. 499 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states 500 | (key and values in the attention blocks) to speed up sequential decoding 501 | (this is the presents output of the model, cf. below). 502 | 503 | Outputs a tuple consisting of: 504 | `hidden_states`: the encoded-hidden-states at the top of the model 505 | as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] 506 | (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids) 507 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as 508 | torch.FloatTensors. They can be reused to speed up sequential decoding. 509 | 510 | Example usage: 511 | ```python 512 | # Already been converted into BPE token ids 513 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 514 | 515 | config = modeling_gpt2.GPT2Config() 516 | 517 | model = modeling_gpt2.GPT2Model(config) 518 | hidden_states, presents = model(input_ids) 519 | ``` 520 | """ 521 | 522 | def __init__(self, config): 523 | super(GPT2Model, self).__init__(config) 524 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 525 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 526 | block = Block(config.n_ctx, config, scale=True) 527 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 528 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 529 | 530 | self.apply(self.init_weights) 531 | 532 | def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): 533 | if past is None: 534 | past_length = 0 535 | past = [None] * len(self.h) 536 | else: 537 | past_length = past[0][0].size(-2) 538 | if position_ids is None: 539 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) 540 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 541 | 542 | input_shape = input_ids.size() 543 | input_ids = input_ids.view(-1, input_ids.size(-1)) 544 | position_ids = position_ids.view(-1, position_ids.size(-1)) 545 | 546 | inputs_embeds = self.wte(input_ids) 547 | position_embeds = self.wpe(position_ids) 548 | if token_type_ids is not None: 549 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 550 | token_type_embeds = self.wte(token_type_ids) 551 | else: 552 | token_type_embeds = 0 553 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 554 | presents = [] 555 | for block, layer_past in zip(self.h, past): 556 | hidden_states, present = block(hidden_states, layer_past) 557 | presents.append(present) 558 | hidden_states = self.ln_f(hidden_states) 559 | output_shape = input_shape + (hidden_states.size(-1),) 560 | return hidden_states.view(*output_shape), presents 561 | 562 | 563 | class GPT2LMHeadModel(GPT2PreTrainedModel): 564 | """OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners"). 565 | 566 | Params: 567 | config: a GPT2Config class instance with the configuration to build a new model 568 | 569 | Inputs: 570 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] 571 | were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ 572 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids 573 | with the position indices (selected in the range [0, config.n_positions - 1[. 574 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids 575 | You can use it to add a third type of embedding to each input token in the sequence 576 | (the previous two being the word and position embeddings). 577 | The input, position and token_type embeddings are summed inside the Transformer before the first 578 | self-attention block. 579 | `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 580 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 581 | is only computed for the labels set in [0, ..., vocab_size] 582 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states 583 | (key and values in the attention blocks) to speed up sequential decoding 584 | (this is the presents output of the model, cf. below). 585 | 586 | Outputs: 587 | if `lm_labels` is not `None`: 588 | Outputs the language modeling loss. 589 | else a tuple: 590 | `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size] 591 | (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids) 592 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as 593 | torch.FloatTensors. They can be reused to speed up sequential decoding. 594 | 595 | Example usage: 596 | ```python 597 | # Already been converted into BPE token ids 598 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 599 | 600 | config = modeling_gpt2.GPT2Config() 601 | 602 | model = modeling_gpt2.GPT2LMHeadModel(config) 603 | lm_logits, presents = model(input_ids) 604 | ``` 605 | """ 606 | 607 | def __init__(self, config): 608 | super(GPT2LMHeadModel, self).__init__(config) 609 | self.transformer = GPT2Model(config) 610 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) 611 | self.apply(self.init_weights) 612 | 613 | def set_tied(self): 614 | """ Make sure we are sharing the embeddings 615 | """ 616 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 617 | 618 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): 619 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 620 | lm_logits = self.lm_head(hidden_states) 621 | if lm_labels is not None: 622 | # Shift so that tokens < n predict n 623 | shift_logits = lm_logits[:, :-1].contiguous() 624 | shift_labels = lm_labels[:, 1:].contiguous() 625 | 626 | # Flatten the tokens 627 | loss_fct = CrossEntropyLoss(ignore_index=-1) 628 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 629 | shift_labels.view(-1)) 630 | return loss 631 | return lm_logits, presents 632 | 633 | 634 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel): 635 | """OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners"). 636 | 637 | Params: 638 | config: a GPT2Config class instance with the configuration to build a new model 639 | 640 | Inputs: 641 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token 642 | indices selected in the range [0, config.vocab_size[ 643 | `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from 644 | which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence) 645 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids 646 | with the position indices (selected in the range [0, config.n_positions - 1[. 647 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids 648 | You can use it to add a third type of embedding to each input token in the sequence 649 | (the previous two being the word and position embeddings). 650 | The input, position and token_type embeddings are summed inside the Transformer before the first 651 | self-attention block. 652 | `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length] 653 | with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss 654 | is only computed for the labels set in [0, ..., config.vocab_size] 655 | `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] 656 | with indices selected in [0, ..., num_choices]. 657 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states 658 | (key and values in the attention blocks) to speed up sequential decoding 659 | (this is the presents output of the model, cf. below). 660 | 661 | Outputs: 662 | if `lm_labels` and `multiple_choice_labels` are not `None`: 663 | Outputs a tuple of losses with the language modeling loss and the multiple choice loss. 664 | else: a tuple with 665 | `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size] 666 | `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices] 667 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as 668 | torch.FloatTensors. They can be reused to speed up sequential decoding. 669 | 670 | Example usage: 671 | ```python 672 | # Already been converted into BPE token ids 673 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length) 674 | mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice) 675 | 676 | config = modeling_gpt2.GPT2Config() 677 | 678 | model = modeling_gpt2.GPT2LMHeadModel(config) 679 | lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids) 680 | ``` 681 | """ 682 | 683 | def __init__(self, config): 684 | super(GPT2DoubleHeadsModel, self).__init__(config) 685 | self.transformer = GPT2Model(config) 686 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) 687 | self.multiple_choice_head = GPT2MultipleChoiceHead(config) 688 | self.apply(self.init_weights) 689 | 690 | def set_tied(self): 691 | """ Make sure we are sharing the embeddings 692 | """ 693 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 694 | 695 | def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): 696 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 697 | lm_logits = self.lm_head(hidden_states) 698 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) 699 | losses = [] 700 | if lm_labels is not None: 701 | shift_logits = lm_logits[:, :-1].contiguous() 702 | shift_labels = lm_labels[:, 1:].contiguous() 703 | loss_fct = CrossEntropyLoss(ignore_index=-1) 704 | losses.append(loss_fct(shift_logits.view(-1, 705 | shift_logits.size(-1)), shift_labels.view(-1))) 706 | if mc_labels is not None: 707 | loss_fct = CrossEntropyLoss() 708 | losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) 709 | if losses: 710 | return losses 711 | return lm_logits, mc_logits, presents 712 | -------------------------------------------------------------------------------- /pytorch_pretrained/modeling_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Utilities for PyTorch Transformer XL model. 17 | Directly adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | 20 | from collections import defaultdict 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 30 | 31 | class ProjectedAdaptiveLogSoftmax(nn.Module): 32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 33 | keep_order=False): 34 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 35 | 36 | self.n_token = n_token 37 | self.d_embed = d_embed 38 | self.d_proj = d_proj 39 | 40 | self.cutoffs = cutoffs + [n_token] 41 | self.cutoff_ends = [0] + self.cutoffs 42 | self.div_val = div_val 43 | 44 | self.shortlist_size = self.cutoffs[0] 45 | self.n_clusters = len(self.cutoffs) - 1 46 | self.head_size = self.shortlist_size + self.n_clusters 47 | 48 | if self.n_clusters > 0: 49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 51 | 52 | self.out_layers = nn.ModuleList() 53 | self.out_projs = nn.ParameterList() 54 | 55 | if div_val == 1: 56 | for i in range(len(self.cutoffs)): 57 | if d_proj != d_embed: 58 | self.out_projs.append( 59 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 60 | ) 61 | else: 62 | self.out_projs.append(None) 63 | 64 | self.out_layers.append(nn.Linear(d_embed, n_token)) 65 | else: 66 | for i in range(len(self.cutoffs)): 67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 68 | d_emb_i = d_embed // (div_val ** i) 69 | 70 | self.out_projs.append( 71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 72 | ) 73 | 74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 75 | 76 | self.keep_order = keep_order 77 | 78 | def _compute_logit(self, hidden, weight, bias, proj): 79 | if proj is None: 80 | logit = F.linear(hidden, weight, bias=bias) 81 | else: 82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 83 | proj_hid = F.linear(hidden, proj.t().contiguous()) 84 | logit = F.linear(proj_hid, weight, bias=bias) 85 | # else: 86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 87 | # if bias is not None: 88 | # logit = logit + bias 89 | 90 | return logit 91 | 92 | def forward(self, hidden, target=None, keep_order=False): 93 | ''' 94 | Params: 95 | hidden :: [len*bsz x d_proj] 96 | target :: [len*bsz] 97 | Return: 98 | if target is None: 99 | out :: [len*bsz] Negative log likelihood 100 | else: 101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary 102 | We could replace this implementation by the native PyTorch one 103 | if their's had an option to set bias on all clusters in the native one. 104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 105 | ''' 106 | 107 | if target is not None: 108 | target = target.view(-1) 109 | if hidden.size(0) != target.size(0): 110 | raise RuntimeError('Input and target should have the same size ' 111 | 'in the batch dimension.') 112 | 113 | if self.n_clusters == 0: 114 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 115 | self.out_layers[0].bias, self.out_projs[0]) 116 | if target is not None: 117 | output = -F.log_softmax(logit, dim=-1) \ 118 | .gather(1, target.unsqueeze(1)).squeeze(1) 119 | else: 120 | output = F.log_softmax(logit, dim=-1) 121 | else: 122 | # construct weights and biases 123 | weights, biases = [], [] 124 | for i in range(len(self.cutoffs)): 125 | if self.div_val == 1: 126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 127 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 128 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 129 | else: 130 | weight_i = self.out_layers[i].weight 131 | bias_i = self.out_layers[i].bias 132 | 133 | if i == 0: 134 | weight_i = torch.cat( 135 | [weight_i, self.cluster_weight], dim=0) 136 | bias_i = torch.cat( 137 | [bias_i, self.cluster_bias], dim=0) 138 | 139 | weights.append(weight_i) 140 | biases.append(bias_i) 141 | 142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 143 | 144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 145 | head_logprob = F.log_softmax(head_logit, dim=1) 146 | 147 | if target is None: 148 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 149 | else: 150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) 151 | 152 | offset = 0 153 | cutoff_values = [0] + self.cutoffs 154 | for i in range(len(cutoff_values) - 1): 155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 156 | 157 | if target is not None: 158 | mask_i = (target >= l_idx) & (target < r_idx) 159 | indices_i = mask_i.nonzero().squeeze() 160 | 161 | if indices_i.numel() == 0: 162 | continue 163 | 164 | target_i = target.index_select(0, indices_i) - l_idx 165 | head_logprob_i = head_logprob.index_select(0, indices_i) 166 | hidden_i = hidden.index_select(0, indices_i) 167 | else: 168 | hidden_i = hidden 169 | 170 | if i == 0: 171 | if target is not None: 172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 173 | else: 174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 175 | else: 176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 177 | 178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 181 | if target is not None: 182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \ 183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) 184 | else: 185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i 186 | out[:, l_idx:r_idx] = logprob_i 187 | 188 | if target is not None: 189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 190 | out.index_copy_(0, indices_i, -logprob_i) 191 | else: 192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 193 | offset += logprob_i.size(0) 194 | 195 | return out 196 | 197 | 198 | def log_prob(self, hidden): 199 | r""" Computes log probabilities for all :math:`n\_classes` 200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py 201 | Args: 202 | hidden (Tensor): a minibatch of examples 203 | Returns: 204 | log-probabilities of for each class :math:`c` 205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a 206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 207 | Shape: 208 | - Input: :math:`(N, in\_features)` 209 | - Output: :math:`(N, n\_classes)` 210 | """ 211 | if self.n_clusters == 0: 212 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 213 | self.out_layers[0].bias, self.out_projs[0]) 214 | return F.log_softmax(logit, dim=-1) 215 | else: 216 | # construct weights and biases 217 | weights, biases = [], [] 218 | for i in range(len(self.cutoffs)): 219 | if self.div_val == 1: 220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 221 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 222 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 223 | else: 224 | weight_i = self.out_layers[i].weight 225 | bias_i = self.out_layers[i].bias 226 | 227 | if i == 0: 228 | weight_i = torch.cat( 229 | [weight_i, self.cluster_weight], dim=0) 230 | bias_i = torch.cat( 231 | [bias_i, self.cluster_bias], dim=0) 232 | 233 | weights.append(weight_i) 234 | biases.append(bias_i) 235 | 236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 238 | 239 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 240 | head_logprob = F.log_softmax(head_logit, dim=1) 241 | 242 | cutoff_values = [0] + self.cutoffs 243 | for i in range(len(cutoff_values) - 1): 244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] 245 | 246 | if i == 0: 247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 248 | else: 249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 250 | 251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) 252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 253 | 254 | logprob_i = head_logprob[:, -i] + tail_logprob_i 255 | out[:, start_idx, stop_idx] = logprob_i 256 | 257 | return out 258 | 259 | 260 | class LogUniformSampler(object): 261 | def __init__(self, range_max, n_sample): 262 | """ 263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 265 | 266 | expected count can be approximated by 1 - (1 - p)^n 267 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 268 | 269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 270 | """ 271 | with torch.no_grad(): 272 | self.range_max = range_max 273 | log_indices = torch.arange(1., range_max+2., 1.).log_() 274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 275 | # print('P', self.dist.numpy().tolist()[-30:]) 276 | 277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 278 | 279 | self.n_sample = n_sample 280 | 281 | def sample(self, labels): 282 | """ 283 | labels: [b1, b2] 284 | Return 285 | true_log_probs: [b1, b2] 286 | samp_log_probs: [n_sample] 287 | neg_samples: [n_sample] 288 | """ 289 | 290 | # neg_samples = torch.empty(0).long() 291 | n_sample = self.n_sample 292 | n_tries = 2 * n_sample 293 | 294 | with torch.no_grad(): 295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 296 | device = labels.device 297 | neg_samples = neg_samples.to(device) 298 | true_log_probs = self.log_q[labels].to(device) 299 | samp_log_probs = self.log_q[neg_samples].to(device) 300 | return true_log_probs, samp_log_probs, neg_samples 301 | 302 | def sample_logits(embedding, bias, labels, inputs, sampler): 303 | """ 304 | embedding: an nn.Embedding layer 305 | bias: [n_vocab] 306 | labels: [b1, b2] 307 | inputs: [b1, b2, n_emb] 308 | sampler: you may use a LogUniformSampler 309 | Return 310 | logits: [b1, b2, 1 + n_sample] 311 | """ 312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 313 | n_sample = neg_samples.size(0) 314 | b1, b2 = labels.size(0), labels.size(1) 315 | all_ids = torch.cat([labels.view(-1), neg_samples]) 316 | all_w = embedding(all_ids) 317 | true_w = all_w[: -n_sample].view(b1, b2, -1) 318 | sample_w = all_w[- n_sample:].view(n_sample, -1) 319 | 320 | all_b = bias[all_ids] 321 | true_b = all_b[: -n_sample].view(b1, b2) 322 | sample_b = all_b[- n_sample:] 323 | 324 | hit = (labels[:, :, None] == neg_samples).detach() 325 | 326 | true_logits = torch.einsum('ijk,ijk->ij', 327 | [true_w, inputs]) + true_b - true_log_probs 328 | sample_logits = torch.einsum('lk,ijk->ijl', 329 | [sample_w, inputs]) + sample_b - samp_log_probs 330 | sample_logits.masked_fill_(hit, -1e30) 331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 332 | 333 | return logits 334 | 335 | 336 | # class LogUniformSampler(object): 337 | # def __init__(self, range_max, unique=False): 338 | # """ 339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 341 | # """ 342 | # self.range_max = range_max 343 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 345 | 346 | # self.unique = unique 347 | 348 | # if self.unique: 349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 350 | 351 | # def sample(self, n_sample, labels): 352 | # pos_sample, new_labels = labels.unique(return_inverse=True) 353 | # n_pos_sample = pos_sample.size(0) 354 | # n_neg_sample = n_sample - n_pos_sample 355 | 356 | # if self.unique: 357 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 359 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 360 | # else: 361 | # sample_dist = self.dist 362 | 363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 364 | 365 | # sample = torch.cat([pos_sample, neg_sample]) 366 | # sample_prob = self.dist[sample] 367 | 368 | # return new_labels, sample, sample_prob 369 | 370 | 371 | if __name__ == '__main__': 372 | S, B = 3, 4 373 | n_vocab = 10000 374 | n_sample = 5 375 | H = 32 376 | 377 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 378 | 379 | # sampler = LogUniformSampler(n_vocab, unique=False) 380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 381 | 382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True) 383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 384 | 385 | # print('true_probs', true_probs.numpy().tolist()) 386 | # print('samp_probs', samp_probs.numpy().tolist()) 387 | # print('neg_samples', neg_samples.numpy().tolist()) 388 | 389 | # print('sum', torch.sum(sampler.dist).item()) 390 | 391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 392 | 393 | embedding = nn.Embedding(n_vocab, H) 394 | bias = torch.zeros(n_vocab) 395 | inputs = torch.Tensor(S, B, H).normal_() 396 | 397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 398 | print('logits', logits.detach().numpy().tolist()) 399 | print('logits shape', logits.size()) 400 | print('out_labels', out_labels.detach().numpy().tolist()) 401 | print('out_labels shape', out_labels.size()) 402 | 403 | -------------------------------------------------------------------------------- /pytorch_pretrained/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | 240 | Arguments: 241 | closure (callable, optional): A closure that reevaluates the model 242 | and returns the loss. 243 | """ 244 | loss = None 245 | if closure is not None: 246 | loss = closure() 247 | 248 | for group in self.param_groups: 249 | for p in group['params']: 250 | if p.grad is None: 251 | continue 252 | grad = p.grad.data 253 | if grad.is_sparse: 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 255 | 256 | state = self.state[p] 257 | 258 | # State initialization 259 | if len(state) == 0: 260 | state['step'] = 0 261 | # Exponential moving average of gradient values 262 | state['next_m'] = torch.zeros_like(p.data) 263 | # Exponential moving average of squared gradient values 264 | state['next_v'] = torch.zeros_like(p.data) 265 | 266 | next_m, next_v = state['next_m'], state['next_v'] 267 | beta1, beta2 = group['b1'], group['b2'] 268 | 269 | # Add grad clipping 270 | if group['max_grad_norm'] > 0: 271 | clip_grad_norm_(p, group['max_grad_norm']) 272 | 273 | # Decay the first and second moment running average coefficient 274 | # In-place operations to update the averages at the same time 275 | next_m.mul_(beta1).add_(1 - beta1, grad) 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 277 | update = next_m / (next_v.sqrt() + group['e']) 278 | 279 | # Just adding the square of the weights to the loss function is *not* 280 | # the correct way of using L2 regularization/weight decay with Adam, 281 | # since that will interact with the m and v parameters in strange ways. 282 | # 283 | # Instead we want to decay the weights in a manner that doesn't interact 284 | # with the m/v parameters. This is equivalent to adding the square 285 | # of the weights to the loss with plain (non-momentum) SGD. 286 | if group['weight_decay'] > 0.0: 287 | update += group['weight_decay'] * p.data 288 | 289 | lr_scheduled = group['lr'] 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) 291 | 292 | update_with_lr = lr_scheduled * update 293 | p.data.add_(-update_with_lr) 294 | 295 | state['step'] += 1 296 | 297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 298 | # No bias correction 299 | # bias_correction1 = 1 - beta1 ** state['step'] 300 | # bias_correction2 = 1 - beta2 ** state['step'] 301 | 302 | return loss 303 | -------------------------------------------------------------------------------- /pytorch_pretrained/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['b1'], group['b2'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | """Constructs a BertTokenizer. 80 | 81 | Args: 82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 83 | do_lower_case: Whether to lower case the input 84 | Only has an effect when do_wordpiece_only=False 85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 86 | max_len: An artificial maximum length to truncate tokenized sequences to; 87 | Effective maximum length is always the minimum of this 88 | value (if specified) and the underlying BERT model's 89 | sequence length. 90 | never_split: List of tokens which will never be split during tokenization. 91 | Only has an effect when do_wordpiece_only=False 92 | """ 93 | if not os.path.isfile(vocab_file): 94 | raise ValueError( 95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 97 | self.vocab = load_vocab(vocab_file) 98 | self.ids_to_tokens = collections.OrderedDict( 99 | [(ids, tok) for tok, ids in self.vocab.items()]) 100 | self.do_basic_tokenize = do_basic_tokenize 101 | if do_basic_tokenize: 102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 103 | never_split=never_split) 104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 105 | self.max_len = max_len if max_len is not None else int(1e12) 106 | 107 | def tokenize(self, text): 108 | split_tokens = [] 109 | if self.do_basic_tokenize: 110 | for token in self.basic_tokenizer.tokenize(text): 111 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 112 | split_tokens.append(sub_token) 113 | else: 114 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | """Converts a sequence of tokens into ids using the vocab.""" 119 | ids = [] 120 | for token in tokens: 121 | ids.append(self.vocab[token]) 122 | if len(ids) > self.max_len: 123 | logger.warning( 124 | "Token indices sequence length is longer than the specified maximum " 125 | " sequence length for this BERT model ({} > {}). Running this" 126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 127 | ) 128 | return ids 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 132 | tokens = [] 133 | for i in ids: 134 | tokens.append(self.ids_to_tokens[i]) 135 | return tokens 136 | 137 | def save_vocabulary(self, vocab_path): 138 | """Save the tokenizer vocabulary to a directory or file.""" 139 | index = 0 140 | if os.path.isdir(vocab_path): 141 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 142 | with open(vocab_file, "w", encoding="utf-8") as writer: 143 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 144 | if index != token_index: 145 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 146 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 147 | index = token_index 148 | writer.write(token + u'\n') 149 | index += 1 150 | return vocab_file 151 | 152 | @classmethod 153 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 154 | """ 155 | Instantiate a PreTrainedBertModel from a pre-trained model file. 156 | Download and cache the pre-trained model file if needed. 157 | """ 158 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 159 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 160 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 161 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 162 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 163 | "you may want to check this behavior.") 164 | kwargs['do_lower_case'] = False 165 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 166 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 167 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 168 | "but you may want to check this behavior.") 169 | kwargs['do_lower_case'] = True 170 | else: 171 | vocab_file = pretrained_model_name_or_path 172 | if os.path.isdir(vocab_file): 173 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 174 | # redirect to the cache, if necessary 175 | try: 176 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 177 | except EnvironmentError: 178 | logger.error( 179 | "Model name '{}' was not found in model name list ({}). " 180 | "We assumed '{}' was a path or url but couldn't find any file " 181 | "associated to this path or url.".format( 182 | pretrained_model_name_or_path, 183 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 184 | vocab_file)) 185 | return None 186 | if resolved_vocab_file == vocab_file: 187 | logger.info("loading vocabulary file {}".format(vocab_file)) 188 | else: 189 | logger.info("loading vocabulary file {} from cache at {}".format( 190 | vocab_file, resolved_vocab_file)) 191 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 192 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 193 | # than the number of positional embeddings 194 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 195 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 196 | # Instantiate tokenizer. 197 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 198 | return tokenizer 199 | 200 | 201 | class BasicTokenizer(object): 202 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 203 | 204 | def __init__(self, 205 | do_lower_case=True, 206 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 207 | """Constructs a BasicTokenizer. 208 | 209 | Args: 210 | do_lower_case: Whether to lower case the input. 211 | """ 212 | self.do_lower_case = do_lower_case 213 | self.never_split = never_split 214 | 215 | def tokenize(self, text): 216 | """Tokenizes a piece of text.""" 217 | text = self._clean_text(text) 218 | # This was added on November 1st, 2018 for the multilingual and Chinese 219 | # models. This is also applied to the English models now, but it doesn't 220 | # matter since the English models were not trained on any Chinese data 221 | # and generally don't have any Chinese data in them (there are Chinese 222 | # characters in the vocabulary because Wikipedia does have some Chinese 223 | # words in the English Wikipedia.). 224 | text = self._tokenize_chinese_chars(text) 225 | orig_tokens = whitespace_tokenize(text) 226 | split_tokens = [] 227 | for token in orig_tokens: 228 | if self.do_lower_case and token not in self.never_split: 229 | token = token.lower() 230 | token = self._run_strip_accents(token) 231 | split_tokens.extend(self._run_split_on_punc(token)) 232 | 233 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 234 | return output_tokens 235 | 236 | def _run_strip_accents(self, text): 237 | """Strips accents from a piece of text.""" 238 | text = unicodedata.normalize("NFD", text) 239 | output = [] 240 | for char in text: 241 | cat = unicodedata.category(char) 242 | if cat == "Mn": 243 | continue 244 | output.append(char) 245 | return "".join(output) 246 | 247 | def _run_split_on_punc(self, text): 248 | """Splits punctuation on a piece of text.""" 249 | if text in self.never_split: 250 | return [text] 251 | chars = list(text) 252 | i = 0 253 | start_new_word = True 254 | output = [] 255 | while i < len(chars): 256 | char = chars[i] 257 | if _is_punctuation(char): 258 | output.append([char]) 259 | start_new_word = True 260 | else: 261 | if start_new_word: 262 | output.append([]) 263 | start_new_word = False 264 | output[-1].append(char) 265 | i += 1 266 | 267 | return ["".join(x) for x in output] 268 | 269 | def _tokenize_chinese_chars(self, text): 270 | """Adds whitespace around any CJK character.""" 271 | output = [] 272 | for char in text: 273 | cp = ord(char) 274 | if self._is_chinese_char(cp): 275 | output.append(" ") 276 | output.append(char) 277 | output.append(" ") 278 | else: 279 | output.append(char) 280 | return "".join(output) 281 | 282 | def _is_chinese_char(self, cp): 283 | """Checks whether CP is the codepoint of a CJK character.""" 284 | # This defines a "chinese character" as anything in the CJK Unicode block: 285 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 286 | # 287 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 288 | # despite its name. The modern Korean Hangul alphabet is a different block, 289 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 290 | # space-separated words, so they are not treated specially and handled 291 | # like the all of the other languages. 292 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 293 | (cp >= 0x3400 and cp <= 0x4DBF) or # 294 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 295 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 296 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 297 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 298 | (cp >= 0xF900 and cp <= 0xFAFF) or # 299 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 300 | return True 301 | 302 | return False 303 | 304 | def _clean_text(self, text): 305 | """Performs invalid character removal and whitespace cleanup on text.""" 306 | output = [] 307 | for char in text: 308 | cp = ord(char) 309 | if cp == 0 or cp == 0xfffd or _is_control(char): 310 | continue 311 | if _is_whitespace(char): 312 | output.append(" ") 313 | else: 314 | output.append(char) 315 | return "".join(output) 316 | 317 | 318 | class WordpieceTokenizer(object): 319 | """Runs WordPiece tokenization.""" 320 | 321 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 322 | self.vocab = vocab 323 | self.unk_token = unk_token 324 | self.max_input_chars_per_word = max_input_chars_per_word 325 | 326 | def tokenize(self, text): 327 | """Tokenizes a piece of text into its word pieces. 328 | 329 | This uses a greedy longest-match-first algorithm to perform tokenization 330 | using the given vocabulary. 331 | 332 | For example: 333 | input = "unaffable" 334 | output = ["un", "##aff", "##able"] 335 | 336 | Args: 337 | text: A single token or whitespace separated tokens. This should have 338 | already been passed through `BasicTokenizer`. 339 | 340 | Returns: 341 | A list of wordpiece tokens. 342 | """ 343 | 344 | output_tokens = [] 345 | for token in whitespace_tokenize(text): 346 | chars = list(token) 347 | if len(chars) > self.max_input_chars_per_word: 348 | output_tokens.append(self.unk_token) 349 | continue 350 | 351 | is_bad = False 352 | start = 0 353 | sub_tokens = [] 354 | while start < len(chars): 355 | end = len(chars) 356 | cur_substr = None 357 | while start < end: 358 | substr = "".join(chars[start:end]) 359 | if start > 0: 360 | substr = "##" + substr 361 | if substr in self.vocab: 362 | cur_substr = substr 363 | break 364 | end -= 1 365 | if cur_substr is None: 366 | is_bad = True 367 | break 368 | sub_tokens.append(cur_substr) 369 | start = end 370 | 371 | if is_bad: 372 | output_tokens.append(self.unk_token) 373 | else: 374 | output_tokens.extend(sub_tokens) 375 | return output_tokens 376 | 377 | 378 | def _is_whitespace(char): 379 | """Checks whether `chars` is a whitespace character.""" 380 | # \t, \n, and \r are technically contorl characters but we treat them 381 | # as whitespace since they are generally considered as such. 382 | if char == " " or char == "\t" or char == "\n" or char == "\r": 383 | return True 384 | cat = unicodedata.category(char) 385 | if cat == "Zs": 386 | return True 387 | return False 388 | 389 | 390 | def _is_control(char): 391 | """Checks whether `chars` is a control character.""" 392 | # These are technically control characters but we count them as whitespace 393 | # characters. 394 | if char == "\t" or char == "\n" or char == "\r": 395 | return False 396 | cat = unicodedata.category(char) 397 | if cat.startswith("C"): 398 | return True 399 | return False 400 | 401 | 402 | def _is_punctuation(char): 403 | """Checks whether `chars` is a punctuation character.""" 404 | cp = ord(char) 405 | # We treat all non-letter/number ASCII as punctuation. 406 | # Characters such as "^", "$", and "`" are not in the Unicode 407 | # Punctuation class but we treat them as punctuation anyways, for 408 | # consistency. 409 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 410 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 411 | return True 412 | cat = unicodedata.category(char) 413 | if cat.startswith("P"): 414 | return True 415 | return False 416 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .file_utils import cached_path 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 40 | } 41 | PRETRAINED_MERGES_ARCHIVE_MAP = { 42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 43 | } 44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 45 | 'gpt2': 1024, 46 | } 47 | VOCAB_NAME = 'vocab.json' 48 | MERGES_NAME = 'merges.txt' 49 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 50 | 51 | @lru_cache() 52 | def bytes_to_unicode(): 53 | """ 54 | Returns list of utf-8 byte and a corresponding list of unicode strings. 55 | The reversible bpe codes work on unicode strings. 56 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 57 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 58 | This is a signficant percentage of your normal, say, 32K bpe vocab. 59 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 60 | And avoids mapping to whitespace/control characters the bpe code barfs on. 61 | """ 62 | _chr = unichr if sys.version_info[0] == 2 else chr 63 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 64 | cs = bs[:] 65 | n = 0 66 | for b in range(2**8): 67 | if b not in bs: 68 | bs.append(b) 69 | cs.append(2**8+n) 70 | n += 1 71 | cs = [_chr(n) for n in cs] 72 | return dict(zip(bs, cs)) 73 | 74 | def get_pairs(word): 75 | """Return set of symbol pairs in a word. 76 | 77 | Word is represented as tuple of symbols (symbols being variable-length strings). 78 | """ 79 | pairs = set() 80 | prev_char = word[0] 81 | for char in word[1:]: 82 | pairs.add((prev_char, char)) 83 | prev_char = char 84 | return pairs 85 | 86 | class GPT2Tokenizer(object): 87 | """ 88 | GPT-2 BPE tokenizer. Peculiarities: 89 | - Byte-level BPE 90 | """ 91 | @classmethod 92 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 93 | """ 94 | Instantiate a PreTrainedBertModel from a pre-trained model file. 95 | Download and cache the pre-trained model file if needed. 96 | """ 97 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 98 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 99 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 100 | special_tokens_file = None 101 | else: 102 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 103 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 104 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 105 | if not os.path.exists(special_tokens_file): 106 | special_tokens_file = None 107 | else: 108 | logger.info("loading special tokens file {}".format(special_tokens_file)) 109 | # redirect to the cache, if necessary 110 | try: 111 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 112 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 113 | except EnvironmentError: 114 | logger.error( 115 | "Model name '{}' was not found in model name list ({}). " 116 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 117 | "at this path or url.".format( 118 | pretrained_model_name_or_path, 119 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 120 | pretrained_model_name_or_path, 121 | vocab_file, merges_file)) 122 | return None 123 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 124 | logger.info("loading vocabulary file {}".format(vocab_file)) 125 | logger.info("loading merges file {}".format(merges_file)) 126 | else: 127 | logger.info("loading vocabulary file {} from cache at {}".format( 128 | vocab_file, resolved_vocab_file)) 129 | logger.info("loading merges file {} from cache at {}".format( 130 | merges_file, resolved_merges_file)) 131 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 132 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 133 | # than the number of positional embeddings 134 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 135 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 136 | # Instantiate tokenizer. 137 | if special_tokens_file and 'special_tokens' not in kwargs: 138 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 139 | else: 140 | special_tokens = kwargs.pop('special_tokens', []) 141 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 142 | return tokenizer 143 | 144 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 145 | self.max_len = max_len if max_len is not None else int(1e12) 146 | self.encoder = json.load(open(vocab_file)) 147 | self.decoder = {v:k for k,v in self.encoder.items()} 148 | self.errors = errors # how to handle errors in decoding 149 | self.byte_encoder = bytes_to_unicode() 150 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 151 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 152 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 153 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 154 | self.cache = {} 155 | 156 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 157 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 158 | 159 | self.special_tokens = {} 160 | self.special_tokens_decoder = {} 161 | self.set_special_tokens(special_tokens) 162 | 163 | def __len__(self): 164 | return len(self.encoder) + len(self.special_tokens) 165 | 166 | def set_special_tokens(self, special_tokens): 167 | """ Add a list of additional tokens to the encoder. 168 | The additional tokens are indexed starting from the last index of the 169 | current vocabulary in the order of the `special_tokens` list. 170 | """ 171 | if not special_tokens: 172 | self.special_tokens = {} 173 | self.special_tokens_decoder = {} 174 | return 175 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 176 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 177 | logger.info("Special tokens {}".format(self.special_tokens)) 178 | 179 | def bpe(self, token): 180 | if token in self.cache: 181 | return self.cache[token] 182 | word = tuple(token) 183 | pairs = get_pairs(word) 184 | 185 | if not pairs: 186 | return token 187 | 188 | while True: 189 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 190 | if bigram not in self.bpe_ranks: 191 | break 192 | first, second = bigram 193 | new_word = [] 194 | i = 0 195 | while i < len(word): 196 | try: 197 | j = word.index(first, i) 198 | new_word.extend(word[i:j]) 199 | i = j 200 | except: 201 | new_word.extend(word[i:]) 202 | break 203 | 204 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 205 | new_word.append(first+second) 206 | i += 2 207 | else: 208 | new_word.append(word[i]) 209 | i += 1 210 | new_word = tuple(new_word) 211 | word = new_word 212 | if len(word) == 1: 213 | break 214 | else: 215 | pairs = get_pairs(word) 216 | word = ' '.join(word) 217 | self.cache[token] = word 218 | return word 219 | 220 | def tokenize(self, text): 221 | """ Tokenize a string. """ 222 | bpe_tokens = [] 223 | for token in re.findall(self.pat, text): 224 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 225 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 226 | return bpe_tokens 227 | 228 | def convert_tokens_to_ids(self, tokens): 229 | """ Converts a sequence of tokens into ids using the vocab. """ 230 | ids = [] 231 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 232 | if tokens in self.special_tokens: 233 | return self.special_tokens[tokens] 234 | else: 235 | return self.encoder.get(tokens, 0) 236 | for token in tokens: 237 | if token in self.special_tokens: 238 | ids.append(self.special_tokens[token]) 239 | else: 240 | ids.append(self.encoder.get(token, 0)) 241 | if len(ids) > self.max_len: 242 | logger.warning( 243 | "Token indices sequence length is longer than the specified maximum " 244 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 245 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 246 | ) 247 | return ids 248 | 249 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 250 | """Converts a sequence of ids in BPE tokens using the vocab.""" 251 | tokens = [] 252 | for i in ids: 253 | if i in self.special_tokens_decoder: 254 | if not skip_special_tokens: 255 | tokens.append(self.special_tokens_decoder[i]) 256 | else: 257 | tokens.append(self.decoder[i]) 258 | return tokens 259 | 260 | def encode(self, text): 261 | return self.convert_tokens_to_ids(self.tokenize(text)) 262 | 263 | def decode(self, tokens): 264 | text = ''.join([self.decoder[token] for token in tokens]) 265 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 266 | return text 267 | 268 | def save_vocabulary(self, vocab_path): 269 | """Save the tokenizer vocabulary and merge files to a directory.""" 270 | if not os.path.isdir(vocab_path): 271 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 272 | return 273 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 274 | merge_file = os.path.join(vocab_path, MERGES_NAME) 275 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 276 | 277 | with open(vocab_file, 'w', encoding='utf-8') as f: 278 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 279 | 280 | index = 0 281 | with open(merge_file, "w", encoding="utf-8") as writer: 282 | writer.write(u'#version: 0.2\n') 283 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 284 | if index != token_index: 285 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 286 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 287 | index = token_index 288 | writer.write(' '.join(bpe_tokens) + u'\n') 289 | index += 1 290 | 291 | index = len(self.encoder) 292 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 293 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 294 | if index != token_index: 295 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 296 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 297 | index = token_index 298 | writer.write(token + u'\n') 299 | index += 1 300 | 301 | return vocab_file, merge_file, special_tokens_file 302 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 45 | 46 | def get_pairs(word): 47 | """ 48 | Return set of symbol pairs in a word. 49 | word is represented as tuple of symbols (symbols being variable-length strings) 50 | """ 51 | pairs = set() 52 | prev_char = word[0] 53 | for char in word[1:]: 54 | pairs.add((prev_char, char)) 55 | prev_char = char 56 | return pairs 57 | 58 | def text_standardize(text): 59 | """ 60 | fixes some issues the spacy tokenizer had on books corpus 61 | also does some whitespace standardization 62 | """ 63 | text = text.replace('—', '-') 64 | text = text.replace('–', '-') 65 | text = text.replace('―', '-') 66 | text = text.replace('…', '...') 67 | text = text.replace('´', "'") 68 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 69 | text = re.sub(r'\s*\n\s*', ' \n ', text) 70 | text = re.sub(r'[^\S\n]+', ' ', text) 71 | return text.strip() 72 | 73 | class OpenAIGPTTokenizer(object): 74 | """ 75 | BPE tokenizer. Peculiarities: 76 | - lower case all inputs 77 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 78 | - argument special_tokens and function set_special_tokens: 79 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 80 | """ 81 | @classmethod 82 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 83 | """ 84 | Instantiate a PreTrainedBertModel from a pre-trained model file. 85 | Download and cache the pre-trained model file if needed. 86 | """ 87 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 88 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 90 | special_tokens_file = None 91 | else: 92 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 93 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 94 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 95 | if not os.path.exists(special_tokens_file): 96 | special_tokens_file = None 97 | else: 98 | logger.info("loading special tokens file {}".format(special_tokens_file)) 99 | # redirect to the cache, if necessary 100 | try: 101 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 102 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 103 | except EnvironmentError: 104 | logger.error( 105 | "Model name '{}' was not found in model name list ({}). " 106 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 107 | "at this path or url.".format( 108 | pretrained_model_name_or_path, 109 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 110 | pretrained_model_name_or_path, 111 | vocab_file, merges_file)) 112 | return None 113 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 114 | logger.info("loading vocabulary file {}".format(vocab_file)) 115 | logger.info("loading merges file {}".format(merges_file)) 116 | else: 117 | logger.info("loading vocabulary file {} from cache at {}".format( 118 | vocab_file, resolved_vocab_file)) 119 | logger.info("loading merges file {} from cache at {}".format( 120 | merges_file, resolved_merges_file)) 121 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 122 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 123 | # than the number of positional embeddings 124 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 125 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 126 | # Instantiate tokenizer. 127 | if special_tokens_file and 'special_tokens' not in kwargs: 128 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 129 | else: 130 | special_tokens = kwargs.pop('special_tokens', []) 131 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 132 | return tokenizer 133 | 134 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 135 | try: 136 | import ftfy 137 | import spacy 138 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 139 | self.fix_text = ftfy.fix_text 140 | except ImportError: 141 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 142 | self.nlp = BasicTokenizer(do_lower_case=True, 143 | never_split=special_tokens if special_tokens is not None else []) 144 | self.fix_text = None 145 | 146 | self.max_len = max_len if max_len is not None else int(1e12) 147 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 148 | self.decoder = {v:k for k,v in self.encoder.items()} 149 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 150 | merges = [tuple(merge.split()) for merge in merges] 151 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 152 | self.cache = {} 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | self.set_special_tokens(special_tokens) 156 | 157 | def __len__(self): 158 | return len(self.encoder) + len(self.special_tokens) 159 | 160 | def set_special_tokens(self, special_tokens): 161 | """ Add a list of additional tokens to the encoder. 162 | The additional tokens are indexed starting from the last index of the 163 | current vocabulary in the order of the `special_tokens` list. 164 | """ 165 | if not special_tokens: 166 | self.special_tokens = {} 167 | self.special_tokens_decoder = {} 168 | return 169 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 170 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 171 | if self.fix_text is None: 172 | # Using BERT's BasicTokenizer: we can update the tokenizer 173 | self.nlp.never_split = special_tokens 174 | logger.info("Special tokens {}".format(self.special_tokens)) 175 | 176 | def bpe(self, token): 177 | word = tuple(token[:-1]) + (token[-1] + '',) 178 | if token in self.cache: 179 | return self.cache[token] 180 | pairs = get_pairs(word) 181 | 182 | if not pairs: 183 | return token+'' 184 | 185 | while True: 186 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 187 | if bigram not in self.bpe_ranks: 188 | break 189 | first, second = bigram 190 | new_word = [] 191 | i = 0 192 | while i < len(word): 193 | try: 194 | j = word.index(first, i) 195 | new_word.extend(word[i:j]) 196 | i = j 197 | except: 198 | new_word.extend(word[i:]) 199 | break 200 | 201 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 202 | new_word.append(first+second) 203 | i += 2 204 | else: 205 | new_word.append(word[i]) 206 | i += 1 207 | new_word = tuple(new_word) 208 | word = new_word 209 | if len(word) == 1: 210 | break 211 | else: 212 | pairs = get_pairs(word) 213 | word = ' '.join(word) 214 | if word == '\n ': 215 | word = '\n' 216 | self.cache[token] = word 217 | return word 218 | 219 | def tokenize(self, text): 220 | """ Tokenize a string. """ 221 | split_tokens = [] 222 | if self.fix_text is None: 223 | # Using BERT's BasicTokenizer 224 | text = self.nlp.tokenize(text) 225 | for token in text: 226 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 227 | else: 228 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 229 | text = self.nlp(text_standardize(self.fix_text(text))) 230 | for token in text: 231 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 232 | return split_tokens 233 | 234 | def convert_tokens_to_ids(self, tokens): 235 | """ Converts a sequence of tokens into ids using the vocab. """ 236 | ids = [] 237 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 238 | if tokens in self.special_tokens: 239 | return self.special_tokens[tokens] 240 | else: 241 | return self.encoder.get(tokens, 0) 242 | for token in tokens: 243 | if token in self.special_tokens: 244 | ids.append(self.special_tokens[token]) 245 | else: 246 | ids.append(self.encoder.get(token, 0)) 247 | if len(ids) > self.max_len: 248 | logger.warning( 249 | "Token indices sequence length is longer than the specified maximum " 250 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 251 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 252 | ) 253 | return ids 254 | 255 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 256 | """Converts a sequence of ids in BPE tokens using the vocab.""" 257 | tokens = [] 258 | for i in ids: 259 | if i in self.special_tokens_decoder: 260 | if not skip_special_tokens: 261 | tokens.append(self.special_tokens_decoder[i]) 262 | else: 263 | tokens.append(self.decoder[i]) 264 | return tokens 265 | 266 | def encode(self, text): 267 | return self.convert_tokens_to_ids(self.tokenize(text)) 268 | 269 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 270 | """Converts a sequence of ids in a string.""" 271 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 272 | out_string = ''.join(tokens).replace('', ' ').strip() 273 | if clean_up_tokenization_spaces: 274 | out_string = out_string.replace('', '') 275 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 276 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 277 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 278 | return out_string 279 | 280 | def save_vocabulary(self, vocab_path): 281 | """Save the tokenizer vocabulary and merge files to a directory.""" 282 | if not os.path.isdir(vocab_path): 283 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 284 | return 285 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 286 | merge_file = os.path.join(vocab_path, MERGES_NAME) 287 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 288 | 289 | with open(vocab_file, 'w', encoding='utf-8') as f: 290 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 291 | 292 | index = 0 293 | with open(merge_file, "w", encoding="utf-8") as writer: 294 | writer.write(u'#version: 0.2\n') 295 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 296 | if index != token_index: 297 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 298 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 299 | index = token_index 300 | writer.write(' '.join(bpe_tokens) + u'\n') 301 | index += 1 302 | 303 | index = len(self.encoder) 304 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 305 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 306 | if index != token_index: 307 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 308 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 309 | index = token_index 310 | writer.write(token + u'\n') 311 | index += 1 312 | 313 | return vocab_file, merge_file, special_tokens_file 314 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Tokenization classes for Transformer XL model. 17 | Adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | from __future__ import (absolute_import, division, print_function, 20 | unicode_literals) 21 | 22 | import glob 23 | import logging 24 | import os 25 | import sys 26 | from collections import Counter, OrderedDict 27 | from io import open 28 | import unicodedata 29 | 30 | import torch 31 | import numpy as np 32 | 33 | from .file_utils import cached_path 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", 45 | } 46 | VOCAB_NAME = 'vocab.bin' 47 | 48 | PRETRAINED_CORPUS_ARCHIVE_MAP = { 49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", 50 | } 51 | CORPUS_NAME = 'corpus.bin' 52 | 53 | class TransfoXLTokenizer(object): 54 | """ 55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 56 | """ 57 | @classmethod 58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 59 | """ 60 | Instantiate a TransfoXLTokenizer. 61 | The TransfoXLTokenizer. 62 | """ 63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 65 | else: 66 | if os.path.isdir(pretrained_model_name_or_path): 67 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 68 | else: 69 | vocab_file = pretrained_model_name_or_path 70 | # redirect to the cache, if necessary 71 | try: 72 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 73 | except EnvironmentError: 74 | logger.error( 75 | "Model name '{}' was not found in model name list ({}). " 76 | "We assumed '{}' was a path or url but couldn't find files {} " 77 | "at this path or url.".format( 78 | pretrained_model_name_or_path, 79 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 80 | pretrained_model_name_or_path, 81 | vocab_file)) 82 | return None 83 | if resolved_vocab_file == vocab_file: 84 | logger.info("loading vocabulary file {}".format(vocab_file)) 85 | else: 86 | logger.info("loading vocabulary file {} from cache at {}".format( 87 | vocab_file, resolved_vocab_file)) 88 | 89 | # Instantiate tokenizer. 90 | tokenizer = cls(*inputs, **kwargs) 91 | vocab_dict = torch.load(resolved_vocab_file) 92 | for key, value in vocab_dict.items(): 93 | tokenizer.__dict__[key] = value 94 | return tokenizer 95 | 96 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, 97 | delimiter=None, vocab_file=None, never_split=("", "", "")): 98 | self.counter = Counter() 99 | self.special = special 100 | self.min_freq = min_freq 101 | self.max_size = max_size 102 | self.lower_case = lower_case 103 | self.delimiter = delimiter 104 | self.vocab_file = vocab_file 105 | self.never_split = never_split 106 | 107 | def count_file(self, path, verbose=False, add_eos=False): 108 | if verbose: print('counting file {} ...'.format(path)) 109 | assert os.path.exists(path) 110 | 111 | sents = [] 112 | with open(path, 'r', encoding='utf-8') as f: 113 | for idx, line in enumerate(f): 114 | if verbose and idx > 0 and idx % 500000 == 0: 115 | print(' line {}'.format(idx)) 116 | symbols = self.tokenize(line, add_eos=add_eos) 117 | self.counter.update(symbols) 118 | sents.append(symbols) 119 | 120 | return sents 121 | 122 | def count_sents(self, sents, verbose=False): 123 | """ 124 | sents : a list of sentences, each a list of tokenized symbols 125 | """ 126 | if verbose: print('counting {} sents ...'.format(len(sents))) 127 | for idx, symbols in enumerate(sents): 128 | if verbose and idx > 0 and idx % 500000 == 0: 129 | print(' line {}'.format(idx)) 130 | self.counter.update(symbols) 131 | 132 | def _build_from_file(self, vocab_file): 133 | self.idx2sym = [] 134 | self.sym2idx = OrderedDict() 135 | 136 | with open(vocab_file, 'r', encoding='utf-8') as f: 137 | for line in f: 138 | symb = line.strip().split()[0] 139 | self.add_symbol(symb) 140 | if '' in self.sym2idx: 141 | self.unk_idx = self.sym2idx[''] 142 | elif '' in self.sym2idx: 143 | self.unk_idx = self.sym2idx[''] 144 | else: 145 | raise ValueError('No token in vocabulary') 146 | 147 | def save_vocabulary(self, vocab_path): 148 | """Save the tokenizer vocabulary to a directory or file.""" 149 | index = 0 150 | if os.path.isdir(vocab_path): 151 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 152 | torch.save(self.__dict__, vocab_file) 153 | return vocab_file 154 | 155 | def build_vocab(self): 156 | if self.vocab_file: 157 | print('building vocab from {}'.format(self.vocab_file)) 158 | self._build_from_file(self.vocab_file) 159 | print('final vocab size {}'.format(len(self))) 160 | else: 161 | print('building vocab with min_freq={}, max_size={}'.format( 162 | self.min_freq, self.max_size)) 163 | self.idx2sym = [] 164 | self.sym2idx = OrderedDict() 165 | 166 | for sym in self.special: 167 | self.add_special(sym) 168 | 169 | for sym, cnt in self.counter.most_common(self.max_size): 170 | if cnt < self.min_freq: break 171 | self.add_symbol(sym) 172 | 173 | print('final vocab size {} from {} unique tokens'.format( 174 | len(self), len(self.counter))) 175 | 176 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 177 | add_double_eos=False): 178 | if verbose: print('encoding file {} ...'.format(path)) 179 | assert os.path.exists(path) 180 | encoded = [] 181 | with open(path, 'r', encoding='utf-8') as f: 182 | for idx, line in enumerate(f): 183 | if verbose and idx > 0 and idx % 500000 == 0: 184 | print(' line {}'.format(idx)) 185 | symbols = self.tokenize(line, add_eos=add_eos, 186 | add_double_eos=add_double_eos) 187 | encoded.append(self.convert_to_tensor(symbols)) 188 | 189 | if ordered: 190 | encoded = torch.cat(encoded) 191 | 192 | return encoded 193 | 194 | def encode_sents(self, sents, ordered=False, verbose=False): 195 | if verbose: print('encoding {} sents ...'.format(len(sents))) 196 | encoded = [] 197 | for idx, symbols in enumerate(sents): 198 | if verbose and idx > 0 and idx % 500000 == 0: 199 | print(' line {}'.format(idx)) 200 | encoded.append(self.convert_to_tensor(symbols)) 201 | 202 | if ordered: 203 | encoded = torch.cat(encoded) 204 | 205 | return encoded 206 | 207 | def add_special(self, sym): 208 | if sym not in self.sym2idx: 209 | self.idx2sym.append(sym) 210 | self.sym2idx[sym] = len(self.idx2sym) - 1 211 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 212 | 213 | def add_symbol(self, sym): 214 | if sym not in self.sym2idx: 215 | self.idx2sym.append(sym) 216 | self.sym2idx[sym] = len(self.idx2sym) - 1 217 | 218 | def get_sym(self, idx): 219 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) 220 | return self.idx2sym[idx] 221 | 222 | def get_idx(self, sym): 223 | if sym in self.sym2idx: 224 | return self.sym2idx[sym] 225 | else: 226 | # print('encounter unk {}'.format(sym)) 227 | # assert '' not in sym 228 | if hasattr(self, 'unk_idx'): 229 | return self.sym2idx.get(sym, self.unk_idx) 230 | # Backward compatibility with pre-trained models 231 | elif '' in self.sym2idx: 232 | return self.sym2idx[''] 233 | elif '' in self.sym2idx: 234 | return self.sym2idx[''] 235 | else: 236 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') 237 | 238 | def convert_ids_to_tokens(self, indices): 239 | """Converts a sequence of indices in symbols using the vocab.""" 240 | return [self.get_sym(idx) for idx in indices] 241 | 242 | def convert_tokens_to_ids(self, symbols): 243 | """Converts a sequence of symbols into ids using the vocab.""" 244 | return [self.get_idx(sym) for sym in symbols] 245 | 246 | def convert_to_tensor(self, symbols): 247 | return torch.LongTensor(self.convert_tokens_to_ids(symbols)) 248 | 249 | def decode(self, indices, exclude=None): 250 | """Converts a sequence of indices in a string.""" 251 | if exclude is None: 252 | return ' '.join([self.get_sym(idx) for idx in indices]) 253 | else: 254 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 255 | 256 | def __len__(self): 257 | return len(self.idx2sym) 258 | 259 | def tokenize(self, line, add_eos=False, add_double_eos=False): 260 | line = line.strip() 261 | # convert to lower case 262 | if self.lower_case: 263 | line = line.lower() 264 | 265 | # empty delimiter '' will evaluate False 266 | if self.delimiter == '': 267 | symbols = line 268 | else: 269 | symbols = line.split(self.delimiter) 270 | 271 | if add_double_eos: # lm1b 272 | return [''] + symbols + [''] 273 | elif add_eos: 274 | return symbols + [''] 275 | else: 276 | return symbols 277 | 278 | 279 | class LMOrderedIterator(object): 280 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 281 | """ 282 | data -- LongTensor -- the LongTensor is strictly ordered 283 | """ 284 | self.bsz = bsz 285 | self.bptt = bptt 286 | self.ext_len = ext_len if ext_len is not None else 0 287 | 288 | self.device = device 289 | 290 | # Work out how cleanly we can divide the dataset into bsz parts. 291 | self.n_step = data.size(0) // bsz 292 | 293 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 294 | data = data.narrow(0, 0, self.n_step * bsz) 295 | 296 | # Evenly divide the data across the bsz batches. 297 | self.data = data.view(bsz, -1).t().contiguous().to(device) 298 | 299 | # Number of mini-batches 300 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 301 | 302 | def get_batch(self, i, bptt=None): 303 | if bptt is None: bptt = self.bptt 304 | seq_len = min(bptt, self.data.size(0) - 1 - i) 305 | 306 | end_idx = i + seq_len 307 | beg_idx = max(0, i - self.ext_len) 308 | 309 | data = self.data[beg_idx:end_idx] 310 | target = self.data[i+1:i+1+seq_len] 311 | 312 | data_out = data.transpose(0, 1).contiguous().to(self.device) 313 | target_out = target.transpose(0, 1).contiguous().to(self.device) 314 | 315 | return data_out, target_out, seq_len 316 | 317 | def get_fixlen_iter(self, start=0): 318 | for i in range(start, self.data.size(0) - 1, self.bptt): 319 | yield self.get_batch(i) 320 | 321 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 322 | max_len = self.bptt + max_deviation * std 323 | i = start 324 | while True: 325 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 326 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 327 | data, target, seq_len = self.get_batch(i, bptt) 328 | i += seq_len 329 | yield data, target, seq_len 330 | if i >= self.data.size(0) - 2: 331 | break 332 | 333 | def __iter__(self): 334 | return self.get_fixlen_iter() 335 | 336 | 337 | class LMShuffledIterator(object): 338 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 339 | """ 340 | data -- list[LongTensor] -- there is no order among the LongTensors 341 | """ 342 | self.data = data 343 | 344 | self.bsz = bsz 345 | self.bptt = bptt 346 | self.ext_len = ext_len if ext_len is not None else 0 347 | 348 | self.device = device 349 | self.shuffle = shuffle 350 | 351 | def get_sent_stream(self): 352 | # index iterator 353 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 354 | else np.array(range(len(self.data))) 355 | 356 | # sentence iterator 357 | for idx in epoch_indices: 358 | yield self.data[idx] 359 | 360 | def stream_iterator(self, sent_stream): 361 | # streams for each data in the batch 362 | streams = [None] * self.bsz 363 | 364 | data = torch.LongTensor(self.bptt, self.bsz) 365 | target = torch.LongTensor(self.bptt, self.bsz) 366 | 367 | n_retain = 0 368 | 369 | while True: 370 | # data : [n_retain+bptt x bsz] 371 | # target : [bptt x bsz] 372 | data[n_retain:].fill_(-1) 373 | target.fill_(-1) 374 | 375 | valid_batch = True 376 | 377 | for i in range(self.bsz): 378 | n_filled = 0 379 | try: 380 | while n_filled < self.bptt: 381 | if streams[i] is None or len(streams[i]) <= 1: 382 | streams[i] = next(sent_stream) 383 | # number of new tokens to fill in 384 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 385 | # first n_retain tokens are retained from last batch 386 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 387 | streams[i][:n_new] 388 | target[n_filled:n_filled+n_new, i] = \ 389 | streams[i][1:n_new+1] 390 | streams[i] = streams[i][n_new:] 391 | n_filled += n_new 392 | except StopIteration: 393 | valid_batch = False 394 | break 395 | 396 | if not valid_batch: 397 | return 398 | 399 | data_out = data.transpose(0, 1).contiguous().to(self.device) 400 | target_out = target.transpose(0, 1).contiguous().to(self.device) 401 | 402 | yield data_out, target_out, self.bptt 403 | 404 | n_retain = min(data.size(0), self.ext_len) 405 | if n_retain > 0: 406 | data[:n_retain] = data[-n_retain:] 407 | data.resize_(n_retain + self.bptt, data.size(1)) 408 | 409 | def __iter__(self): 410 | # sent_stream is an iterator 411 | sent_stream = self.get_sent_stream() 412 | 413 | for batch in self.stream_iterator(sent_stream): 414 | yield batch 415 | 416 | 417 | class LMMultiFileIterator(LMShuffledIterator): 418 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 419 | shuffle=False): 420 | 421 | self.paths = paths 422 | self.vocab = vocab 423 | 424 | self.bsz = bsz 425 | self.bptt = bptt 426 | self.ext_len = ext_len if ext_len is not None else 0 427 | 428 | self.device = device 429 | self.shuffle = shuffle 430 | 431 | def get_sent_stream(self, path): 432 | sents = self.vocab.encode_file(path, add_double_eos=True) 433 | if self.shuffle: 434 | np.random.shuffle(sents) 435 | sent_stream = iter(sents) 436 | 437 | return sent_stream 438 | 439 | def __iter__(self): 440 | if self.shuffle: 441 | np.random.shuffle(self.paths) 442 | 443 | for path in self.paths: 444 | # sent_stream is an iterator 445 | sent_stream = self.get_sent_stream(path) 446 | for batch in self.stream_iterator(sent_stream): 447 | yield batch 448 | 449 | 450 | class TransfoXLCorpus(object): 451 | @classmethod 452 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 453 | """ 454 | Instantiate a pre-processed corpus. 455 | """ 456 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 457 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: 458 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] 459 | else: 460 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) 461 | # redirect to the cache, if necessary 462 | try: 463 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) 464 | except EnvironmentError: 465 | logger.error( 466 | "Corpus '{}' was not found in corpus list ({}). " 467 | "We assumed '{}' was a path or url but couldn't find files {} " 468 | "at this path or url.".format( 469 | pretrained_model_name_or_path, 470 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 471 | pretrained_model_name_or_path, 472 | corpus_file)) 473 | return None 474 | if resolved_corpus_file == corpus_file: 475 | logger.info("loading corpus file {}".format(corpus_file)) 476 | else: 477 | logger.info("loading corpus file {} from cache at {}".format( 478 | corpus_file, resolved_corpus_file)) 479 | 480 | # Instantiate tokenizer. 481 | corpus = cls(*inputs, **kwargs) 482 | corpus_dict = torch.load(resolved_corpus_file) 483 | for key, value in corpus_dict.items(): 484 | corpus.__dict__[key] = value 485 | corpus.vocab = vocab 486 | if corpus.train is not None: 487 | corpus.train = torch.tensor(corpus.train, dtype=torch.long) 488 | if corpus.valid is not None: 489 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) 490 | if corpus.test is not None: 491 | corpus.test = torch.tensor(corpus.test, dtype=torch.long) 492 | return corpus 493 | 494 | def __init__(self, *args, **kwargs): 495 | self.vocab = TransfoXLTokenizer(*args, **kwargs) 496 | self.dataset = None 497 | self.train = None 498 | self.valid = None 499 | self.test = None 500 | 501 | def build_corpus(self, path, dataset): 502 | self.dataset = dataset 503 | 504 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 505 | self.vocab.count_file(os.path.join(path, 'train.txt')) 506 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 507 | self.vocab.count_file(os.path.join(path, 'test.txt')) 508 | elif self.dataset == 'wt103': 509 | self.vocab.count_file(os.path.join(path, 'train.txt')) 510 | elif self.dataset == 'lm1b': 511 | train_path_pattern = os.path.join( 512 | path, '1-billion-word-language-modeling-benchmark-r13output', 513 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 514 | train_paths = glob.glob(train_path_pattern) 515 | # the vocab will load from file when build_vocab() is called 516 | 517 | self.vocab.build_vocab() 518 | 519 | if self.dataset in ['ptb', 'wt2', 'wt103']: 520 | self.train = self.vocab.encode_file( 521 | os.path.join(path, 'train.txt'), ordered=True) 522 | self.valid = self.vocab.encode_file( 523 | os.path.join(path, 'valid.txt'), ordered=True) 524 | self.test = self.vocab.encode_file( 525 | os.path.join(path, 'test.txt'), ordered=True) 526 | elif self.dataset in ['enwik8', 'text8']: 527 | self.train = self.vocab.encode_file( 528 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 529 | self.valid = self.vocab.encode_file( 530 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 531 | self.test = self.vocab.encode_file( 532 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 533 | elif self.dataset == 'lm1b': 534 | self.train = train_paths 535 | self.valid = self.vocab.encode_file( 536 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 537 | self.test = self.vocab.encode_file( 538 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 539 | 540 | def get_iterator(self, split, *args, **kwargs): 541 | if split == 'train': 542 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 543 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 544 | elif self.dataset == 'lm1b': 545 | kwargs['shuffle'] = True 546 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 547 | elif split in ['valid', 'test']: 548 | data = self.valid if split == 'valid' else self.test 549 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 550 | data_iter = LMOrderedIterator(data, *args, **kwargs) 551 | elif self.dataset == 'lm1b': 552 | data_iter = LMShuffledIterator(data, *args, **kwargs) 553 | 554 | return data_iter 555 | 556 | 557 | def get_lm_corpus(datadir, dataset): 558 | fn = os.path.join(datadir, 'cache.pt') 559 | fn_pickle = os.path.join(datadir, 'cache.pkl') 560 | if os.path.exists(fn): 561 | print('Loading cached dataset...') 562 | corpus = torch.load(fn_pickle) 563 | elif os.path.exists(fn): 564 | print('Loading cached dataset from pickle...') 565 | with open(fn, "rb") as fp: 566 | corpus = pickle.load(fp) 567 | else: 568 | print('Producing dataset {}...'.format(dataset)) 569 | kwargs = {} 570 | if dataset in ['wt103', 'wt2']: 571 | kwargs['special'] = [''] 572 | kwargs['lower_case'] = False 573 | elif dataset == 'ptb': 574 | kwargs['special'] = [''] 575 | kwargs['lower_case'] = True 576 | elif dataset == 'lm1b': 577 | kwargs['special'] = [] 578 | kwargs['lower_case'] = False 579 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 580 | elif dataset in ['enwik8', 'text8']: 581 | pass 582 | 583 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs) 584 | torch.save(corpus, fn) 585 | 586 | return corpus 587 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import time 3 | import torch 4 | import numpy as np 5 | from train_eval import train, init_network 6 | from importlib import import_module 7 | import argparse 8 | from utils import build_dataset, build_iterator, get_time_dif 9 | 10 | parser = argparse.ArgumentParser(description='Chinese Text Classification') 11 | parser.add_argument('--model', type=str, required=True, help='choose a model: Bert, ERNIE') 12 | args = parser.parse_args() 13 | 14 | 15 | if __name__ == '__main__': 16 | dataset = 'THUCNews' # 数据集 17 | 18 | model_name = args.model # bert 19 | x = import_module('models.' + model_name) 20 | config = x.Config(dataset) 21 | np.random.seed(1) 22 | torch.manual_seed(1) 23 | torch.cuda.manual_seed_all(1) 24 | torch.backends.cudnn.deterministic = True # 保证每次结果一样 25 | 26 | start_time = time.time() 27 | print("Loading data...") 28 | train_data, dev_data, test_data = build_dataset(config) 29 | train_iter = build_iterator(train_data, config) 30 | dev_iter = build_iterator(dev_data, config) 31 | test_iter = build_iterator(test_data, config) 32 | time_dif = get_time_dif(start_time) 33 | print("Time usage:", time_dif) 34 | 35 | # train 36 | model = x.Model(config).to(config.device) 37 | train(config, model, train_iter, dev_iter, test_iter) 38 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from sklearn import metrics 7 | import time 8 | from utils import get_time_dif 9 | from pytorch_pretrained_bert.optimization import BertAdam 10 | 11 | 12 | # 权重初始化,默认xavier 13 | def init_network(model, method='xavier', exclude='embedding', seed=123): 14 | for name, w in model.named_parameters(): 15 | if exclude not in name: 16 | if len(w.size()) < 2: 17 | continue 18 | if 'weight' in name: 19 | if method == 'xavier': 20 | nn.init.xavier_normal_(w) 21 | elif method == 'kaiming': 22 | nn.init.kaiming_normal_(w) 23 | else: 24 | nn.init.normal_(w) 25 | elif 'bias' in name: 26 | nn.init.constant_(w, 0) 27 | else: 28 | pass 29 | 30 | 31 | def train(config, model, train_iter, dev_iter, test_iter): 32 | start_time = time.time() 33 | model.train() 34 | param_optimizer = list(model.named_parameters()) 35 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 36 | optimizer_grouped_parameters = [ 37 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 38 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 39 | # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 40 | optimizer = BertAdam(optimizer_grouped_parameters, 41 | lr=config.learning_rate, 42 | warmup=0.05, 43 | t_total=len(train_iter) * config.num_epochs) 44 | total_batch = 0 # 记录进行到多少batch 45 | dev_best_loss = float('inf') 46 | last_improve = 0 # 记录上次验证集loss下降的batch数 47 | flag = False # 记录是否很久没有效果提升 48 | model.train() 49 | for epoch in range(config.num_epochs): 50 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) 51 | for i, (trains, labels) in enumerate(train_iter): 52 | outputs = model(trains) 53 | model.zero_grad() 54 | loss = F.cross_entropy(outputs, labels) 55 | loss.backward() 56 | optimizer.step() 57 | if total_batch % 100 == 0: 58 | # 每多少轮输出在训练集和验证集上的效果 59 | true = labels.data.cpu() 60 | predic = torch.max(outputs.data, 1)[1].cpu() 61 | train_acc = metrics.accuracy_score(true, predic) 62 | dev_acc, dev_loss = evaluate(config, model, dev_iter) 63 | if dev_loss < dev_best_loss: 64 | dev_best_loss = dev_loss 65 | torch.save(model.state_dict(), config.save_path) 66 | improve = '*' 67 | last_improve = total_batch 68 | else: 69 | improve = '' 70 | time_dif = get_time_dif(start_time) 71 | msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' 72 | print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) 73 | model.train() 74 | total_batch += 1 75 | if total_batch - last_improve > config.require_improvement: 76 | # 验证集loss超过1000batch没下降,结束训练 77 | print("No optimization for a long time, auto-stopping...") 78 | flag = True 79 | break 80 | if flag: 81 | break 82 | test(config, model, test_iter) 83 | 84 | 85 | def test(config, model, test_iter): 86 | # test 87 | model.load_state_dict(torch.load(config.save_path)) 88 | model.eval() 89 | start_time = time.time() 90 | test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) 91 | msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' 92 | print(msg.format(test_loss, test_acc)) 93 | print("Precision, Recall and F1-Score...") 94 | print(test_report) 95 | print("Confusion Matrix...") 96 | print(test_confusion) 97 | time_dif = get_time_dif(start_time) 98 | print("Time usage:", time_dif) 99 | 100 | 101 | def evaluate(config, model, data_iter, test=False): 102 | model.eval() 103 | loss_total = 0 104 | predict_all = np.array([], dtype=int) 105 | labels_all = np.array([], dtype=int) 106 | with torch.no_grad(): 107 | for texts, labels in data_iter: 108 | outputs = model(texts) 109 | loss = F.cross_entropy(outputs, labels) 110 | loss_total += loss 111 | labels = labels.data.cpu().numpy() 112 | predic = torch.max(outputs.data, 1)[1].cpu().numpy() 113 | labels_all = np.append(labels_all, labels) 114 | predict_all = np.append(predict_all, predic) 115 | 116 | acc = metrics.accuracy_score(labels_all, predict_all) 117 | if test: 118 | report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) 119 | confusion = metrics.confusion_matrix(labels_all, predict_all) 120 | return acc, loss_total / len(data_iter), report, confusion 121 | return acc, loss_total / len(data_iter) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from datetime import timedelta 6 | 7 | PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号 8 | 9 | 10 | def build_dataset(config): 11 | 12 | def load_dataset(path, pad_size=32): 13 | contents = [] 14 | with open(path, 'r', encoding='UTF-8') as f: 15 | for line in tqdm(f): 16 | lin = line.strip() 17 | if not lin: 18 | continue 19 | content, label = lin.split('\t') 20 | token = config.tokenizer.tokenize(content) 21 | token = [CLS] + token 22 | seq_len = len(token) 23 | mask = [] 24 | token_ids = config.tokenizer.convert_tokens_to_ids(token) 25 | 26 | if pad_size: 27 | if len(token) < pad_size: 28 | mask = [1] * len(token_ids) + [0] * (pad_size - len(token)) 29 | token_ids += ([0] * (pad_size - len(token))) 30 | else: 31 | mask = [1] * pad_size 32 | token_ids = token_ids[:pad_size] 33 | seq_len = pad_size 34 | contents.append((token_ids, int(label), seq_len, mask)) 35 | return contents 36 | train = load_dataset(config.train_path, config.pad_size) 37 | dev = load_dataset(config.dev_path, config.pad_size) 38 | test = load_dataset(config.test_path, config.pad_size) 39 | return train, dev, test 40 | 41 | 42 | class DatasetIterater(object): 43 | def __init__(self, batches, batch_size, device): 44 | self.batch_size = batch_size 45 | self.batches = batches 46 | self.n_batches = len(batches) // batch_size 47 | self.residue = False # 记录batch数量是否为整数 48 | if len(batches) % self.n_batches != 0: 49 | self.residue = True 50 | self.index = 0 51 | self.device = device 52 | 53 | def _to_tensor(self, datas): 54 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 55 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 56 | 57 | # pad前的长度(超过pad_size的设为pad_size) 58 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 59 | mask = torch.LongTensor([_[3] for _ in datas]).to(self.device) 60 | return (x, seq_len, mask), y 61 | 62 | def __next__(self): 63 | if self.residue and self.index == self.n_batches: 64 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 65 | self.index += 1 66 | batches = self._to_tensor(batches) 67 | return batches 68 | 69 | elif self.index > self.n_batches: 70 | self.index = 0 71 | raise StopIteration 72 | else: 73 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 74 | self.index += 1 75 | batches = self._to_tensor(batches) 76 | return batches 77 | 78 | def __iter__(self): 79 | return self 80 | 81 | def __len__(self): 82 | if self.residue: 83 | return self.n_batches + 1 84 | else: 85 | return self.n_batches 86 | 87 | 88 | def build_iterator(dataset, config): 89 | iter = DatasetIterater(dataset, config.batch_size, config.device) 90 | return iter 91 | 92 | 93 | def get_time_dif(start_time): 94 | """获取已使用时间""" 95 | end_time = time.time() 96 | time_dif = end_time - start_time 97 | return timedelta(seconds=int(round(time_dif))) 98 | --------------------------------------------------------------------------------