├── .gitignore ├── ERNIE_pretrain └── README.md ├── README.md ├── THUCNews ├── data │ ├── class.txt │ ├── dev.txt │ ├── embedding_SougouNews.npz │ ├── embedding_Tencent.npz │ ├── test.txt │ ├── train.txt │ └── vocab.pkl └── saved_dict │ └── README.md ├── bert_pretrain └── README.md ├── models ├── DPCNN.py ├── ERNIE.py ├── FastText.py ├── TextCNN.py ├── TextRCNN.py ├── TextRNN.py ├── TextRNN_Att.py ├── Transformer.py ├── bert.py ├── bert_CNN.py ├── bert_DPCNN.py ├── bert_RCNN.py └── bert_RNN.py ├── predict.py ├── pretrain_eval.py ├── pretrain_predict.py ├── pretrain_run.py ├── pretrain_utils.py ├── pytorch_pretrained ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── optimization.py ├── optimization_openai.py ├── tokenization.py ├── tokenization_gpt2.py ├── tokenization_openai.py └── tokenization_transfo_xl.py ├── run.py ├── train_eval.py ├── utils.py └── utils_fasttext.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | THUCNews/log/ 4 | THUCNews/saved_dict/*.ckpt 5 | 6 | bert_pretrain/*.bin 7 | bert_pretrain/*.json 8 | bert_pretrain/*.text 9 | 10 | ERNIE_pretrain/*.bin 11 | ERNIE_pretrain/*.json 12 | ERNIE_pretrain/*.text 13 | -------------------------------------------------------------------------------- /ERNIE_pretrain/README.md: -------------------------------------------------------------------------------- 1 | ## 此处存放ERNIE预训练模型: 2 | pytorch_model.bin 3 | bert_config.json 4 | vocab.txt 5 | 6 | ## 下载地址: 7 | http://image.nghuyong.top/ERNIE.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chinese-Text-Classification 2 | 3 | 中文文本分类,基于pytorch,开箱即用。 4 | 5 | - 神经网络模型:TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention, DPCNN, Transformer 6 | 7 | - 预训练模型:Bert,ERNIE 8 | 9 | 10 | 11 | ## 介绍 12 | 13 | ### 神经网络模型 14 | 15 | 模型介绍、数据流动过程:[参考](https://zhuanlan.zhihu.com/p/73176084) 16 | 17 | 数据以字为单位输入模型,预训练词向量使用 [搜狗新闻 Word+Character 300d](https://github.com/Embedding/Chinese-Word-Vectors),[点这里下载](https://pan.baidu.com/s/14k-9jsspp43ZhMxqPmsWMQ) 18 | 19 | | 模型 | 介绍 | 20 | | ----------- | --------------------------------- | 21 | | TextCNN | Kim 2014 经典的CNN文本分类 | 22 | | TextRNN | BiLSTM | 23 | | TextRNN_Att | BiLSTM+Attention | 24 | | TextRCNN | BiLSTM+池化 | 25 | | FastText | bow+bigram+trigram, 效果出奇的好 | 26 | | DPCNN | 深层金字塔CNN | 27 | | Transformer | 效果较差 | 28 | 29 | ### 预训练模型 30 | 31 | | 模型 | 介绍 | 备注 | 32 | | ---------- | ------------------------------------------------------------ | ------------ | 33 | | bert | 原始的bert | | 34 | | ERNIE | ERNIE | | 35 | | bert_CNN | bert作为Embedding层,接入三种卷积核的CNN | bert + CNN | 36 | | bert_RNN | bert作为Embedding层,接入LSTM | bert + RNN | 37 | | bert_RCNN | bert作为Embedding层,通过LSTM与bert输出拼接,经过一层最大池化层 | bert + RCNN | 38 | | bert_DPCNN | bert作为Embedding层,经过一个包含三个不同卷积特征提取器的region embedding层,可以看作输出的是embedding,然后经过两层的等长卷积来为接下来的特征抽取提供更宽的感受眼,(提高embdding的丰富性),然后会重复通过一个1/2池化的残差块,1/2池化不断提高词位的语义,其中固定了feature_maps,残差网络的引入是为了解决在训练的过程中梯度消失和梯度爆炸的问题。 | bert + DPCNN | 39 | 40 | 参考: 41 | 42 | - [ERNIE - 详解](https://baijiahao.baidu.com/s?id=1648169054540877476) 43 | - [DPCNN 模型详解](https://zhuanlan.zhihu.com/p/372904980) 44 | - [从经典文本分类模型TextCNN到深度模型DPCNN](https://zhuanlan.zhihu.com/p/35457093) 45 | 46 | ## 环境 47 | python 3.7 48 | pytorch 1.1 49 | tqdm 50 | sklearn 51 | tensorboardX 52 | ~~pytorch_pretrained_bert~~(预训练代码也上传了, 不需要这个库了) 53 | 54 | 55 | ## 中文数据集 56 | 我从[THUCNews](http://thuctc.thunlp.org/)中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。 57 | 58 | 类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。 59 | 60 | 数据集划分: 61 | 62 | 数据集|数据量 63 | --|-- 64 | 训练集|18万 65 | 验证集|1万 66 | 测试集|1万 67 | 68 | 69 | ### 更换数据集 70 | - 按照THUCNews数据集的格式来格式化自己的中文数据集。 71 | - 对于神经网络模型: 72 | - 如果用字,按照数据集的格式来格式化你的数据。 73 | - 如果用词,提前分好词,词之间用空格隔开,`python run.py --model TextCNN --word True` 74 | - 使用预训练词向量:utils.py的main函数可以提取词表对应的预训练词向量。 75 | 76 | 77 | ## 实验效果 78 | 79 | 机器:一块2080Ti , 训练时间:30分钟。 80 | 81 | 模型|acc|备注 82 | --|--|-- 83 | TextCNN|91.22%|Kim 2014 经典的CNN文本分类 84 | TextRNN|91.12%|BiLSTM 85 | TextRNN_Att|90.90%|BiLSTM+Attention 86 | TextRCNN|91.54%|BiLSTM+池化 87 | FastText|92.23%|bow+bigram+trigram, 效果出奇的好 88 | DPCNN|91.25%|深层金字塔CNN 89 | Transformer|89.91%|效果较差 90 | bert|94.83%|单纯的bert 91 | ERNIE|94.61%|说好的中文碾压bert呢 92 | bert_CNN|94.44%|bert + CNN 93 | bert_RNN|94.57%|bert + RNN 94 | bert_RCNN|94.51%|bert + RCNN 95 | bert_DPCNN|94.47%|bert + DPCNN 96 | 97 | 原始的bert效果就很好了,把bert当作embedding层送入其它模型,效果反而降了,之后会尝试长文本的效果对比。 98 | 99 | ## 预训练语言模型 100 | bert模型放在 bert_pretain目录下,ERNIE模型放在ERNIE_pretrain目录下,每个目录下都是三个文件: 101 | - pytorch_model.bin 102 | - bert_config.json 103 | - vocab.txt 104 | 105 | 预训练模型下载地址: 106 | 107 | bert_Chinese: 模型 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz 108 | 词表 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt 109 | 110 | 来自[这里](https://github.com/huggingface/pytorch-transformers) 111 | 112 | 备用:模型的网盘地址:https://pan.baidu.com/s/1qSAD5gwClq7xlgzl_4W3Pw 113 | 114 | ERNIE_Chinese: https://pan.baidu.com/s/1lEPdDN1-YQJmKEd_g9rLgw 115 | 116 | 来自[这里](https://github.com/nghuyong/ERNIE-Pytorch) 117 | 118 | 解压后,按照上面说的放在对应目录下,文件名称确认无误即可。 119 | 120 | ## 使用说明 121 | 122 | ### 神经网络方法 123 | 124 | ``` 125 | # 训练并测试: 126 | # TextCNN 127 | python run.py --model TextCNN 128 | 129 | # TextRNN 130 | python run.py --model TextRNN 131 | 132 | # TextRNN_Att 133 | python run.py --model TextRNN_Att 134 | 135 | # TextRCNN 136 | python run.py --model TextRCNN 137 | 138 | # FastText, embedding层是随机初始化的 139 | python run.py --model FastText --embedding random 140 | 141 | # DPCNN 142 | python run.py --model DPCNN 143 | 144 | # Transformer 145 | python run.py --model Transformer 146 | ``` 147 | 148 | ### 预训练方法 149 | 150 | 下载好预训练模型就可以跑了: 151 | ``` 152 | # 预训练模型训练并测试: 153 | # bert 154 | python pretrain_run.py --model bert 155 | 156 | # bert + 其它 157 | python pretrain_run.py --model bert_CNN 158 | 159 | # ERNIE 160 | python pretrain_run.py --model ERNIE 161 | ``` 162 | 163 | ### 预测 164 | 165 | 预训练模型: 166 | 167 | ``` 168 | python pretrain_predict.py 169 | ``` 170 | 171 | 神经网络模型: 172 | 173 | ``` 174 | python predict.py 175 | ``` 176 | 177 | 178 | ### 参数 179 | 模型都在models目录下,超参定义和模型定义在同一文件中。 180 | 181 | ## 参考 182 | 183 | ### 论文 184 | 185 | [1] Convolutional Neural Networks for Sentence Classification 186 | 187 | [2] Recurrent Neural Network for Text Classification with Multi-Task Learning 188 | 189 | [3] Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification 190 | 191 | [4] Recurrent Convolutional Neural Networks for Text Classification 192 | 193 | [5] Bag of Tricks for Efficient Text Classification 194 | 195 | [6] Deep Pyramid Convolutional Neural Networks for Text Categorization 196 | 197 | [7] Attention Is All You Need 198 | 199 | [8] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 200 | 201 | [9] ERNIE: Enhanced Representation through Knowledge Integration 202 | 203 | ### 仓库 204 | 205 | 本项目基于以下仓库继续开发优化: 206 | 207 | - https://github.com/649453932/Chinese-Text-Classification-Pytorch 208 | - https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch 209 | 210 | -------------------------------------------------------------------------------- /THUCNews/data/class.txt: -------------------------------------------------------------------------------- 1 | finance 2 | realty 3 | stocks 4 | education 5 | science 6 | society 7 | politics 8 | sports 9 | game 10 | entertainment -------------------------------------------------------------------------------- /THUCNews/data/embedding_SougouNews.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/fcba006841db6f170f9a3cf56bf7b038c9eeb51c/THUCNews/data/embedding_SougouNews.npz -------------------------------------------------------------------------------- /THUCNews/data/embedding_Tencent.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/fcba006841db6f170f9a3cf56bf7b038c9eeb51c/THUCNews/data/embedding_Tencent.npz -------------------------------------------------------------------------------- /THUCNews/data/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/fcba006841db6f170f9a3cf56bf7b038c9eeb51c/THUCNews/data/vocab.pkl -------------------------------------------------------------------------------- /THUCNews/saved_dict/README.md: -------------------------------------------------------------------------------- 1 | 该文件夹存放训练的模型 -------------------------------------------------------------------------------- /bert_pretrain/README.md: -------------------------------------------------------------------------------- 1 | ## 此处存放bert预训练模型: 2 | pytorch_model.bin 3 | bert_config.json 4 | vocab.txt 5 | 6 | ## 下载地址: 7 | https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz -------------------------------------------------------------------------------- /models/DPCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'DPCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.num_filters = 250 # 卷积核数量(channels数) 37 | 38 | 39 | '''Deep Pyramid Convolutional Neural Networks for Text Categorization''' 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, config): 44 | super(Model, self).__init__() 45 | if config.embedding_pretrained is not None: 46 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 47 | else: 48 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 49 | self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.embed), stride=1) 50 | self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1) 51 | self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) 52 | self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom 53 | self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom 54 | self.relu = nn.ReLU() 55 | self.fc = nn.Linear(config.num_filters, config.num_classes) 56 | 57 | def forward(self, x): 58 | x = x[0] 59 | x = self.embedding(x) 60 | x = x.unsqueeze(1) # [batch_size, 250, seq_len, 1] 61 | x = self.conv_region(x) # [batch_size, 250, seq_len-3+1, 1] 62 | 63 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 64 | x = self.relu(x) 65 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 66 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 67 | x = self.relu(x) 68 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 69 | while x.size()[2] > 2: 70 | x = self._block(x) 71 | x = x.squeeze() # [batch_size, num_filters(250)] 72 | x = self.fc(x) 73 | return x 74 | 75 | def _block(self, x): 76 | x = self.padding2(x) 77 | px = self.max_pool(x) 78 | 79 | x = self.padding1(px) 80 | x = F.relu(x) 81 | x = self.conv(x) 82 | 83 | x = self.padding1(x) 84 | x = F.relu(x) 85 | x = self.conv(x) 86 | 87 | # Short Cut 88 | x = x + px 89 | return x 90 | -------------------------------------------------------------------------------- /models/ERNIE.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | class Config(object): 8 | 9 | """配置参数""" 10 | def __init__(self, dataset): 11 | self.model_name = 'ERNIE' 12 | self.train_path = dataset + '/data/train.txt' # 训练集 13 | self.dev_path = dataset + '/data/dev.txt' # 验证集 14 | self.test_path = dataset + '/data/test.txt' # 测试集 15 | self.class_list = [x.strip() for x in open( 16 | dataset + '/data/class.txt').readlines()] # 类别名单 17 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 18 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 19 | 20 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 21 | self.num_classes = len(self.class_list) # 类别数 22 | self.num_epochs = 3 # epoch数 23 | self.batch_size = 128 # mini-batch大小 24 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 25 | self.learning_rate = 5e-5 # 学习率 26 | self.bert_path = './ERNIE_pretrain' 27 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 28 | print(self.tokenizer) 29 | self.hidden_size = 768 30 | 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, config): 35 | super(Model, self).__init__() 36 | self.bert = BertModel.from_pretrained(config.bert_path) 37 | for param in self.bert.parameters(): 38 | param.requires_grad = True 39 | self.fc = nn.Linear(config.hidden_size, config.num_classes) 40 | 41 | def forward(self, x): 42 | context = x[0] # 输入的句子 43 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 44 | _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 45 | out = self.fc(pooled) 46 | return out 47 | -------------------------------------------------------------------------------- /models/FastText.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'FastText' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.hidden_size = 256 # 隐藏层大小 37 | self.n_gram_vocab = 250499 # ngram 词表大小 38 | 39 | 40 | '''Bag of Tricks for Efficient Text Classification''' 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config): 45 | super(Model, self).__init__() 46 | if config.embedding_pretrained is not None: 47 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 48 | else: 49 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 50 | self.embedding_ngram2 = nn.Embedding(config.n_gram_vocab, config.embed) 51 | self.embedding_ngram3 = nn.Embedding(config.n_gram_vocab, config.embed) 52 | self.dropout = nn.Dropout(config.dropout) 53 | self.fc1 = nn.Linear(config.embed * 3, config.hidden_size) 54 | # self.dropout2 = nn.Dropout(config.dropout) 55 | self.fc2 = nn.Linear(config.hidden_size, config.num_classes) 56 | 57 | def forward(self, x): 58 | 59 | out_word = self.embedding(x[0]) 60 | out_bigram = self.embedding_ngram2(x[2]) 61 | out_trigram = self.embedding_ngram3(x[3]) 62 | out = torch.cat((out_word, out_bigram, out_trigram), -1) 63 | 64 | out = out.mean(dim=1) 65 | out = self.dropout(out) 66 | out = self.fc1(out) 67 | out = F.relu(out) 68 | out = self.fc2(out) 69 | return out 70 | -------------------------------------------------------------------------------- /models/TextCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'TextCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 37 | self.num_filters = 256 # 卷积核数量(channels数) 38 | 39 | 40 | '''Convolutional Neural Networks for Sentence Classification''' 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config): 45 | super(Model, self).__init__() 46 | if config.embedding_pretrained is not None: 47 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 48 | else: 49 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 50 | self.convs = nn.ModuleList( 51 | [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes]) 52 | self.dropout = nn.Dropout(config.dropout) 53 | self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes) 54 | 55 | def conv_and_pool(self, x, conv): 56 | x = F.relu(conv(x)).squeeze(3) 57 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 58 | return x 59 | 60 | def forward(self, x): 61 | out = self.embedding(x[0]) 62 | out = out.unsqueeze(1) 63 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) 64 | out = self.dropout(out) 65 | out = self.fc(out) 66 | return out 67 | -------------------------------------------------------------------------------- /models/TextRCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'TextRCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 1.0 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 10 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度, 若使用了预训练词向量,则维度统一 36 | self.hidden_size = 256 # lstm隐藏层 37 | self.num_layers = 1 # lstm层数 38 | 39 | 40 | '''Recurrent Convolutional Neural Networks for Text Classification''' 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config): 45 | super(Model, self).__init__() 46 | if config.embedding_pretrained is not None: 47 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 48 | else: 49 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 50 | self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, 51 | bidirectional=True, batch_first=True, dropout=config.dropout) 52 | self.maxpool = nn.MaxPool1d(config.pad_size) 53 | self.fc = nn.Linear(config.hidden_size * 2 + config.embed, config.num_classes) 54 | 55 | def forward(self, x): 56 | x, _ = x 57 | embed = self.embedding(x) # [batch_size, seq_len, embeding]=[64, 32, 64] 58 | out, _ = self.lstm(embed) 59 | out = torch.cat((embed, out), 2) 60 | out = F.relu(out) 61 | out = out.permute(0, 2, 1) 62 | out = self.maxpool(out).squeeze() 63 | out = self.fc(out) 64 | return out 65 | -------------------------------------------------------------------------------- /models/TextRNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class Config(object): 8 | 9 | """配置参数""" 10 | def __init__(self, dataset, embedding): 11 | self.model_name = 'TextRNN' 12 | self.train_path = dataset + '/data/train.txt' # 训练集 13 | self.dev_path = dataset + '/data/dev.txt' # 验证集 14 | self.test_path = dataset + '/data/test.txt' # 测试集 15 | self.class_list = [x.strip() for x in open( 16 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 17 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.log_path = dataset + '/log/' + self.model_name 20 | self.embedding_pretrained = torch.tensor( 21 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 22 | if embedding != 'random' else None # 预训练词向量 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 24 | 25 | self.dropout = 0.5 # 随机失活 26 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 27 | self.num_classes = len(self.class_list) # 类别数 28 | self.n_vocab = 0 # 词表大小,在运行时赋值 29 | self.num_epochs = 10 # epoch数 30 | self.batch_size = 128 # mini-batch大小 31 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 32 | self.learning_rate = 1e-3 # 学习率 33 | self.embed = self.embedding_pretrained.size(1)\ 34 | if self.embedding_pretrained is not None else 300 # 字向量维度, 若使用了预训练词向量,则维度统一 35 | self.hidden_size = 128 # lstm隐藏层 36 | self.num_layers = 2 # lstm层数 37 | 38 | 39 | '''Recurrent Neural Network for Text Classification with Multi-Task Learning''' 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, config): 44 | super(Model, self).__init__() 45 | if config.embedding_pretrained is not None: 46 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 47 | else: 48 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 49 | self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, 50 | bidirectional=True, batch_first=True, dropout=config.dropout) 51 | self.fc = nn.Linear(config.hidden_size * 2, config.num_classes) 52 | 53 | def forward(self, x): 54 | x, _ = x 55 | out = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300] 56 | out, _ = self.lstm(out) 57 | out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state 58 | return out 59 | 60 | '''变长RNN,效果差不多,甚至还低了点...''' 61 | # def forward(self, x): 62 | # x, seq_len = x 63 | # out = self.embedding(x) 64 | # _, idx_sort = torch.sort(seq_len, dim=0, descending=True) # 长度从长到短排序(index) 65 | # _, idx_unsort = torch.sort(idx_sort) # 排序后,原序列的 index 66 | # out = torch.index_select(out, 0, idx_sort) 67 | # seq_len = list(seq_len[idx_sort]) 68 | # out = nn.utils.rnn.pack_padded_sequence(out, seq_len, batch_first=True) 69 | # # [batche_size, seq_len, num_directions * hidden_size] 70 | # out, (hn, _) = self.lstm(out) 71 | # out = torch.cat((hn[2], hn[3]), -1) 72 | # # out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 73 | # out = out.index_select(0, idx_unsort) 74 | # out = self.fc(out) 75 | # return out 76 | -------------------------------------------------------------------------------- /models/TextRNN_Att.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'TextRNN_Att' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 10 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 1e-3 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度, 若使用了预训练词向量,则维度统一 36 | self.hidden_size = 128 # lstm隐藏层 37 | self.num_layers = 2 # lstm层数 38 | self.hidden_size2 = 64 39 | 40 | 41 | '''Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification''' 42 | 43 | 44 | class Model(nn.Module): 45 | def __init__(self, config): 46 | super(Model, self).__init__() 47 | if config.embedding_pretrained is not None: 48 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 49 | else: 50 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 51 | self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, 52 | bidirectional=True, batch_first=True, dropout=config.dropout) 53 | self.tanh1 = nn.Tanh() 54 | # self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2)) 55 | self.w = nn.Parameter(torch.zeros(config.hidden_size * 2)) 56 | self.tanh2 = nn.Tanh() 57 | self.fc1 = nn.Linear(config.hidden_size * 2, config.hidden_size2) 58 | self.fc = nn.Linear(config.hidden_size2, config.num_classes) 59 | 60 | def forward(self, x): 61 | x, _ = x 62 | emb = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300] 63 | H, _ = self.lstm(emb) # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256] 64 | 65 | M = self.tanh1(H) # [128, 32, 256] 66 | # M = torch.tanh(torch.matmul(H, self.u)) 67 | alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1) # [128, 32, 1] 68 | out = H * alpha # [128, 32, 256] 69 | out = torch.sum(out, 1) # [128, 256] 70 | out = F.relu(out) 71 | out = self.fc1(out) 72 | out = self.fc(out) # [128, 64] 73 | return out 74 | -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset, embedding): 12 | self.model_name = 'Transformer' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt', encoding='utf-8').readlines()] # 类别名单 18 | self.vocab_path = dataset + '/data/vocab.pkl' # 词表 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.log_path = dataset + '/log/' + self.model_name 21 | self.embedding_pretrained = torch.tensor( 22 | np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ 23 | if embedding != 'random' else None # 预训练词向量 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 25 | 26 | self.dropout = 0.5 # 随机失活 27 | self.require_improvement = 2000 # 若超过1000batch效果还没提升,则提前结束训练 28 | self.num_classes = len(self.class_list) # 类别数 29 | self.n_vocab = 0 # 词表大小,在运行时赋值 30 | self.num_epochs = 20 # epoch数 31 | self.batch_size = 128 # mini-batch大小 32 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 33 | self.learning_rate = 5e-4 # 学习率 34 | self.embed = self.embedding_pretrained.size(1)\ 35 | if self.embedding_pretrained is not None else 300 # 字向量维度 36 | self.dim_model = 300 37 | self.hidden = 1024 38 | self.last_hidden = 512 39 | self.num_head = 5 40 | self.num_encoder = 2 41 | 42 | 43 | '''Attention Is All You Need''' 44 | 45 | 46 | class Model(nn.Module): 47 | def __init__(self, config): 48 | super(Model, self).__init__() 49 | if config.embedding_pretrained is not None: 50 | self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) 51 | else: 52 | self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) 53 | 54 | self.postion_embedding = Positional_Encoding(config.embed, config.pad_size, config.dropout, config.device) 55 | self.encoder = Encoder(config.dim_model, config.num_head, config.hidden, config.dropout) 56 | self.encoders = nn.ModuleList([ 57 | copy.deepcopy(self.encoder) 58 | # Encoder(config.dim_model, config.num_head, config.hidden, config.dropout) 59 | for _ in range(config.num_encoder)]) 60 | 61 | self.fc1 = nn.Linear(config.pad_size * config.dim_model, config.num_classes) 62 | # self.fc2 = nn.Linear(config.last_hidden, config.num_classes) 63 | # self.fc1 = nn.Linear(config.dim_model, config.num_classes) 64 | 65 | def forward(self, x): 66 | out = self.embedding(x[0]) 67 | out = self.postion_embedding(out) 68 | for encoder in self.encoders: 69 | out = encoder(out) 70 | out = out.view(out.size(0), -1) 71 | # out = torch.mean(out, 1) 72 | out = self.fc1(out) 73 | return out 74 | 75 | 76 | class Encoder(nn.Module): 77 | def __init__(self, dim_model, num_head, hidden, dropout): 78 | super(Encoder, self).__init__() 79 | self.attention = Multi_Head_Attention(dim_model, num_head, dropout) 80 | self.feed_forward = Position_wise_Feed_Forward(dim_model, hidden, dropout) 81 | 82 | def forward(self, x): 83 | out = self.attention(x) 84 | out = self.feed_forward(out) 85 | return out 86 | 87 | 88 | class Positional_Encoding(nn.Module): 89 | def __init__(self, embed, pad_size, dropout, device): 90 | super(Positional_Encoding, self).__init__() 91 | self.device = device 92 | self.pe = torch.tensor([[pos / (10000.0 ** (i // 2 * 2.0 / embed)) for i in range(embed)] for pos in range(pad_size)]) 93 | self.pe[:, 0::2] = np.sin(self.pe[:, 0::2]) 94 | self.pe[:, 1::2] = np.cos(self.pe[:, 1::2]) 95 | self.dropout = nn.Dropout(dropout) 96 | 97 | def forward(self, x): 98 | out = x + nn.Parameter(self.pe, requires_grad=False).to(self.device) 99 | out = self.dropout(out) 100 | return out 101 | 102 | 103 | class Scaled_Dot_Product_Attention(nn.Module): 104 | '''Scaled Dot-Product Attention ''' 105 | def __init__(self): 106 | super(Scaled_Dot_Product_Attention, self).__init__() 107 | 108 | def forward(self, Q, K, V, scale=None): 109 | ''' 110 | Args: 111 | Q: [batch_size, len_Q, dim_Q] 112 | K: [batch_size, len_K, dim_K] 113 | V: [batch_size, len_V, dim_V] 114 | scale: 缩放因子 论文为根号dim_K 115 | Return: 116 | self-attention后的张量,以及attention张量 117 | ''' 118 | attention = torch.matmul(Q, K.permute(0, 2, 1)) 119 | if scale: 120 | attention = attention * scale 121 | # if mask: # TODO change this 122 | # attention = attention.masked_fill_(mask == 0, -1e9) 123 | attention = F.softmax(attention, dim=-1) 124 | context = torch.matmul(attention, V) 125 | return context 126 | 127 | 128 | class Multi_Head_Attention(nn.Module): 129 | def __init__(self, dim_model, num_head, dropout=0.0): 130 | super(Multi_Head_Attention, self).__init__() 131 | self.num_head = num_head 132 | assert dim_model % num_head == 0 133 | self.dim_head = dim_model // self.num_head 134 | self.fc_Q = nn.Linear(dim_model, num_head * self.dim_head) 135 | self.fc_K = nn.Linear(dim_model, num_head * self.dim_head) 136 | self.fc_V = nn.Linear(dim_model, num_head * self.dim_head) 137 | self.attention = Scaled_Dot_Product_Attention() 138 | self.fc = nn.Linear(num_head * self.dim_head, dim_model) 139 | self.dropout = nn.Dropout(dropout) 140 | self.layer_norm = nn.LayerNorm(dim_model) 141 | 142 | def forward(self, x): 143 | batch_size = x.size(0) 144 | Q = self.fc_Q(x) 145 | K = self.fc_K(x) 146 | V = self.fc_V(x) 147 | Q = Q.view(batch_size * self.num_head, -1, self.dim_head) 148 | K = K.view(batch_size * self.num_head, -1, self.dim_head) 149 | V = V.view(batch_size * self.num_head, -1, self.dim_head) 150 | # if mask: # TODO 151 | # mask = mask.repeat(self.num_head, 1, 1) # TODO change this 152 | scale = K.size(-1) ** -0.5 # 缩放因子 153 | context = self.attention(Q, K, V, scale) 154 | 155 | context = context.view(batch_size, -1, self.dim_head * self.num_head) 156 | out = self.fc(context) 157 | out = self.dropout(out) 158 | out = out + x # 残差连接 159 | out = self.layer_norm(out) 160 | return out 161 | 162 | 163 | class Position_wise_Feed_Forward(nn.Module): 164 | def __init__(self, dim_model, hidden, dropout=0.0): 165 | super(Position_wise_Feed_Forward, self).__init__() 166 | self.fc1 = nn.Linear(dim_model, hidden) 167 | self.fc2 = nn.Linear(hidden, dim_model) 168 | self.dropout = nn.Dropout(dropout) 169 | self.layer_norm = nn.LayerNorm(dim_model) 170 | 171 | def forward(self, x): 172 | out = self.fc1(x) 173 | out = F.relu(out) 174 | out = self.fc2(out) 175 | out = self.dropout(out) 176 | out = out + x # 残差连接 177 | out = self.layer_norm(out) 178 | return out 179 | -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, config): 35 | super(Model, self).__init__() 36 | self.bert = BertModel.from_pretrained(config.bert_path) 37 | for param in self.bert.parameters(): 38 | param.requires_grad = True 39 | self.fc = nn.Linear(config.hidden_size, config.num_classes) 40 | 41 | def forward(self, x): 42 | context = x[0] # 输入的句子 43 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 44 | _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 45 | out = self.fc(pooled) 46 | return out 47 | -------------------------------------------------------------------------------- /models/bert_CNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert_CNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 31 | self.num_filters = 256 # 卷积核数量(channels数) 32 | self.dropout = 0.1 33 | 34 | 35 | class Model(nn.Module): 36 | 37 | def __init__(self, config): 38 | super(Model, self).__init__() 39 | self.bert = BertModel.from_pretrained(config.bert_path) 40 | for param in self.bert.parameters(): 41 | param.requires_grad = True 42 | self.convs = nn.ModuleList( 43 | [nn.Conv2d(1, config.num_filters, (k, config.hidden_size)) for k in config.filter_sizes]) 44 | self.dropout = nn.Dropout(config.dropout) 45 | 46 | self.fc_cnn = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes) 47 | 48 | def conv_and_pool(self, x, conv): 49 | x = F.relu(conv(x)).squeeze(3) 50 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 51 | return x 52 | 53 | def forward(self, x): 54 | context = x[0] # 输入的句子 55 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 56 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 57 | out = encoder_out.unsqueeze(1) 58 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) 59 | out = self.dropout(out) 60 | out = self.fc_cnn(out) 61 | return out 62 | -------------------------------------------------------------------------------- /models/bert_DPCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from pytorch_pretrained_bert import BertModel, BertTokenizer 6 | from pytorch_pretrained import BertModel, BertTokenizer 7 | 8 | 9 | class Config(object): 10 | 11 | """配置参数""" 12 | def __init__(self, dataset): 13 | self.model_name = 'bert_DPCNN' 14 | self.train_path = dataset + '/data/train.txt' # 训练集 15 | self.dev_path = dataset + '/data/dev.txt' # 验证集 16 | self.test_path = dataset + '/data/test.txt' # 测试集 17 | self.class_list = [x.strip() for x in open( 18 | dataset + '/data/class.txt').readlines()] # 类别名单 19 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 20 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 21 | 22 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 23 | self.num_classes = len(self.class_list) # 类别数 24 | self.num_epochs = 3 # epoch数 25 | self.batch_size = 128 # mini-batch大小 26 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 27 | self.learning_rate = 5e-5 # 学习率 28 | self.bert_path = './bert_pretrain' 29 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 30 | self.hidden_size = 768 31 | self.num_filters = 250 # 卷积核数量(channels数) 32 | 33 | 34 | class Model(nn.Module): 35 | 36 | def __init__(self, config): 37 | super(Model, self).__init__() 38 | self.bert = BertModel.from_pretrained(config.bert_path) 39 | for param in self.bert.parameters(): 40 | param.requires_grad = True 41 | # self.fc = nn.Linear(config.hidden_size, config.num_classes) 42 | self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.hidden_size), stride=1) 43 | self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1) 44 | self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) 45 | self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom 46 | self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom 47 | self.relu = nn.ReLU() 48 | self.fc = nn.Linear(config.num_filters, config.num_classes) 49 | 50 | def forward(self, x): 51 | context = x[0] # 输入的句子 52 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 53 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 54 | x = encoder_out.unsqueeze(1) # [batch_size, 1, seq_len, embed] 55 | x = self.conv_region(x) # [batch_size, 250, seq_len-3+1, 1] 56 | 57 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 58 | x = self.relu(x) 59 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 60 | x = self.padding1(x) # [batch_size, 250, seq_len, 1] 61 | x = self.relu(x) 62 | x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] 63 | while x.size()[2] > 2: 64 | x = self._block(x) 65 | x = x.squeeze() # [batch_size, num_filters(250)] 66 | x = self.fc(x) 67 | return x 68 | 69 | def _block(self, x): 70 | x = self.padding2(x) 71 | px = self.max_pool(x) 72 | x = self.padding1(px) 73 | x = F.relu(x) 74 | x = self.conv(x) 75 | x = self.padding1(x) 76 | x = F.relu(x) 77 | x = self.conv(x) 78 | x = x + px # short cut 79 | return x 80 | -------------------------------------------------------------------------------- /models/bert_RCNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert_RCNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 31 | self.num_filters = 256 # 卷积核数量(channels数) 32 | self.dropout = 0.1 33 | self.rnn_hidden = 256 34 | self.num_layers = 2 35 | 36 | 37 | class Model(nn.Module): 38 | 39 | def __init__(self, config): 40 | super(Model, self).__init__() 41 | self.bert = BertModel.from_pretrained(config.bert_path) 42 | for param in self.bert.parameters(): 43 | param.requires_grad = True 44 | self.lstm = nn.LSTM(config.hidden_size, config.rnn_hidden, config.num_layers, 45 | bidirectional=True, batch_first=True, dropout=config.dropout) 46 | self.maxpool = nn.MaxPool1d(config.pad_size) 47 | self.fc = nn.Linear(config.rnn_hidden * 2 + config.hidden_size, config.num_classes) 48 | 49 | def forward(self, x): 50 | context = x[0] # 输入的句子 51 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 52 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 53 | out, _ = self.lstm(encoder_out) 54 | out = torch.cat((encoder_out, out), 2) 55 | out = F.relu(out) 56 | out = out.permute(0, 2, 1) 57 | out = self.maxpool(out).squeeze() 58 | out = self.fc(out) 59 | return out 60 | -------------------------------------------------------------------------------- /models/bert_RNN.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pytorch_pretrained import BertModel, BertTokenizer 6 | 7 | 8 | class Config(object): 9 | 10 | """配置参数""" 11 | def __init__(self, dataset): 12 | self.model_name = 'bert_RNN' 13 | self.train_path = dataset + '/data/train.txt' # 训练集 14 | self.dev_path = dataset + '/data/dev.txt' # 验证集 15 | self.test_path = dataset + '/data/test.txt' # 测试集 16 | self.class_list = [x.strip() for x in open( 17 | dataset + '/data/class.txt').readlines()] # 类别名单 18 | self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 20 | 21 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 22 | self.num_classes = len(self.class_list) # 类别数 23 | self.num_epochs = 3 # epoch数 24 | self.batch_size = 128 # mini-batch大小 25 | self.pad_size = 32 # 每句话处理成的长度(短填长切) 26 | self.learning_rate = 5e-5 # 学习率 27 | self.bert_path = './bert_pretrain' 28 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 29 | self.hidden_size = 768 30 | self.filter_sizes = (2, 3, 4) # 卷积核尺寸 31 | self.num_filters = 256 # 卷积核数量(channels数) 32 | self.dropout = 0.1 33 | self.rnn_hidden = 768 34 | self.num_layers = 2 35 | 36 | 37 | class Model(nn.Module): 38 | 39 | def __init__(self, config): 40 | super(Model, self).__init__() 41 | self.bert = BertModel.from_pretrained(config.bert_path) 42 | for param in self.bert.parameters(): 43 | param.requires_grad = True 44 | self.lstm = nn.LSTM(config.hidden_size, config.rnn_hidden, config.num_layers, 45 | bidirectional=True, batch_first=True, dropout=config.dropout) 46 | self.dropout = nn.Dropout(config.dropout) 47 | self.fc_rnn = nn.Linear(config.rnn_hidden * 2, config.num_classes) 48 | 49 | def forward(self, x): 50 | context = x[0] # 输入的句子 51 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 52 | encoder_out, text_cls = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) 53 | out, _ = self.lstm(encoder_out) 54 | out = self.dropout(out) 55 | out = self.fc_rnn(out[:, -1, :]) # 句子最后时刻的 hidden state 56 | return out 57 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # !/usr/bin/env python 4 | # -*- coding: UTF-8 -*- 5 | import torch 6 | import pickle as pkl 7 | import numpy as np 8 | from importlib import import_module 9 | 10 | key = { 11 | 0: 'finance', 12 | 1: 'realty', 13 | 2: 'stocks', 14 | 3: 'education', 15 | 4: 'science', 16 | 5: 'society', 17 | 6: 'politics', 18 | 7: 'sports', 19 | 8: 'game', 20 | 9: 'entertainment' 21 | } 22 | 23 | 24 | class Predict: 25 | def __init__(self, model_name='TextCNN', dataset='THUCNews', embedding='embedding_SougouNews.npz', use_word=False): 26 | if use_word: 27 | self.tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level 28 | else: 29 | self.tokenizer = lambda x: [y for y in x] # char-level 30 | self.x = import_module('models.' + model_name) 31 | self.config = self.x.Config(dataset, embedding) 32 | self.vocab = pkl.load(open(self.config.vocab_path, 'rb')) 33 | self.pad_size = self.config.pad_size 34 | self.model = self.x.Model(self.config).to('cpu') 35 | self.model.load_state_dict(torch.load(self.config.save_path, map_location='cpu')) 36 | 37 | def build_predict_text(self, texts): 38 | words_lines = [] 39 | seq_lens = [] 40 | for text in texts: 41 | words_line = [] 42 | token = self.tokenizer(text) 43 | seq_len = len(token) 44 | if self.pad_size: 45 | if len(token) < self.pad_size: 46 | token.extend([''] * (self.pad_size - len(token))) 47 | else: 48 | token = token[:self.pad_size] 49 | seq_len = self.pad_size 50 | # word to id 51 | for word in token: 52 | words_line.append(self.vocab.get(word, self.vocab.get(''))) 53 | words_lines.append(words_line) 54 | seq_lens.append(seq_len) 55 | 56 | return torch.LongTensor(words_lines), torch.LongTensor(seq_lens) 57 | 58 | def predict(self, query): 59 | query = [query] 60 | # 返回预测的索引 61 | data = self.build_predict_text(query) 62 | with torch.no_grad(): 63 | outputs = self.model(data) 64 | num = torch.argmax(outputs) 65 | return key[int(num)] 66 | 67 | def predict_list(self, querys): 68 | # 返回预测的索引 69 | data = self.build_predict_text(querys) 70 | with torch.no_grad(): 71 | outputs = self.model(data) 72 | num = torch.argmax(outputs, dim=1) 73 | pred = [key[index] for index in list(np.array(num))] 74 | return pred 75 | 76 | 77 | if __name__ == "__main__": 78 | pred = Predict('TextCNN') 79 | # 预测一条 80 | query = "学费太贵怎么办?" 81 | print(pred.predict(query)) 82 | # 预测一个列表 83 | querys = ["学费太贵怎么办?", "金融怎么样"] 84 | print(pred.predict_list(querys)) 85 | -------------------------------------------------------------------------------- /pretrain_eval.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from sklearn import metrics 7 | import time 8 | from pretrain_utils import get_time_dif 9 | from pytorch_pretrained.optimization import BertAdam 10 | 11 | 12 | # 权重初始化,默认xavier 13 | def init_network(model, method='xavier', exclude='embedding', seed=123): 14 | for name, w in model.named_parameters(): 15 | if exclude not in name: 16 | if len(w.size()) < 2: 17 | continue 18 | if 'weight' in name: 19 | if method == 'xavier': 20 | nn.init.xavier_normal_(w) 21 | elif method == 'kaiming': 22 | nn.init.kaiming_normal_(w) 23 | else: 24 | nn.init.normal_(w) 25 | elif 'bias' in name: 26 | nn.init.constant_(w, 0) 27 | else: 28 | pass 29 | 30 | 31 | def train(config, model, train_iter, dev_iter, test_iter): 32 | start_time = time.time() 33 | model.train() 34 | param_optimizer = list(model.named_parameters()) 35 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 36 | optimizer_grouped_parameters = [ 37 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 38 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 39 | # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 40 | optimizer = BertAdam(optimizer_grouped_parameters, 41 | lr=config.learning_rate, 42 | warmup=0.05, 43 | t_total=len(train_iter) * config.num_epochs) 44 | total_batch = 0 # 记录进行到多少batch 45 | dev_best_loss = float('inf') 46 | last_improve = 0 # 记录上次验证集loss下降的batch数 47 | flag = False # 记录是否很久没有效果提升 48 | model.train() 49 | for epoch in range(config.num_epochs): 50 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) 51 | for i, (trains, labels) in enumerate(train_iter): 52 | outputs = model(trains) 53 | model.zero_grad() 54 | loss = F.cross_entropy(outputs, labels) 55 | loss.backward() 56 | optimizer.step() 57 | if total_batch % 100 == 0: 58 | # 每多少轮输出在训练集和验证集上的效果 59 | true = labels.data.cpu() 60 | predic = torch.max(outputs.data, 1)[1].cpu() 61 | train_acc = metrics.accuracy_score(true, predic) 62 | dev_acc, dev_loss = evaluate(config, model, dev_iter) 63 | if dev_loss < dev_best_loss: 64 | dev_best_loss = dev_loss 65 | torch.save(model.state_dict(), config.save_path) 66 | improve = '*' 67 | last_improve = total_batch 68 | else: 69 | improve = '' 70 | time_dif = get_time_dif(start_time) 71 | msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' 72 | print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) 73 | model.train() 74 | total_batch += 1 75 | if total_batch - last_improve > config.require_improvement: 76 | # 验证集loss超过1000batch没下降,结束训练 77 | print("No optimization for a long time, auto-stopping...") 78 | flag = True 79 | break 80 | if flag: 81 | break 82 | test(config, model, test_iter) 83 | 84 | 85 | def test(config, model, test_iter): 86 | # test 87 | model.load_state_dict(torch.load(config.save_path)) 88 | model.eval() 89 | start_time = time.time() 90 | test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) 91 | msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' 92 | print(msg.format(test_loss, test_acc)) 93 | print("Precision, Recall and F1-Score...") 94 | print(test_report) 95 | print("Confusion Matrix...") 96 | print(test_confusion) 97 | time_dif = get_time_dif(start_time) 98 | print("Time usage:", time_dif) 99 | 100 | 101 | def evaluate(config, model, data_iter, test=False): 102 | model.eval() 103 | loss_total = 0 104 | predict_all = np.array([], dtype=int) 105 | labels_all = np.array([], dtype=int) 106 | with torch.no_grad(): 107 | for texts, labels in data_iter: 108 | outputs = model(texts) 109 | loss = F.cross_entropy(outputs, labels) 110 | loss_total += loss 111 | labels = labels.data.cpu().numpy() 112 | predic = torch.max(outputs.data, 1)[1].cpu().numpy() 113 | labels_all = np.append(labels_all, labels) 114 | predict_all = np.append(predict_all, predic) 115 | 116 | acc = metrics.accuracy_score(labels_all, predict_all) 117 | if test: 118 | report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) 119 | confusion = metrics.confusion_matrix(labels_all, predict_all) 120 | return acc, loss_total / len(data_iter), report, confusion 121 | return acc, loss_total / len(data_iter) 122 | -------------------------------------------------------------------------------- /pretrain_predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import torch 4 | from importlib import import_module 5 | 6 | 7 | key = { 8 | 0: 'finance', 9 | 1: 'realty', 10 | 2: 'stocks', 11 | 3: 'education', 12 | 4: 'science', 13 | 5: 'society', 14 | 6: 'politics', 15 | 7: 'sports', 16 | 8: 'game', 17 | 9: 'entertainment' 18 | } 19 | 20 | 21 | class Predict: 22 | def __init__(self, model_name='bert', dataset='THUCNews'): 23 | self.x = import_module('models.' + model_name) 24 | self.config = self.x.Config(dataset) 25 | self.model = self.x.Model(self.config).to('cpu') 26 | self.model.load_state_dict(torch.load(self.config.save_path, map_location='cpu')) 27 | 28 | def build_predict_text(self, text): 29 | token = self.config.tokenizer.tokenize(text) 30 | token = ['[CLS]'] + token 31 | seq_len = len(token) 32 | mask = [] 33 | token_ids = self.config.tokenizer.convert_tokens_to_ids(token) 34 | pad_size = self.config.pad_size 35 | if pad_size: 36 | if len(token) < pad_size: 37 | mask = [1] * len(token_ids) + ([0] * (pad_size - len(token))) 38 | token_ids += ([0] * (pad_size - len(token))) 39 | else: 40 | mask = [1] * pad_size 41 | token_ids = token_ids[:pad_size] 42 | seq_len = pad_size 43 | ids = torch.LongTensor([token_ids]) 44 | seq_len = torch.LongTensor([seq_len]) 45 | mask = torch.LongTensor([mask]) 46 | return ids, seq_len, mask 47 | 48 | def predict(self, query): 49 | # 返回预测的索引 50 | data = self.build_predict_text(query) 51 | with torch.no_grad(): 52 | outputs = self.model(data) 53 | num = torch.argmax(outputs) 54 | return key[int(num)] 55 | 56 | def predict_list(self, querys): 57 | pred = [] 58 | for query in querys: 59 | pred.append(self.predict(query)) 60 | return pred 61 | 62 | 63 | if __name__ == "__main__": 64 | pred = Predict('bert') 65 | # 预测一条 66 | query = "学费太贵怎么办?" 67 | print(pred.predict(query)) 68 | # 预测一个列表 69 | querys = ["学费太贵怎么办?", "金融怎么样"] 70 | print(pred.predict_list(querys)) 71 | -------------------------------------------------------------------------------- /pretrain_run.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import time 3 | import torch 4 | import numpy as np 5 | from pretrain_eval import train, init_network 6 | from importlib import import_module 7 | import argparse 8 | from pretrain_utils import build_dataset, build_iterator, get_time_dif 9 | 10 | parser = argparse.ArgumentParser(description='Chinese Text Classification') 11 | parser.add_argument('--model', type=str, required=True, help='choose a model: Bert, ERNIE') 12 | args = parser.parse_args() 13 | 14 | 15 | if __name__ == '__main__': 16 | dataset = 'THUCNews' # 数据集 17 | 18 | model_name = args.model # bert 19 | x = import_module('models.' + model_name) 20 | config = x.Config(dataset) 21 | np.random.seed(1) 22 | torch.manual_seed(1) 23 | torch.cuda.manual_seed_all(1) 24 | torch.backends.cudnn.deterministic = True # 保证每次结果一样 25 | 26 | start_time = time.time() 27 | print("Loading data...") 28 | train_data, dev_data, test_data = build_dataset(config) 29 | train_iter = build_iterator(train_data, config) 30 | dev_iter = build_iterator(dev_data, config) 31 | test_iter = build_iterator(test_data, config) 32 | time_dif = get_time_dif(start_time) 33 | print("Time usage:", time_dif) 34 | 35 | # train 36 | model = x.Model(config).to(config.device) 37 | train(config, model, train_iter, dev_iter, test_iter) 38 | -------------------------------------------------------------------------------- /pretrain_utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from datetime import timedelta 6 | 7 | PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号 8 | 9 | 10 | def build_dataset(config): 11 | 12 | def load_dataset(path, pad_size=32): 13 | contents = [] 14 | with open(path, 'r', encoding='UTF-8') as f: 15 | for line in tqdm(f): 16 | lin = line.strip() 17 | if not lin: 18 | continue 19 | content, label = lin.split('\t') 20 | token = config.tokenizer.tokenize(content) 21 | token = [CLS] + token 22 | seq_len = len(token) 23 | mask = [] 24 | token_ids = config.tokenizer.convert_tokens_to_ids(token) 25 | 26 | if pad_size: 27 | if len(token) < pad_size: 28 | mask = [1] * len(token_ids) + [0] * (pad_size - len(token)) 29 | token_ids += ([0] * (pad_size - len(token))) 30 | else: 31 | mask = [1] * pad_size 32 | token_ids = token_ids[:pad_size] 33 | seq_len = pad_size 34 | contents.append((token_ids, int(label), seq_len, mask)) 35 | return contents 36 | train = load_dataset(config.train_path, config.pad_size) 37 | dev = load_dataset(config.dev_path, config.pad_size) 38 | test = load_dataset(config.test_path, config.pad_size) 39 | return train, dev, test 40 | 41 | 42 | class DatasetIterater(object): 43 | def __init__(self, batches, batch_size, device): 44 | self.batch_size = batch_size 45 | self.batches = batches 46 | self.n_batches = len(batches) // batch_size 47 | self.residue = False # 记录batch数量是否为整数 48 | if len(batches) % self.n_batches != 0: 49 | self.residue = True 50 | self.index = 0 51 | self.device = device 52 | 53 | def _to_tensor(self, datas): 54 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 55 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 56 | 57 | # pad前的长度(超过pad_size的设为pad_size) 58 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 59 | mask = torch.LongTensor([_[3] for _ in datas]).to(self.device) 60 | return (x, seq_len, mask), y 61 | 62 | def __next__(self): 63 | if self.residue and self.index == self.n_batches: 64 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 65 | self.index += 1 66 | batches = self._to_tensor(batches) 67 | return batches 68 | 69 | elif self.index >= self.n_batches: 70 | self.index = 0 71 | raise StopIteration 72 | else: 73 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 74 | self.index += 1 75 | batches = self._to_tensor(batches) 76 | return batches 77 | 78 | def __iter__(self): 79 | return self 80 | 81 | def __len__(self): 82 | if self.residue: 83 | return self.n_batches + 1 84 | else: 85 | return self.n_batches 86 | 87 | 88 | def build_iterator(dataset, config): 89 | iter = DatasetIterater(dataset, config.batch_size, config.device) 90 | return iter 91 | 92 | 93 | def get_time_dif(start_time): 94 | """获取已使用时间""" 95 | end_time = time.time() 96 | time_dif = end_time - start_time 97 | return timedelta(seconds=int(round(time_dif))) 98 | -------------------------------------------------------------------------------- /pytorch_pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /pytorch_pretrained/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /pytorch_pretrained/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /pytorch_pretrained/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | -------------------------------------------------------------------------------- /pytorch_pretrained/modeling_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Utilities for PyTorch Transformer XL model. 17 | Directly adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | 20 | from collections import defaultdict 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 30 | 31 | class ProjectedAdaptiveLogSoftmax(nn.Module): 32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 33 | keep_order=False): 34 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 35 | 36 | self.n_token = n_token 37 | self.d_embed = d_embed 38 | self.d_proj = d_proj 39 | 40 | self.cutoffs = cutoffs + [n_token] 41 | self.cutoff_ends = [0] + self.cutoffs 42 | self.div_val = div_val 43 | 44 | self.shortlist_size = self.cutoffs[0] 45 | self.n_clusters = len(self.cutoffs) - 1 46 | self.head_size = self.shortlist_size + self.n_clusters 47 | 48 | if self.n_clusters > 0: 49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 51 | 52 | self.out_layers = nn.ModuleList() 53 | self.out_projs = nn.ParameterList() 54 | 55 | if div_val == 1: 56 | for i in range(len(self.cutoffs)): 57 | if d_proj != d_embed: 58 | self.out_projs.append( 59 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 60 | ) 61 | else: 62 | self.out_projs.append(None) 63 | 64 | self.out_layers.append(nn.Linear(d_embed, n_token)) 65 | else: 66 | for i in range(len(self.cutoffs)): 67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 68 | d_emb_i = d_embed // (div_val ** i) 69 | 70 | self.out_projs.append( 71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 72 | ) 73 | 74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 75 | 76 | self.keep_order = keep_order 77 | 78 | def _compute_logit(self, hidden, weight, bias, proj): 79 | if proj is None: 80 | logit = F.linear(hidden, weight, bias=bias) 81 | else: 82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 83 | proj_hid = F.linear(hidden, proj.t().contiguous()) 84 | logit = F.linear(proj_hid, weight, bias=bias) 85 | # else: 86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 87 | # if bias is not None: 88 | # logit = logit + bias 89 | 90 | return logit 91 | 92 | def forward(self, hidden, target=None, keep_order=False): 93 | ''' 94 | Params: 95 | hidden :: [len*bsz x d_proj] 96 | target :: [len*bsz] 97 | Return: 98 | if target is None: 99 | out :: [len*bsz] Negative log likelihood 100 | else: 101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary 102 | We could replace this implementation by the native PyTorch one 103 | if their's had an option to set bias on all clusters in the native one. 104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 105 | ''' 106 | 107 | if target is not None: 108 | target = target.view(-1) 109 | if hidden.size(0) != target.size(0): 110 | raise RuntimeError('Input and target should have the same size ' 111 | 'in the batch dimension.') 112 | 113 | if self.n_clusters == 0: 114 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 115 | self.out_layers[0].bias, self.out_projs[0]) 116 | if target is not None: 117 | output = -F.log_softmax(logit, dim=-1) \ 118 | .gather(1, target.unsqueeze(1)).squeeze(1) 119 | else: 120 | output = F.log_softmax(logit, dim=-1) 121 | else: 122 | # construct weights and biases 123 | weights, biases = [], [] 124 | for i in range(len(self.cutoffs)): 125 | if self.div_val == 1: 126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 127 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 128 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 129 | else: 130 | weight_i = self.out_layers[i].weight 131 | bias_i = self.out_layers[i].bias 132 | 133 | if i == 0: 134 | weight_i = torch.cat( 135 | [weight_i, self.cluster_weight], dim=0) 136 | bias_i = torch.cat( 137 | [bias_i, self.cluster_bias], dim=0) 138 | 139 | weights.append(weight_i) 140 | biases.append(bias_i) 141 | 142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 143 | 144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 145 | head_logprob = F.log_softmax(head_logit, dim=1) 146 | 147 | if target is None: 148 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 149 | else: 150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) 151 | 152 | offset = 0 153 | cutoff_values = [0] + self.cutoffs 154 | for i in range(len(cutoff_values) - 1): 155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 156 | 157 | if target is not None: 158 | mask_i = (target >= l_idx) & (target < r_idx) 159 | indices_i = mask_i.nonzero().squeeze() 160 | 161 | if indices_i.numel() == 0: 162 | continue 163 | 164 | target_i = target.index_select(0, indices_i) - l_idx 165 | head_logprob_i = head_logprob.index_select(0, indices_i) 166 | hidden_i = hidden.index_select(0, indices_i) 167 | else: 168 | hidden_i = hidden 169 | 170 | if i == 0: 171 | if target is not None: 172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 173 | else: 174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 175 | else: 176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 177 | 178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 181 | if target is not None: 182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \ 183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) 184 | else: 185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i 186 | out[:, l_idx:r_idx] = logprob_i 187 | 188 | if target is not None: 189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 190 | out.index_copy_(0, indices_i, -logprob_i) 191 | else: 192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 193 | offset += logprob_i.size(0) 194 | 195 | return out 196 | 197 | 198 | def log_prob(self, hidden): 199 | r""" Computes log probabilities for all :math:`n\_classes` 200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py 201 | Args: 202 | hidden (Tensor): a minibatch of examples 203 | Returns: 204 | log-probabilities of for each class :math:`c` 205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a 206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 207 | Shape: 208 | - Input: :math:`(N, in\_features)` 209 | - Output: :math:`(N, n\_classes)` 210 | """ 211 | if self.n_clusters == 0: 212 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 213 | self.out_layers[0].bias, self.out_projs[0]) 214 | return F.log_softmax(logit, dim=-1) 215 | else: 216 | # construct weights and biases 217 | weights, biases = [], [] 218 | for i in range(len(self.cutoffs)): 219 | if self.div_val == 1: 220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 221 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 222 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 223 | else: 224 | weight_i = self.out_layers[i].weight 225 | bias_i = self.out_layers[i].bias 226 | 227 | if i == 0: 228 | weight_i = torch.cat( 229 | [weight_i, self.cluster_weight], dim=0) 230 | bias_i = torch.cat( 231 | [bias_i, self.cluster_bias], dim=0) 232 | 233 | weights.append(weight_i) 234 | biases.append(bias_i) 235 | 236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 238 | 239 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 240 | head_logprob = F.log_softmax(head_logit, dim=1) 241 | 242 | cutoff_values = [0] + self.cutoffs 243 | for i in range(len(cutoff_values) - 1): 244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] 245 | 246 | if i == 0: 247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 248 | else: 249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 250 | 251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) 252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 253 | 254 | logprob_i = head_logprob[:, -i] + tail_logprob_i 255 | out[:, start_idx, stop_idx] = logprob_i 256 | 257 | return out 258 | 259 | 260 | class LogUniformSampler(object): 261 | def __init__(self, range_max, n_sample): 262 | """ 263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 265 | 266 | expected count can be approximated by 1 - (1 - p)^n 267 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 268 | 269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 270 | """ 271 | with torch.no_grad(): 272 | self.range_max = range_max 273 | log_indices = torch.arange(1., range_max+2., 1.).log_() 274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 275 | # print('P', self.dist.numpy().tolist()[-30:]) 276 | 277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 278 | 279 | self.n_sample = n_sample 280 | 281 | def sample(self, labels): 282 | """ 283 | labels: [b1, b2] 284 | Return 285 | true_log_probs: [b1, b2] 286 | samp_log_probs: [n_sample] 287 | neg_samples: [n_sample] 288 | """ 289 | 290 | # neg_samples = torch.empty(0).long() 291 | n_sample = self.n_sample 292 | n_tries = 2 * n_sample 293 | 294 | with torch.no_grad(): 295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 296 | device = labels.device 297 | neg_samples = neg_samples.to(device) 298 | true_log_probs = self.log_q[labels].to(device) 299 | samp_log_probs = self.log_q[neg_samples].to(device) 300 | return true_log_probs, samp_log_probs, neg_samples 301 | 302 | def sample_logits(embedding, bias, labels, inputs, sampler): 303 | """ 304 | embedding: an nn.Embedding layer 305 | bias: [n_vocab] 306 | labels: [b1, b2] 307 | inputs: [b1, b2, n_emb] 308 | sampler: you may use a LogUniformSampler 309 | Return 310 | logits: [b1, b2, 1 + n_sample] 311 | """ 312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 313 | n_sample = neg_samples.size(0) 314 | b1, b2 = labels.size(0), labels.size(1) 315 | all_ids = torch.cat([labels.view(-1), neg_samples]) 316 | all_w = embedding(all_ids) 317 | true_w = all_w[: -n_sample].view(b1, b2, -1) 318 | sample_w = all_w[- n_sample:].view(n_sample, -1) 319 | 320 | all_b = bias[all_ids] 321 | true_b = all_b[: -n_sample].view(b1, b2) 322 | sample_b = all_b[- n_sample:] 323 | 324 | hit = (labels[:, :, None] == neg_samples).detach() 325 | 326 | true_logits = torch.einsum('ijk,ijk->ij', 327 | [true_w, inputs]) + true_b - true_log_probs 328 | sample_logits = torch.einsum('lk,ijk->ijl', 329 | [sample_w, inputs]) + sample_b - samp_log_probs 330 | sample_logits.masked_fill_(hit, -1e30) 331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 332 | 333 | return logits 334 | 335 | 336 | # class LogUniformSampler(object): 337 | # def __init__(self, range_max, unique=False): 338 | # """ 339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 341 | # """ 342 | # self.range_max = range_max 343 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 345 | 346 | # self.unique = unique 347 | 348 | # if self.unique: 349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 350 | 351 | # def sample(self, n_sample, labels): 352 | # pos_sample, new_labels = labels.unique(return_inverse=True) 353 | # n_pos_sample = pos_sample.size(0) 354 | # n_neg_sample = n_sample - n_pos_sample 355 | 356 | # if self.unique: 357 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 359 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 360 | # else: 361 | # sample_dist = self.dist 362 | 363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 364 | 365 | # sample = torch.cat([pos_sample, neg_sample]) 366 | # sample_prob = self.dist[sample] 367 | 368 | # return new_labels, sample, sample_prob 369 | 370 | 371 | if __name__ == '__main__': 372 | S, B = 3, 4 373 | n_vocab = 10000 374 | n_sample = 5 375 | H = 32 376 | 377 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 378 | 379 | # sampler = LogUniformSampler(n_vocab, unique=False) 380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 381 | 382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True) 383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 384 | 385 | # print('true_probs', true_probs.numpy().tolist()) 386 | # print('samp_probs', samp_probs.numpy().tolist()) 387 | # print('neg_samples', neg_samples.numpy().tolist()) 388 | 389 | # print('sum', torch.sum(sampler.dist).item()) 390 | 391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 392 | 393 | embedding = nn.Embedding(n_vocab, H) 394 | bias = torch.zeros(n_vocab) 395 | inputs = torch.Tensor(S, B, H).normal_() 396 | 397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 398 | print('logits', logits.detach().numpy().tolist()) 399 | print('logits shape', logits.size()) 400 | print('out_labels', out_labels.detach().numpy().tolist()) 401 | print('out_labels shape', out_labels.size()) 402 | 403 | -------------------------------------------------------------------------------- /pytorch_pretrained/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | 240 | Arguments: 241 | closure (callable, optional): A closure that reevaluates the model 242 | and returns the loss. 243 | """ 244 | loss = None 245 | if closure is not None: 246 | loss = closure() 247 | 248 | for group in self.param_groups: 249 | for p in group['params']: 250 | if p.grad is None: 251 | continue 252 | grad = p.grad.data 253 | if grad.is_sparse: 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 255 | 256 | state = self.state[p] 257 | 258 | # State initialization 259 | if len(state) == 0: 260 | state['step'] = 0 261 | # Exponential moving average of gradient values 262 | state['next_m'] = torch.zeros_like(p.data) 263 | # Exponential moving average of squared gradient values 264 | state['next_v'] = torch.zeros_like(p.data) 265 | 266 | next_m, next_v = state['next_m'], state['next_v'] 267 | beta1, beta2 = group['b1'], group['b2'] 268 | 269 | # Add grad clipping 270 | if group['max_grad_norm'] > 0: 271 | clip_grad_norm_(p, group['max_grad_norm']) 272 | 273 | # Decay the first and second moment running average coefficient 274 | # In-place operations to update the averages at the same time 275 | next_m.mul_(beta1).add_(1 - beta1, grad) 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 277 | update = next_m / (next_v.sqrt() + group['e']) 278 | 279 | # Just adding the square of the weights to the loss function is *not* 280 | # the correct way of using L2 regularization/weight decay with Adam, 281 | # since that will interact with the m and v parameters in strange ways. 282 | # 283 | # Instead we want to decay the weights in a manner that doesn't interact 284 | # with the m/v parameters. This is equivalent to adding the square 285 | # of the weights to the loss with plain (non-momentum) SGD. 286 | if group['weight_decay'] > 0.0: 287 | update += group['weight_decay'] * p.data 288 | 289 | lr_scheduled = group['lr'] 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) 291 | 292 | update_with_lr = lr_scheduled * update 293 | p.data.add_(-update_with_lr) 294 | 295 | state['step'] += 1 296 | 297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 298 | # No bias correction 299 | # bias_correction1 = 1 - beta1 ** state['step'] 300 | # bias_correction2 = 1 - beta2 ** state['step'] 301 | 302 | return loss 303 | -------------------------------------------------------------------------------- /pytorch_pretrained/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['b1'], group['b2'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | """Constructs a BertTokenizer. 80 | 81 | Args: 82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 83 | do_lower_case: Whether to lower case the input 84 | Only has an effect when do_wordpiece_only=False 85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 86 | max_len: An artificial maximum length to truncate tokenized sequences to; 87 | Effective maximum length is always the minimum of this 88 | value (if specified) and the underlying BERT model's 89 | sequence length. 90 | never_split: List of tokens which will never be split during tokenization. 91 | Only has an effect when do_wordpiece_only=False 92 | """ 93 | if not os.path.isfile(vocab_file): 94 | raise ValueError( 95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 97 | self.vocab = load_vocab(vocab_file) 98 | self.ids_to_tokens = collections.OrderedDict( 99 | [(ids, tok) for tok, ids in self.vocab.items()]) 100 | self.do_basic_tokenize = do_basic_tokenize 101 | if do_basic_tokenize: 102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 103 | never_split=never_split) 104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 105 | self.max_len = max_len if max_len is not None else int(1e12) 106 | 107 | def tokenize(self, text): 108 | split_tokens = [] 109 | if self.do_basic_tokenize: 110 | for token in self.basic_tokenizer.tokenize(text): 111 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 112 | split_tokens.append(sub_token) 113 | else: 114 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | """Converts a sequence of tokens into ids using the vocab.""" 119 | ids = [] 120 | for token in tokens: 121 | ids.append(self.vocab[token]) 122 | if len(ids) > self.max_len: 123 | logger.warning( 124 | "Token indices sequence length is longer than the specified maximum " 125 | " sequence length for this BERT model ({} > {}). Running this" 126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 127 | ) 128 | return ids 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 132 | tokens = [] 133 | for i in ids: 134 | tokens.append(self.ids_to_tokens[i]) 135 | return tokens 136 | 137 | def save_vocabulary(self, vocab_path): 138 | """Save the tokenizer vocabulary to a directory or file.""" 139 | index = 0 140 | if os.path.isdir(vocab_path): 141 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 142 | with open(vocab_file, "w", encoding="utf-8") as writer: 143 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 144 | if index != token_index: 145 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 146 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 147 | index = token_index 148 | writer.write(token + u'\n') 149 | index += 1 150 | return vocab_file 151 | 152 | @classmethod 153 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 154 | """ 155 | Instantiate a PreTrainedBertModel from a pre-trained model file. 156 | Download and cache the pre-trained model file if needed. 157 | """ 158 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 159 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 160 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 161 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 162 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 163 | "you may want to check this behavior.") 164 | kwargs['do_lower_case'] = False 165 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 166 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 167 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 168 | "but you may want to check this behavior.") 169 | kwargs['do_lower_case'] = True 170 | else: 171 | vocab_file = pretrained_model_name_or_path 172 | if os.path.isdir(vocab_file): 173 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 174 | # redirect to the cache, if necessary 175 | try: 176 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 177 | except EnvironmentError: 178 | logger.error( 179 | "Model name '{}' was not found in model name list ({}). " 180 | "We assumed '{}' was a path or url but couldn't find any file " 181 | "associated to this path or url.".format( 182 | pretrained_model_name_or_path, 183 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 184 | vocab_file)) 185 | return None 186 | if resolved_vocab_file == vocab_file: 187 | logger.info("loading vocabulary file {}".format(vocab_file)) 188 | else: 189 | logger.info("loading vocabulary file {} from cache at {}".format( 190 | vocab_file, resolved_vocab_file)) 191 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 192 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 193 | # than the number of positional embeddings 194 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 195 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 196 | # Instantiate tokenizer. 197 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 198 | return tokenizer 199 | 200 | 201 | class BasicTokenizer(object): 202 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 203 | 204 | def __init__(self, 205 | do_lower_case=True, 206 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 207 | """Constructs a BasicTokenizer. 208 | 209 | Args: 210 | do_lower_case: Whether to lower case the input. 211 | """ 212 | self.do_lower_case = do_lower_case 213 | self.never_split = never_split 214 | 215 | def tokenize(self, text): 216 | """Tokenizes a piece of text.""" 217 | text = self._clean_text(text) 218 | # This was added on November 1st, 2018 for the multilingual and Chinese 219 | # models. This is also applied to the English models now, but it doesn't 220 | # matter since the English models were not trained on any Chinese data 221 | # and generally don't have any Chinese data in them (there are Chinese 222 | # characters in the vocabulary because Wikipedia does have some Chinese 223 | # words in the English Wikipedia.). 224 | text = self._tokenize_chinese_chars(text) 225 | orig_tokens = whitespace_tokenize(text) 226 | split_tokens = [] 227 | for token in orig_tokens: 228 | if self.do_lower_case and token not in self.never_split: 229 | token = token.lower() 230 | token = self._run_strip_accents(token) 231 | split_tokens.extend(self._run_split_on_punc(token)) 232 | 233 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 234 | return output_tokens 235 | 236 | def _run_strip_accents(self, text): 237 | """Strips accents from a piece of text.""" 238 | text = unicodedata.normalize("NFD", text) 239 | output = [] 240 | for char in text: 241 | cat = unicodedata.category(char) 242 | if cat == "Mn": 243 | continue 244 | output.append(char) 245 | return "".join(output) 246 | 247 | def _run_split_on_punc(self, text): 248 | """Splits punctuation on a piece of text.""" 249 | if text in self.never_split: 250 | return [text] 251 | chars = list(text) 252 | i = 0 253 | start_new_word = True 254 | output = [] 255 | while i < len(chars): 256 | char = chars[i] 257 | if _is_punctuation(char): 258 | output.append([char]) 259 | start_new_word = True 260 | else: 261 | if start_new_word: 262 | output.append([]) 263 | start_new_word = False 264 | output[-1].append(char) 265 | i += 1 266 | 267 | return ["".join(x) for x in output] 268 | 269 | def _tokenize_chinese_chars(self, text): 270 | """Adds whitespace around any CJK character.""" 271 | output = [] 272 | for char in text: 273 | cp = ord(char) 274 | if self._is_chinese_char(cp): 275 | output.append(" ") 276 | output.append(char) 277 | output.append(" ") 278 | else: 279 | output.append(char) 280 | return "".join(output) 281 | 282 | def _is_chinese_char(self, cp): 283 | """Checks whether CP is the codepoint of a CJK character.""" 284 | # This defines a "chinese character" as anything in the CJK Unicode block: 285 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 286 | # 287 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 288 | # despite its name. The modern Korean Hangul alphabet is a different block, 289 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 290 | # space-separated words, so they are not treated specially and handled 291 | # like the all of the other languages. 292 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 293 | (cp >= 0x3400 and cp <= 0x4DBF) or # 294 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 295 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 296 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 297 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 298 | (cp >= 0xF900 and cp <= 0xFAFF) or # 299 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 300 | return True 301 | 302 | return False 303 | 304 | def _clean_text(self, text): 305 | """Performs invalid character removal and whitespace cleanup on text.""" 306 | output = [] 307 | for char in text: 308 | cp = ord(char) 309 | if cp == 0 or cp == 0xfffd or _is_control(char): 310 | continue 311 | if _is_whitespace(char): 312 | output.append(" ") 313 | else: 314 | output.append(char) 315 | return "".join(output) 316 | 317 | 318 | class WordpieceTokenizer(object): 319 | """Runs WordPiece tokenization.""" 320 | 321 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 322 | self.vocab = vocab 323 | self.unk_token = unk_token 324 | self.max_input_chars_per_word = max_input_chars_per_word 325 | 326 | def tokenize(self, text): 327 | """Tokenizes a piece of text into its word pieces. 328 | 329 | This uses a greedy longest-match-first algorithm to perform tokenization 330 | using the given vocabulary. 331 | 332 | For example: 333 | input = "unaffable" 334 | output = ["un", "##aff", "##able"] 335 | 336 | Args: 337 | text: A single token or whitespace separated tokens. This should have 338 | already been passed through `BasicTokenizer`. 339 | 340 | Returns: 341 | A list of wordpiece tokens. 342 | """ 343 | 344 | output_tokens = [] 345 | for token in whitespace_tokenize(text): 346 | chars = list(token) 347 | if len(chars) > self.max_input_chars_per_word: 348 | output_tokens.append(self.unk_token) 349 | continue 350 | 351 | is_bad = False 352 | start = 0 353 | sub_tokens = [] 354 | while start < len(chars): 355 | end = len(chars) 356 | cur_substr = None 357 | while start < end: 358 | substr = "".join(chars[start:end]) 359 | if start > 0: 360 | substr = "##" + substr 361 | if substr in self.vocab: 362 | cur_substr = substr 363 | break 364 | end -= 1 365 | if cur_substr is None: 366 | is_bad = True 367 | break 368 | sub_tokens.append(cur_substr) 369 | start = end 370 | 371 | if is_bad: 372 | output_tokens.append(self.unk_token) 373 | else: 374 | output_tokens.extend(sub_tokens) 375 | return output_tokens 376 | 377 | 378 | def _is_whitespace(char): 379 | """Checks whether `chars` is a whitespace character.""" 380 | # \t, \n, and \r are technically contorl characters but we treat them 381 | # as whitespace since they are generally considered as such. 382 | if char == " " or char == "\t" or char == "\n" or char == "\r": 383 | return True 384 | cat = unicodedata.category(char) 385 | if cat == "Zs": 386 | return True 387 | return False 388 | 389 | 390 | def _is_control(char): 391 | """Checks whether `chars` is a control character.""" 392 | # These are technically control characters but we count them as whitespace 393 | # characters. 394 | if char == "\t" or char == "\n" or char == "\r": 395 | return False 396 | cat = unicodedata.category(char) 397 | if cat.startswith("C"): 398 | return True 399 | return False 400 | 401 | 402 | def _is_punctuation(char): 403 | """Checks whether `chars` is a punctuation character.""" 404 | cp = ord(char) 405 | # We treat all non-letter/number ASCII as punctuation. 406 | # Characters such as "^", "$", and "`" are not in the Unicode 407 | # Punctuation class but we treat them as punctuation anyways, for 408 | # consistency. 409 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 410 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 411 | return True 412 | cat = unicodedata.category(char) 413 | if cat.startswith("P"): 414 | return True 415 | return False 416 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .file_utils import cached_path 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 40 | } 41 | PRETRAINED_MERGES_ARCHIVE_MAP = { 42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 43 | } 44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 45 | 'gpt2': 1024, 46 | } 47 | VOCAB_NAME = 'vocab.json' 48 | MERGES_NAME = 'merges.txt' 49 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 50 | 51 | @lru_cache() 52 | def bytes_to_unicode(): 53 | """ 54 | Returns list of utf-8 byte and a corresponding list of unicode strings. 55 | The reversible bpe codes work on unicode strings. 56 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 57 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 58 | This is a signficant percentage of your normal, say, 32K bpe vocab. 59 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 60 | And avoids mapping to whitespace/control characters the bpe code barfs on. 61 | """ 62 | _chr = unichr if sys.version_info[0] == 2 else chr 63 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 64 | cs = bs[:] 65 | n = 0 66 | for b in range(2**8): 67 | if b not in bs: 68 | bs.append(b) 69 | cs.append(2**8+n) 70 | n += 1 71 | cs = [_chr(n) for n in cs] 72 | return dict(zip(bs, cs)) 73 | 74 | def get_pairs(word): 75 | """Return set of symbol pairs in a word. 76 | 77 | Word is represented as tuple of symbols (symbols being variable-length strings). 78 | """ 79 | pairs = set() 80 | prev_char = word[0] 81 | for char in word[1:]: 82 | pairs.add((prev_char, char)) 83 | prev_char = char 84 | return pairs 85 | 86 | class GPT2Tokenizer(object): 87 | """ 88 | GPT-2 BPE tokenizer. Peculiarities: 89 | - Byte-level BPE 90 | """ 91 | @classmethod 92 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 93 | """ 94 | Instantiate a PreTrainedBertModel from a pre-trained model file. 95 | Download and cache the pre-trained model file if needed. 96 | """ 97 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 98 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 99 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 100 | special_tokens_file = None 101 | else: 102 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 103 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 104 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 105 | if not os.path.exists(special_tokens_file): 106 | special_tokens_file = None 107 | else: 108 | logger.info("loading special tokens file {}".format(special_tokens_file)) 109 | # redirect to the cache, if necessary 110 | try: 111 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 112 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 113 | except EnvironmentError: 114 | logger.error( 115 | "Model name '{}' was not found in model name list ({}). " 116 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 117 | "at this path or url.".format( 118 | pretrained_model_name_or_path, 119 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 120 | pretrained_model_name_or_path, 121 | vocab_file, merges_file)) 122 | return None 123 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 124 | logger.info("loading vocabulary file {}".format(vocab_file)) 125 | logger.info("loading merges file {}".format(merges_file)) 126 | else: 127 | logger.info("loading vocabulary file {} from cache at {}".format( 128 | vocab_file, resolved_vocab_file)) 129 | logger.info("loading merges file {} from cache at {}".format( 130 | merges_file, resolved_merges_file)) 131 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 132 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 133 | # than the number of positional embeddings 134 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 135 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 136 | # Instantiate tokenizer. 137 | if special_tokens_file and 'special_tokens' not in kwargs: 138 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 139 | else: 140 | special_tokens = kwargs.pop('special_tokens', []) 141 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 142 | return tokenizer 143 | 144 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 145 | self.max_len = max_len if max_len is not None else int(1e12) 146 | self.encoder = json.load(open(vocab_file)) 147 | self.decoder = {v:k for k,v in self.encoder.items()} 148 | self.errors = errors # how to handle errors in decoding 149 | self.byte_encoder = bytes_to_unicode() 150 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 151 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 152 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 153 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 154 | self.cache = {} 155 | 156 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 157 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 158 | 159 | self.special_tokens = {} 160 | self.special_tokens_decoder = {} 161 | self.set_special_tokens(special_tokens) 162 | 163 | def __len__(self): 164 | return len(self.encoder) + len(self.special_tokens) 165 | 166 | def set_special_tokens(self, special_tokens): 167 | """ Add a list of additional tokens to the encoder. 168 | The additional tokens are indexed starting from the last index of the 169 | current vocabulary in the order of the `special_tokens` list. 170 | """ 171 | if not special_tokens: 172 | self.special_tokens = {} 173 | self.special_tokens_decoder = {} 174 | return 175 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 176 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 177 | logger.info("Special tokens {}".format(self.special_tokens)) 178 | 179 | def bpe(self, token): 180 | if token in self.cache: 181 | return self.cache[token] 182 | word = tuple(token) 183 | pairs = get_pairs(word) 184 | 185 | if not pairs: 186 | return token 187 | 188 | while True: 189 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 190 | if bigram not in self.bpe_ranks: 191 | break 192 | first, second = bigram 193 | new_word = [] 194 | i = 0 195 | while i < len(word): 196 | try: 197 | j = word.index(first, i) 198 | new_word.extend(word[i:j]) 199 | i = j 200 | except: 201 | new_word.extend(word[i:]) 202 | break 203 | 204 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 205 | new_word.append(first+second) 206 | i += 2 207 | else: 208 | new_word.append(word[i]) 209 | i += 1 210 | new_word = tuple(new_word) 211 | word = new_word 212 | if len(word) == 1: 213 | break 214 | else: 215 | pairs = get_pairs(word) 216 | word = ' '.join(word) 217 | self.cache[token] = word 218 | return word 219 | 220 | def tokenize(self, text): 221 | """ Tokenize a string. """ 222 | bpe_tokens = [] 223 | for token in re.findall(self.pat, text): 224 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 225 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 226 | return bpe_tokens 227 | 228 | def convert_tokens_to_ids(self, tokens): 229 | """ Converts a sequence of tokens into ids using the vocab. """ 230 | ids = [] 231 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 232 | if tokens in self.special_tokens: 233 | return self.special_tokens[tokens] 234 | else: 235 | return self.encoder.get(tokens, 0) 236 | for token in tokens: 237 | if token in self.special_tokens: 238 | ids.append(self.special_tokens[token]) 239 | else: 240 | ids.append(self.encoder.get(token, 0)) 241 | if len(ids) > self.max_len: 242 | logger.warning( 243 | "Token indices sequence length is longer than the specified maximum " 244 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 245 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 246 | ) 247 | return ids 248 | 249 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 250 | """Converts a sequence of ids in BPE tokens using the vocab.""" 251 | tokens = [] 252 | for i in ids: 253 | if i in self.special_tokens_decoder: 254 | if not skip_special_tokens: 255 | tokens.append(self.special_tokens_decoder[i]) 256 | else: 257 | tokens.append(self.decoder[i]) 258 | return tokens 259 | 260 | def encode(self, text): 261 | return self.convert_tokens_to_ids(self.tokenize(text)) 262 | 263 | def decode(self, tokens): 264 | text = ''.join([self.decoder[token] for token in tokens]) 265 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 266 | return text 267 | 268 | def save_vocabulary(self, vocab_path): 269 | """Save the tokenizer vocabulary and merge files to a directory.""" 270 | if not os.path.isdir(vocab_path): 271 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 272 | return 273 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 274 | merge_file = os.path.join(vocab_path, MERGES_NAME) 275 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 276 | 277 | with open(vocab_file, 'w', encoding='utf-8') as f: 278 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 279 | 280 | index = 0 281 | with open(merge_file, "w", encoding="utf-8") as writer: 282 | writer.write(u'#version: 0.2\n') 283 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 284 | if index != token_index: 285 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 286 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 287 | index = token_index 288 | writer.write(' '.join(bpe_tokens) + u'\n') 289 | index += 1 290 | 291 | index = len(self.encoder) 292 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 293 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 294 | if index != token_index: 295 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 296 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 297 | index = token_index 298 | writer.write(token + u'\n') 299 | index += 1 300 | 301 | return vocab_file, merge_file, special_tokens_file 302 | -------------------------------------------------------------------------------- /pytorch_pretrained/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 45 | 46 | def get_pairs(word): 47 | """ 48 | Return set of symbol pairs in a word. 49 | word is represented as tuple of symbols (symbols being variable-length strings) 50 | """ 51 | pairs = set() 52 | prev_char = word[0] 53 | for char in word[1:]: 54 | pairs.add((prev_char, char)) 55 | prev_char = char 56 | return pairs 57 | 58 | def text_standardize(text): 59 | """ 60 | fixes some issues the spacy tokenizer had on books corpus 61 | also does some whitespace standardization 62 | """ 63 | text = text.replace('—', '-') 64 | text = text.replace('–', '-') 65 | text = text.replace('―', '-') 66 | text = text.replace('…', '...') 67 | text = text.replace('´', "'") 68 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 69 | text = re.sub(r'\s*\n\s*', ' \n ', text) 70 | text = re.sub(r'[^\S\n]+', ' ', text) 71 | return text.strip() 72 | 73 | class OpenAIGPTTokenizer(object): 74 | """ 75 | BPE tokenizer. Peculiarities: 76 | - lower case all inputs 77 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 78 | - argument special_tokens and function set_special_tokens: 79 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 80 | """ 81 | @classmethod 82 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 83 | """ 84 | Instantiate a PreTrainedBertModel from a pre-trained model file. 85 | Download and cache the pre-trained model file if needed. 86 | """ 87 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 88 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 90 | special_tokens_file = None 91 | else: 92 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 93 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 94 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 95 | if not os.path.exists(special_tokens_file): 96 | special_tokens_file = None 97 | else: 98 | logger.info("loading special tokens file {}".format(special_tokens_file)) 99 | # redirect to the cache, if necessary 100 | try: 101 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 102 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 103 | except EnvironmentError: 104 | logger.error( 105 | "Model name '{}' was not found in model name list ({}). " 106 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 107 | "at this path or url.".format( 108 | pretrained_model_name_or_path, 109 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 110 | pretrained_model_name_or_path, 111 | vocab_file, merges_file)) 112 | return None 113 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 114 | logger.info("loading vocabulary file {}".format(vocab_file)) 115 | logger.info("loading merges file {}".format(merges_file)) 116 | else: 117 | logger.info("loading vocabulary file {} from cache at {}".format( 118 | vocab_file, resolved_vocab_file)) 119 | logger.info("loading merges file {} from cache at {}".format( 120 | merges_file, resolved_merges_file)) 121 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 122 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 123 | # than the number of positional embeddings 124 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 125 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 126 | # Instantiate tokenizer. 127 | if special_tokens_file and 'special_tokens' not in kwargs: 128 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 129 | else: 130 | special_tokens = kwargs.pop('special_tokens', []) 131 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 132 | return tokenizer 133 | 134 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 135 | try: 136 | import ftfy 137 | import spacy 138 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 139 | self.fix_text = ftfy.fix_text 140 | except ImportError: 141 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 142 | self.nlp = BasicTokenizer(do_lower_case=True, 143 | never_split=special_tokens if special_tokens is not None else []) 144 | self.fix_text = None 145 | 146 | self.max_len = max_len if max_len is not None else int(1e12) 147 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 148 | self.decoder = {v:k for k,v in self.encoder.items()} 149 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 150 | merges = [tuple(merge.split()) for merge in merges] 151 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 152 | self.cache = {} 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | self.set_special_tokens(special_tokens) 156 | 157 | def __len__(self): 158 | return len(self.encoder) + len(self.special_tokens) 159 | 160 | def set_special_tokens(self, special_tokens): 161 | """ Add a list of additional tokens to the encoder. 162 | The additional tokens are indexed starting from the last index of the 163 | current vocabulary in the order of the `special_tokens` list. 164 | """ 165 | if not special_tokens: 166 | self.special_tokens = {} 167 | self.special_tokens_decoder = {} 168 | return 169 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 170 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 171 | if self.fix_text is None: 172 | # Using BERT's BasicTokenizer: we can update the tokenizer 173 | self.nlp.never_split = special_tokens 174 | logger.info("Special tokens {}".format(self.special_tokens)) 175 | 176 | def bpe(self, token): 177 | word = tuple(token[:-1]) + (token[-1] + '',) 178 | if token in self.cache: 179 | return self.cache[token] 180 | pairs = get_pairs(word) 181 | 182 | if not pairs: 183 | return token+'' 184 | 185 | while True: 186 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 187 | if bigram not in self.bpe_ranks: 188 | break 189 | first, second = bigram 190 | new_word = [] 191 | i = 0 192 | while i < len(word): 193 | try: 194 | j = word.index(first, i) 195 | new_word.extend(word[i:j]) 196 | i = j 197 | except: 198 | new_word.extend(word[i:]) 199 | break 200 | 201 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 202 | new_word.append(first+second) 203 | i += 2 204 | else: 205 | new_word.append(word[i]) 206 | i += 1 207 | new_word = tuple(new_word) 208 | word = new_word 209 | if len(word) == 1: 210 | break 211 | else: 212 | pairs = get_pairs(word) 213 | word = ' '.join(word) 214 | if word == '\n ': 215 | word = '\n' 216 | self.cache[token] = word 217 | return word 218 | 219 | def tokenize(self, text): 220 | """ Tokenize a string. """ 221 | split_tokens = [] 222 | if self.fix_text is None: 223 | # Using BERT's BasicTokenizer 224 | text = self.nlp.tokenize(text) 225 | for token in text: 226 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 227 | else: 228 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 229 | text = self.nlp(text_standardize(self.fix_text(text))) 230 | for token in text: 231 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 232 | return split_tokens 233 | 234 | def convert_tokens_to_ids(self, tokens): 235 | """ Converts a sequence of tokens into ids using the vocab. """ 236 | ids = [] 237 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 238 | if tokens in self.special_tokens: 239 | return self.special_tokens[tokens] 240 | else: 241 | return self.encoder.get(tokens, 0) 242 | for token in tokens: 243 | if token in self.special_tokens: 244 | ids.append(self.special_tokens[token]) 245 | else: 246 | ids.append(self.encoder.get(token, 0)) 247 | if len(ids) > self.max_len: 248 | logger.warning( 249 | "Token indices sequence length is longer than the specified maximum " 250 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 251 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 252 | ) 253 | return ids 254 | 255 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 256 | """Converts a sequence of ids in BPE tokens using the vocab.""" 257 | tokens = [] 258 | for i in ids: 259 | if i in self.special_tokens_decoder: 260 | if not skip_special_tokens: 261 | tokens.append(self.special_tokens_decoder[i]) 262 | else: 263 | tokens.append(self.decoder[i]) 264 | return tokens 265 | 266 | def encode(self, text): 267 | return self.convert_tokens_to_ids(self.tokenize(text)) 268 | 269 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 270 | """Converts a sequence of ids in a string.""" 271 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 272 | out_string = ''.join(tokens).replace('', ' ').strip() 273 | if clean_up_tokenization_spaces: 274 | out_string = out_string.replace('', '') 275 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 276 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 277 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 278 | return out_string 279 | 280 | def save_vocabulary(self, vocab_path): 281 | """Save the tokenizer vocabulary and merge files to a directory.""" 282 | if not os.path.isdir(vocab_path): 283 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 284 | return 285 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 286 | merge_file = os.path.join(vocab_path, MERGES_NAME) 287 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 288 | 289 | with open(vocab_file, 'w', encoding='utf-8') as f: 290 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 291 | 292 | index = 0 293 | with open(merge_file, "w", encoding="utf-8") as writer: 294 | writer.write(u'#version: 0.2\n') 295 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 296 | if index != token_index: 297 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 298 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 299 | index = token_index 300 | writer.write(' '.join(bpe_tokens) + u'\n') 301 | index += 1 302 | 303 | index = len(self.encoder) 304 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 305 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 306 | if index != token_index: 307 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 308 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 309 | index = token_index 310 | writer.write(token + u'\n') 311 | index += 1 312 | 313 | return vocab_file, merge_file, special_tokens_file 314 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import time 3 | import torch 4 | import numpy as np 5 | from train_eval import train, init_network 6 | from importlib import import_module 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Chinese Text Classification') 10 | parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer') 11 | parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') 12 | parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') 13 | args = parser.parse_args() 14 | 15 | 16 | if __name__ == '__main__': 17 | dataset = 'THUCNews' # 数据集 18 | 19 | # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random 20 | embedding = 'embedding_SougouNews.npz' 21 | if args.embedding == 'random': 22 | embedding = 'random' 23 | model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer 24 | if model_name == 'FastText': 25 | from utils_fasttext import build_dataset, build_iterator, get_time_dif 26 | embedding = 'random' 27 | else: 28 | from utils import build_dataset, build_iterator, get_time_dif 29 | 30 | x = import_module('models.' + model_name) 31 | config = x.Config(dataset, embedding) 32 | np.random.seed(1) 33 | torch.manual_seed(1) 34 | torch.cuda.manual_seed_all(1) 35 | torch.backends.cudnn.deterministic = True # 保证每次结果一样 36 | 37 | start_time = time.time() 38 | print("Loading data...") 39 | vocab, train_data, dev_data, test_data = build_dataset(config, args.word) 40 | train_iter = build_iterator(train_data, config) 41 | dev_iter = build_iterator(dev_data, config) 42 | test_iter = build_iterator(test_data, config) 43 | time_dif = get_time_dif(start_time) 44 | print("Time usage:", time_dif) 45 | 46 | # train 47 | config.n_vocab = len(vocab) 48 | model = x.Model(config).to(config.device) 49 | if model_name != 'Transformer': 50 | init_network(model) 51 | print(model.parameters) 52 | 53 | train(config, model, train_iter, dev_iter, test_iter) 54 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from sklearn import metrics 7 | import time 8 | from utils import get_time_dif 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | # 权重初始化,默认xavier 13 | def init_network(model, method='xavier', exclude='embedding', seed=123): 14 | for name, w in model.named_parameters(): 15 | if exclude not in name: 16 | if 'weight' in name: 17 | if method == 'xavier': 18 | nn.init.xavier_normal_(w) 19 | elif method == 'kaiming': 20 | nn.init.kaiming_normal_(w) 21 | else: 22 | nn.init.normal_(w) 23 | elif 'bias' in name: 24 | nn.init.constant_(w, 0) 25 | else: 26 | pass 27 | 28 | 29 | def train(config, model, train_iter, dev_iter, test_iter): 30 | start_time = time.time() 31 | model.train() 32 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 33 | 34 | # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率 35 | # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 36 | total_batch = 0 # 记录进行到多少batch 37 | dev_best_loss = float('inf') 38 | last_improve = 0 # 记录上次验证集loss下降的batch数 39 | flag = False # 记录是否很久没有效果提升 40 | writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime())) 41 | for epoch in range(config.num_epochs): 42 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) 43 | # scheduler.step() # 学习率衰减 44 | for i, (trains, labels) in enumerate(train_iter): 45 | outputs = model(trains) 46 | model.zero_grad() 47 | loss = F.cross_entropy(outputs, labels) 48 | loss.backward() 49 | optimizer.step() 50 | if total_batch % 100 == 0: 51 | # 每多少轮输出在训练集和验证集上的效果 52 | true = labels.data.cpu() 53 | predic = torch.max(outputs.data, 1)[1].cpu() 54 | train_acc = metrics.accuracy_score(true, predic) 55 | dev_acc, dev_loss = evaluate(config, model, dev_iter) 56 | if dev_loss < dev_best_loss: 57 | dev_best_loss = dev_loss 58 | torch.save(model.state_dict(), config.save_path) 59 | improve = '*' 60 | last_improve = total_batch 61 | else: 62 | improve = '' 63 | time_dif = get_time_dif(start_time) 64 | msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' 65 | print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) 66 | writer.add_scalar("loss/train", loss.item(), total_batch) 67 | writer.add_scalar("loss/dev", dev_loss, total_batch) 68 | writer.add_scalar("acc/train", train_acc, total_batch) 69 | writer.add_scalar("acc/dev", dev_acc, total_batch) 70 | model.train() 71 | total_batch += 1 72 | if total_batch - last_improve > config.require_improvement: 73 | # 验证集loss超过1000batch没下降,结束训练 74 | print("No optimization for a long time, auto-stopping...") 75 | flag = True 76 | break 77 | if flag: 78 | break 79 | writer.close() 80 | test(config, model, test_iter) 81 | 82 | 83 | def test(config, model, test_iter): 84 | # test 85 | model.load_state_dict(torch.load(config.save_path)) 86 | model.eval() 87 | start_time = time.time() 88 | test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) 89 | msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' 90 | print(msg.format(test_loss, test_acc)) 91 | print("Precision, Recall and F1-Score...") 92 | print(test_report) 93 | print("Confusion Matrix...") 94 | print(test_confusion) 95 | time_dif = get_time_dif(start_time) 96 | print("Time usage:", time_dif) 97 | 98 | 99 | def evaluate(config, model, data_iter, test=False): 100 | model.eval() 101 | loss_total = 0 102 | predict_all = np.array([], dtype=int) 103 | labels_all = np.array([], dtype=int) 104 | with torch.no_grad(): 105 | for texts, labels in data_iter: 106 | outputs = model(texts) 107 | loss = F.cross_entropy(outputs, labels) 108 | loss_total += loss 109 | labels = labels.data.cpu().numpy() 110 | predic = torch.max(outputs.data, 1)[1].cpu().numpy() 111 | labels_all = np.append(labels_all, labels) 112 | predict_all = np.append(predict_all, predic) 113 | 114 | acc = metrics.accuracy_score(labels_all, predict_all) 115 | if test: 116 | report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) 117 | confusion = metrics.confusion_matrix(labels_all, predict_all) 118 | return acc, loss_total / len(data_iter), report, confusion 119 | return acc, loss_total / len(data_iter) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle as pkl 6 | from tqdm import tqdm 7 | import time 8 | from datetime import timedelta 9 | 10 | 11 | MAX_VOCAB_SIZE = 10000 # 词表长度限制 12 | UNK, PAD = '', '' # 未知字,padding符号 13 | 14 | 15 | def build_vocab(file_path, tokenizer, max_size, min_freq): 16 | vocab_dic = {} 17 | with open(file_path, 'r', encoding='UTF-8') as f: 18 | for line in tqdm(f): 19 | lin = line.strip() 20 | if not lin: 21 | continue 22 | content = lin.split('\t')[0] 23 | for word in tokenizer(content): 24 | vocab_dic[word] = vocab_dic.get(word, 0) + 1 25 | vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size] 26 | vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} 27 | vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) 28 | return vocab_dic 29 | 30 | 31 | def build_dataset(config, ues_word): 32 | if ues_word: 33 | tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level 34 | else: 35 | tokenizer = lambda x: [y for y in x] # char-level 36 | if os.path.exists(config.vocab_path): 37 | vocab = pkl.load(open(config.vocab_path, 'rb')) 38 | else: 39 | vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) 40 | pkl.dump(vocab, open(config.vocab_path, 'wb')) 41 | print(f"Vocab size: {len(vocab)}") 42 | 43 | def load_dataset(path, pad_size=32): 44 | contents = [] 45 | with open(path, 'r', encoding='UTF-8') as f: 46 | for line in tqdm(f): 47 | lin = line.strip() 48 | if not lin: 49 | continue 50 | content, label = lin.split('\t') 51 | words_line = [] 52 | token = tokenizer(content) 53 | seq_len = len(token) 54 | if pad_size: 55 | if len(token) < pad_size: 56 | token.extend([PAD] * (pad_size - len(token))) 57 | else: 58 | token = token[:pad_size] 59 | seq_len = pad_size 60 | # word to id 61 | for word in token: 62 | words_line.append(vocab.get(word, vocab.get(UNK))) 63 | contents.append((words_line, int(label), seq_len)) 64 | return contents # [([...], 0), ([...], 1), ...] 65 | train = load_dataset(config.train_path, config.pad_size) 66 | dev = load_dataset(config.dev_path, config.pad_size) 67 | test = load_dataset(config.test_path, config.pad_size) 68 | return vocab, train, dev, test 69 | 70 | 71 | class DatasetIterater(object): 72 | def __init__(self, batches, batch_size, device): 73 | self.batch_size = batch_size 74 | self.batches = batches 75 | self.n_batches = len(batches) // batch_size 76 | self.residue = False # 记录batch数量是否为整数 77 | if len(batches) % self.n_batches != 0: 78 | self.residue = True 79 | self.index = 0 80 | self.device = device 81 | 82 | def _to_tensor(self, datas): 83 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 84 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 85 | 86 | # pad前的长度(超过pad_size的设为pad_size) 87 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 88 | return (x, seq_len), y 89 | 90 | def __next__(self): 91 | if self.residue and self.index == self.n_batches: 92 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 93 | self.index += 1 94 | batches = self._to_tensor(batches) 95 | return batches 96 | 97 | elif self.index >= self.n_batches: 98 | self.index = 0 99 | raise StopIteration 100 | else: 101 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 102 | self.index += 1 103 | batches = self._to_tensor(batches) 104 | return batches 105 | 106 | def __iter__(self): 107 | return self 108 | 109 | def __len__(self): 110 | if self.residue: 111 | return self.n_batches + 1 112 | else: 113 | return self.n_batches 114 | 115 | 116 | def build_iterator(dataset, config): 117 | iter = DatasetIterater(dataset, config.batch_size, config.device) 118 | return iter 119 | 120 | 121 | def get_time_dif(start_time): 122 | """获取已使用时间""" 123 | end_time = time.time() 124 | time_dif = end_time - start_time 125 | return timedelta(seconds=int(round(time_dif))) 126 | 127 | 128 | if __name__ == "__main__": 129 | '''提取预训练词向量''' 130 | # 下面的目录、文件名按需更改。 131 | train_dir = "./THUCNews/data/train.txt" 132 | vocab_dir = "./THUCNews/data/vocab.pkl" 133 | pretrain_dir = "./THUCNews/data/sgns.sogou.char" 134 | emb_dim = 300 135 | filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews" 136 | if os.path.exists(vocab_dir): 137 | word_to_id = pkl.load(open(vocab_dir, 'rb')) 138 | else: 139 | # tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开) 140 | tokenizer = lambda x: [y for y in x] # 以字为单位构建词表 141 | word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) 142 | pkl.dump(word_to_id, open(vocab_dir, 'wb')) 143 | 144 | embeddings = np.random.rand(len(word_to_id), emb_dim) 145 | f = open(pretrain_dir, "r", encoding='UTF-8') 146 | for i, line in enumerate(f.readlines()): 147 | # if i == 0: # 若第一行是标题,则跳过 148 | # continue 149 | lin = line.strip().split(" ") 150 | if lin[0] in word_to_id: 151 | idx = word_to_id[lin[0]] 152 | emb = [float(x) for x in lin[1:301]] 153 | embeddings[idx] = np.asarray(emb, dtype='float32') 154 | f.close() 155 | np.savez_compressed(filename_trimmed_dir, embeddings=embeddings) 156 | -------------------------------------------------------------------------------- /utils_fasttext.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle as pkl 6 | from tqdm import tqdm 7 | import time 8 | from datetime import timedelta 9 | 10 | 11 | MAX_VOCAB_SIZE = 10000 12 | UNK, PAD = '', '' 13 | 14 | 15 | def build_vocab(file_path, tokenizer, max_size, min_freq): 16 | vocab_dic = {} 17 | with open(file_path, 'r', encoding='UTF-8') as f: 18 | for line in tqdm(f): 19 | lin = line.strip() 20 | if not lin: 21 | continue 22 | content = lin.split('\t')[0] 23 | for word in tokenizer(content): 24 | vocab_dic[word] = vocab_dic.get(word, 0) + 1 25 | vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size] 26 | vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} 27 | vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) 28 | return vocab_dic 29 | 30 | 31 | def build_dataset(config, ues_word): 32 | if ues_word: 33 | tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level 34 | else: 35 | tokenizer = lambda x: [y for y in x] # char-level 36 | if os.path.exists(config.vocab_path): 37 | vocab = pkl.load(open(config.vocab_path, 'rb')) 38 | else: 39 | vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) 40 | pkl.dump(vocab, open(config.vocab_path, 'wb')) 41 | print(f"Vocab size: {len(vocab)}") 42 | 43 | def biGramHash(sequence, t, buckets): 44 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 45 | return (t1 * 14918087) % buckets 46 | 47 | def triGramHash(sequence, t, buckets): 48 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 49 | t2 = sequence[t - 2] if t - 2 >= 0 else 0 50 | return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets 51 | 52 | def load_dataset(path, pad_size=32): 53 | contents = [] 54 | with open(path, 'r', encoding='UTF-8') as f: 55 | for line in tqdm(f): 56 | lin = line.strip() 57 | if not lin: 58 | continue 59 | content, label = lin.split('\t') 60 | words_line = [] 61 | token = tokenizer(content) 62 | seq_len = len(token) 63 | if pad_size: 64 | if len(token) < pad_size: 65 | token.extend([PAD] * (pad_size - len(token))) 66 | else: 67 | token = token[:pad_size] 68 | seq_len = pad_size 69 | # word to id 70 | for word in token: 71 | words_line.append(vocab.get(word, vocab.get(UNK))) 72 | 73 | # fasttext ngram 74 | buckets = config.n_gram_vocab 75 | bigram = [] 76 | trigram = [] 77 | # ------ngram------ 78 | for i in range(pad_size): 79 | bigram.append(biGramHash(words_line, i, buckets)) 80 | trigram.append(triGramHash(words_line, i, buckets)) 81 | # ----------------- 82 | contents.append((words_line, int(label), seq_len, bigram, trigram)) 83 | return contents # [([...], 0), ([...], 1), ...] 84 | train = load_dataset(config.train_path, config.pad_size) 85 | dev = load_dataset(config.dev_path, config.pad_size) 86 | test = load_dataset(config.test_path, config.pad_size) 87 | return vocab, train, dev, test 88 | 89 | 90 | class DatasetIterater(object): 91 | def __init__(self, batches, batch_size, device): 92 | self.batch_size = batch_size 93 | self.batches = batches 94 | self.n_batches = len(batches) // batch_size 95 | self.residue = False # 记录batch数量是否为整数 96 | if len(batches) % self.n_batches != 0: 97 | self.residue = True 98 | self.index = 0 99 | self.device = device 100 | 101 | def _to_tensor(self, datas): 102 | # xx = [xxx[2] for xxx in datas] 103 | # indexx = np.argsort(xx)[::-1] 104 | # datas = np.array(datas)[indexx] 105 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 106 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 107 | bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device) 108 | trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device) 109 | 110 | # pad前的长度(超过pad_size的设为pad_size) 111 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 112 | return (x, seq_len, bigram, trigram), y 113 | 114 | def __next__(self): 115 | if self.residue and self.index == self.n_batches: 116 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 117 | self.index += 1 118 | batches = self._to_tensor(batches) 119 | return batches 120 | 121 | elif self.index >= self.n_batches: 122 | self.index = 0 123 | raise StopIteration 124 | else: 125 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 126 | self.index += 1 127 | batches = self._to_tensor(batches) 128 | return batches 129 | 130 | def __iter__(self): 131 | return self 132 | 133 | def __len__(self): 134 | if self.residue: 135 | return self.n_batches + 1 136 | else: 137 | return self.n_batches 138 | 139 | 140 | def build_iterator(dataset, config): 141 | iter = DatasetIterater(dataset, config.batch_size, config.device) 142 | return iter 143 | 144 | 145 | def get_time_dif(start_time): 146 | """获取已使用时间""" 147 | end_time = time.time() 148 | time_dif = end_time - start_time 149 | return timedelta(seconds=int(round(time_dif))) 150 | 151 | if __name__ == "__main__": 152 | '''提取预训练词向量''' 153 | vocab_dir = "./THUCNews/data/vocab.pkl" 154 | pretrain_dir = "./THUCNews/data/sgns.sogou.char" 155 | emb_dim = 300 156 | filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou" 157 | word_to_id = pkl.load(open(vocab_dir, 'rb')) 158 | embeddings = np.random.rand(len(word_to_id), emb_dim) 159 | f = open(pretrain_dir, "r", encoding='UTF-8') 160 | for i, line in enumerate(f.readlines()): 161 | # if i == 0: # 若第一行是标题,则跳过 162 | # continue 163 | lin = line.strip().split(" ") 164 | if lin[0] in word_to_id: 165 | idx = word_to_id[lin[0]] 166 | emb = [float(x) for x in lin[1:301]] 167 | embeddings[idx] = np.asarray(emb, dtype='float32') 168 | f.close() 169 | np.savez_compressed(filename_trimmed_dir, embeddings=embeddings) 170 | --------------------------------------------------------------------------------