├── .gitignore ├── README.md ├── bert4pytorch ├── __init__.py ├── layers.py ├── losses.py ├── modeling.py ├── optimization.py ├── snippets.py └── tokenization.py ├── examples ├── README.md ├── basic_language_model.py └── tnews_classify_finetune.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert4pytorch 2 | 3 | ## 更新: 4 | 5 | - **2022年4月2日更新**: 最近bert4pytorch新增了一位开发者,新加了一些功能,地址如下,目前还在整理调试中,整理完成后会上线 6 | - **2022年3月24日更新**: 实现对抗训练(FGM),并在分类任务上测试通过 7 | - **2022年3月22日更新**: focal loss更新,并在分类任务上测试通过 8 | - **2021年11月4日更新**: 基础测试, 添加mlm预测案例 9 | - **2021年11月3日更新**:
10 | 考虑到后续对bert家族,比如albert、T5、NEZHA、ELECTRA等架构的实现能全部集中在一个model文件实现,保证代码简洁清爽,本次更新基本对代码进行了全面重构,主干参照了bert4keras的代码结构。几乎可以以bert4keras的api风格使用。另外实现了unilm式、gpt式的mask矩阵。使用例子后续会给出。 11 | 其他几点更新如下:
12 | 1、删除ema文件,把权重滑动平均整合到optimization文件
13 | 2、添加一个完整的分类案例,在CLUE的tnews数据集上做finetune
14 | - **2021年9月6日更新**:
15 | 1、删除file_utils文件, 简化加载预训练模型代码和网络请求库的依赖, 这样就只支持下载相关模型文件后,本地加载模型,模型可以去这里下载:https://huggingface.co/models
16 | 2、增加特殊的layers、特殊的loss, layer增加了CRF,loss增加了focal_loss和LabelSmoothingCrossEntropy, 后续会逐步添加
17 | - **2021年8月27更新**:感谢大家的star,最近有小伙伴反映了一些小的bug,我也注意到了,奈何这个月工作上实在太忙,更新不及时,大约会在9月中旬集中更新一个只需要pip一下就完全可用的版本,然后会新添加一些关键注释。 18 | 再增加对抗训练的内容,更新一个完整的finetune案例。 19 | 20 | 21 | 22 | # 背景 23 | 24 | 目前最流行的pytorch版本的bert框架,莫过于huggingface团队的Transformers项目,但是随着项目的越来越大,显得很重,对于初学者、有一定nlp基础的人来说,想看懂里面的代码逻辑,深入了解bert,有很大的难度。 25 | 26 | 另外,如果想修改Transformers的底层代码也是想当困难的,导致很难对模型进行魔改。 27 | 28 | 本项目把整个bert架构,**浓缩在几个文件当中**(主要修改自Transfomers开源项目),删除大量无关紧要的代码,新增了一些功能,比如:ema、warmup schedule,并且在核心部分,**添加了大量中文注释**,力求解答读者在使用过程中产生的一些疑惑。 29 | 30 | 此项目核心只有三个文件,modeling、tokenization、optimization。并且都在**几百行内完成**。结合大量的中文注释,分分钟透彻理解bert。 31 | 32 | ## 功能 33 | 34 | ## 现在已经实现 35 | 36 | - 加载bert、RoBERTa-wwm-ext的预训练权重进行fintune 37 | - 实现了带warmup的优化器 38 | - 实现了模型权重的指数滑动平均(ema) 39 | 40 | ### 未来将实现 41 | 42 | - albert、GPT、XLnet、conformer等网络架构 43 | - 实现各种trick(比如对抗训练、ema等),定义特定的layer、loss,方便后续扩展 44 | - 添加大量nlp、语音识别的完整可直接运行的例子和中文注释,减轻学习难度 45 | 46 | 47 | ## 使用 48 | 49 | ##### pip安装 50 | ```python 51 | pip install bert4pytorch==0.1.3 52 | ``` 53 | 目前pip安装的是旧版本,新版本请自行下载源码安装 54 | 55 | #### 下载源码安装 56 | ```python 57 | pip install git+https://github.com/MuQiuJun-AI/bert4pytorch.git 58 | ``` 59 | 60 | ## 权重 61 | 支持加载的权重 62 | 63 | - Google原版bert的pytorch版本(需要转换为pytorch版本的脚本): https://github.com/google-research/bert 64 | - 哈工大版roberta: https://github.com/ymcui/Chinese-BERT-wwm 65 | 66 | 67 | ## 其他 68 | 69 | 最初整理这个项目,只是为了自己方便。这一段时间,经常逛苏剑林大佬的博客,里面的内容写得相当精辟,更加感叹的是, 苏神经常能闭门造车出一些还不错的trick,只能说,大佬牛逼。 70 | 71 | 所以本项目命名也雷同bert4keras,以感谢苏大佬无私的分享。 72 | 73 | 后来,慢慢萌生把学习中的小小成果开源出来,后期会渐渐补充例子,前期会借用苏神的bert4keras里面的例子,实现pytorch版本。如果有问题,欢迎讨论;如果本项目对您有用,请不吝star! 74 | -------------------------------------------------------------------------------- /bert4pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | 4 | __version__ = '0.1.3' -------------------------------------------------------------------------------- /bert4pytorch/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | 7 | def gelu(x): 8 | """ gelu激活函数 9 | 在GPT架构中,使用的是gelu函数的近似版本,公式如下: 10 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 11 | 这里是直接求的解析解,就是原始论文给出的公式 12 | 论文 https://arxiv.org/abs/1606.08415 13 | """ 14 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 15 | 16 | 17 | def swish(x): 18 | return x * torch.sigmoid(x) 19 | 20 | 21 | activations = {"gelu": gelu, "relu": F.relu, "swish": swish} 22 | 23 | 24 | class LayerNorm(nn.Module): 25 | def __init__(self, hidden_size, eps=1e-12, conditional=False): 26 | """layernorm 层,这里自行实现,目的是为了兼容 conditianal layernorm,使得可以做条件文本生成、条件分类等任务 27 | 条件layernorm来自于苏剑林的想法,详情:https://spaces.ac.cn/archives/7124 28 | """ 29 | super(LayerNorm, self).__init__() 30 | self.weight = nn.Parameter(torch.ones(hidden_size)) 31 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 32 | self.eps = eps 33 | self.conditional = conditional 34 | if conditional: 35 | # 条件layernorm, 用于条件文本生成, 36 | # 这里采用全零初始化, 目的是在初始状态不干扰原来的预训练权重 37 | self.dense1 = nn.Linear(2 * hidden_size, hidden_size, bias=False) 38 | self.dense1.weight.data.uniform_(0, 0) 39 | self.dense2 = nn.Linear(2 * hidden_size, hidden_size, bias=False) 40 | self.dense2.weight.data.uniform_(0, 0) 41 | 42 | def forward(self, x): 43 | if self.conditional: 44 | inputs = x[0] 45 | cond = x[1] 46 | for _ in range(len(inputs.shape) - len(cond.shape)): 47 | cond = cond.unsqueeze(dim=1) 48 | u = inputs.mean(-1, keepdim=True) 49 | s = (inputs - u).pow(2).mean(-1, keepdim=True) 50 | x = (inputs - u) / torch.sqrt(s + self.eps) 51 | return (self.weight + self.dense1(cond)) * x + (self.bias + self.dense2(cond)) 52 | else: 53 | u = x.mean(-1, keepdim=True) 54 | s = (x - u).pow(2).mean(-1, keepdim=True) 55 | x = (x - u) / torch.sqrt(s + self.eps) 56 | return self.weight * x + self.bias 57 | 58 | 59 | class MultiHeadAttentionLayer(nn.Module): 60 | def __init__(self, hidden_size, num_attention_heads, dropout_rate, attention_scale=True, 61 | return_attention_scores=False): 62 | super(MultiHeadAttentionLayer, self).__init__() 63 | 64 | assert hidden_size % num_attention_heads == 0 65 | 66 | self.hidden_size = hidden_size 67 | self.num_attention_heads = num_attention_heads 68 | self.attention_head_size = int(hidden_size / num_attention_heads) 69 | self.attention_scale = attention_scale 70 | self.return_attention_scores = return_attention_scores 71 | 72 | self.q = nn.Linear(hidden_size, hidden_size) 73 | self.k = nn.Linear(hidden_size, hidden_size) 74 | self.v = nn.Linear(hidden_size, hidden_size) 75 | 76 | self.o = nn.Linear(hidden_size, hidden_size) 77 | 78 | self.dropout = nn.Dropout(dropout_rate) 79 | 80 | def transpose_for_scores(self, x): 81 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 82 | x = x.view(*new_x_shape) 83 | return x.permute(0, 2, 1, 3) 84 | 85 | def forward(self, query, key, value, attention_mask=None): 86 | 87 | # query shape: [batch_size, query_len, hidden_size] 88 | # key shape: [batch_size, key_len, hidden_size] 89 | # value shape: [batch_size, value_len, hidden_size] 90 | # 一般情况下,query_len、key_len、value_len三者相等 91 | 92 | mixed_query_layer = self.q(query) 93 | mixed_key_layer = self.k(key) 94 | mixed_value_layer = self.v(value) 95 | 96 | # mixed_query_layer shape: [batch_size, query_len, hidden_size] 97 | # mixed_query_layer shape: [batch_size, key_len, hidden_size] 98 | # mixed_query_layer shape: [batch_size, value_len, hidden_size] 99 | 100 | query_layer = self.transpose_for_scores(mixed_query_layer) 101 | key_layer = self.transpose_for_scores(mixed_key_layer) 102 | value_layer = self.transpose_for_scores(mixed_value_layer) 103 | 104 | # query_layer shape: [batch_size, num_attention_heads, query_len, attention_head_size] 105 | # key_layer shape: [batch_size, num_attention_heads, key_len, attention_head_size] 106 | # value_layer shape: [batch_size, num_attention_heads, value_len, attention_head_size] 107 | 108 | # 交换k的最后两个维度,然后q和k执行点积, 获得attention score 109 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 110 | 111 | # attention_scores shape: [batch_size, num_attention_heads, query_len, key_len] 112 | 113 | # 是否进行attention scale 114 | if self.attention_scale: 115 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 116 | # 执行attention mask,对于mask为0部分的attention mask, 117 | # 值为-1e10,经过softmax后,attention_probs几乎为0,所以不会attention到mask为0的部分 118 | if attention_mask is not None: 119 | # attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e10) 120 | attention_mask = (1.0 - attention_mask) * -10000.0 121 | attention_scores = attention_scores + attention_mask 122 | 123 | # 将attention score 归一化到0-1 124 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 125 | 126 | attention_probs = self.dropout(attention_probs) 127 | 128 | context_layer = torch.matmul(attention_probs, value_layer) 129 | 130 | # context_layer shape: [batch_size, num_attention_heads, query_len, attention_head_size] 131 | 132 | # transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储, 133 | # 所以在调用view之前,需要contiguous来返回一个contiguous copy; 134 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 135 | 136 | # context_layer shape: [batch_size, query_len, num_attention_heads, attention_head_size] 137 | 138 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) 139 | context_layer = context_layer.view(*new_context_layer_shape) 140 | 141 | # 是否返回attention scores 142 | if self.return_attention_scores: 143 | # 这里返回的attention_scores没有经过softmax, 可在外部进行归一化操作 144 | return self.o(context_layer), attention_scores 145 | else: 146 | return self.o(context_layer) 147 | 148 | 149 | class PositionWiseFeedForward(nn.Module): 150 | def __init__(self, hidden_size, intermediate_size, dropout_rate=0.5, hidden_act='gelu', is_dropout=True): 151 | # 原生的tf版本的bert在激活函数后,没有添加dropout层,但是在google AI的bert-pytorch开源项目中,多了一层dropout; 152 | # 并且在pytorch官方的TransformerEncoderLayer的实现中,也有一层dropout层,就像这样:self.linear2(self.dropout(self.activation(self.linear1(src)))); 153 | # 这样不统一做法的原因不得而知,不过有没有这一层,差别可能不会很大; 154 | 155 | # 为了适配是否dropout,用is_dropout,dropout_rate两个参数控制;如果是实现原始的transformer,直接使用默认参数即可;如果是实现bert,则is_dropout为False,此时的dropout_rate参数并不会使用. 156 | super(PositionWiseFeedForward, self).__init__() 157 | 158 | self.is_dropout = is_dropout 159 | self.intermediate_act_fn = activations[hidden_act] 160 | self.intermediateDense = nn.Linear(hidden_size, intermediate_size) 161 | self.outputDense = nn.Linear(intermediate_size, hidden_size) 162 | if self.is_dropout: 163 | self.dropout = nn.Dropout(dropout_rate) 164 | 165 | def forward(self, x): 166 | # x shape: (batch size, seq len, hidden_size) 167 | if self.is_dropout: 168 | x = self.dropout(self.intermediate_act_fn(self.intermediateDense(x))) 169 | else: 170 | x = self.intermediate_act_fn(self.intermediateDense(x)) 171 | 172 | # x shape: (batch size, seq len, intermediate_size) 173 | x = self.outputDense(x) 174 | 175 | # x shape: (batch size, seq len, hidden_size) 176 | return x 177 | -------------------------------------------------------------------------------- /bert4pytorch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FocalLoss(nn.Module): 7 | def __init__(self, gamma: float = 2.0, weight=None, reduction: str = 'mean') -> None: 8 | super(FocalLoss, self).__init__() 9 | self.gamma = gamma 10 | self.weight = weight 11 | 12 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 13 | ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction="none") 14 | p_t = torch.exp(-ce_loss) 15 | loss = (1 - p_t)**self.gamma * ce_loss 16 | if self.reduction == "mean": 17 | loss = loss.mean() 18 | elif self.reduction == "sum": 19 | loss = loss.sum() 20 | return loss 21 | 22 | 23 | class LabelSmoothingCrossEntropy(nn.Module): 24 | def __init__(self, eps=0.1, reduction='mean',ignore_index=-100): 25 | super(LabelSmoothingCrossEntropy, self).__init__() 26 | self.eps = eps 27 | self.reduction = reduction 28 | self.ignore_index = ignore_index 29 | 30 | def forward(self, output, target): 31 | c = output.size()[-1] 32 | log_preds = F.log_softmax(output, dim=-1) 33 | if self.reduction=='sum': 34 | loss = -log_preds.sum() 35 | else: 36 | loss = -log_preds.sum(dim=-1) 37 | if self.reduction=='mean': 38 | loss = loss.mean() 39 | return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction, 40 | ignore_index=self.ignore_index) -------------------------------------------------------------------------------- /bert4pytorch/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import json 5 | from bert4pytorch.layers import LayerNorm, MultiHeadAttentionLayer, PositionWiseFeedForward, activations 6 | 7 | 8 | class Transformer(nn.Module): 9 | """模型基类 10 | """ 11 | 12 | def __init__( 13 | self, 14 | vocab_size, # 词表大小 15 | hidden_size, # 编码维度 16 | num_hidden_layers, # Transformer总层数 17 | num_attention_heads, # Attention的头数 18 | intermediate_size, # FeedForward的隐层维度 19 | hidden_act, # FeedForward隐层的激活函数 20 | dropout_rate, # Dropout比例 21 | embedding_size=None, # 指定embedding_size, 不指定则使用config文件的参数 22 | attention_head_size=None, # Attention中V的head_size 23 | attention_key_size=None, # Attention中Q,K的head_size 24 | sequence_length=None, # 是否固定序列长度 25 | keep_tokens=None, # 要保留的词ID列表 26 | compound_tokens=None, # 扩展Embedding 27 | residual_attention_scores=False, # Attention矩阵加残差 28 | ignore_invalid_weights=False, # 允许跳过不存在的权重 29 | **kwargs 30 | ): 31 | super(Transformer, self).__init__() 32 | if keep_tokens is not None: 33 | vocab_size = len(keep_tokens) 34 | if compound_tokens is not None: 35 | vocab_size += len(compound_tokens) 36 | self.vocab_size = vocab_size 37 | self.hidden_size = hidden_size 38 | self.num_hidden_layers = num_hidden_layers 39 | self.num_attention_heads = num_attention_heads 40 | self.attention_head_size = attention_head_size or self.hidden_size // self.num_attention_heads 41 | self.attention_key_size = attention_key_size or self.attention_head_size 42 | self.intermediate_size = intermediate_size 43 | self.dropout_rate = dropout_rate or 0 44 | self.hidden_act = hidden_act 45 | self.embedding_size = embedding_size or hidden_size 46 | self.sequence_length = sequence_length 47 | self.keep_tokens = keep_tokens 48 | self.compound_tokens = compound_tokens 49 | self.attention_bias = None 50 | self.position_bias = None 51 | self.attention_scores = None 52 | self.residual_attention_scores = residual_attention_scores 53 | self.ignore_invalid_weights = ignore_invalid_weights 54 | 55 | def init_model_weights(self, module): 56 | raise NotImplementedError 57 | 58 | def variable_mapping(self): 59 | """构建pytorch层与checkpoint的变量名之间的映射表 60 | """ 61 | return {} 62 | 63 | def load_weights_from_pytorch_checkpoint(self, checkpoint, mapping=None): 64 | """根据mapping从checkpoint加载权重 65 | """ 66 | # model = self 67 | state_dict = torch.load(checkpoint, map_location='cpu') 68 | mapping = mapping or self.variable_mapping() 69 | 70 | for new_key, old_key in mapping.items(): 71 | if old_key in state_dict.keys(): 72 | state_dict[new_key] = state_dict.pop(old_key) 73 | self.load_state_dict(state_dict, strict=self.ignore_invalid_weights) 74 | 75 | 76 | def lm_mask(segment_ids): 77 | """定义下三角Attention Mask(语言模型用) 78 | """ 79 | idxs = torch.arange(0, segment_ids.shape[1]) 80 | mask = (idxs.unsqueeze(0) <= idxs.unsqueeze(1)).unsqueeze(0).unsqueeze(0).to(dtype=torch.float32) 81 | return mask 82 | 83 | 84 | def unilm_mask(tokens_ids, segment_ids): 85 | """定义UniLM的Attention Mask(Seq2Seq模型用) 86 | 其中source和target的分区,由segment_ids来表示。 87 | UniLM: https://arxiv.org/abs/1905.03197 88 | 89 | token_ids: shape为(batch_size, seq_length), type为tensor 90 | segment_ids: shape为(batch_size, seq_length), type为tensor 91 | """ 92 | 93 | # 把segment_ids的padding部分值变成1,思想就是先不考虑padding,最后统一把padding部分的mask值设为0 94 | ids = segment_ids + (tokens_ids <= 0).long() 95 | # 在序列维度进行累加求和 96 | idxs = torch.cumsum(ids, dim=1) 97 | # 根据tokens_ids构造mask矩阵:[batch_size, 1, seq_length, 1] 98 | extended_mask = tokens_ids.unsqueeze(1).unsqueeze(3) 99 | # 构造unilm的mask矩阵,并把shape扩充到[batch_size, num_heads, from_seq_length, to_seq_length] 100 | mask = (idxs.unsqueeze(1) <= idxs.unsqueeze(2)).unsqueeze(1).to(dtype=torch.float32) 101 | # 把padding部分的mask值设为0 102 | mask *= extended_mask 103 | return mask 104 | 105 | 106 | #################################################################################### 107 | # bert # 108 | #################################################################################### 109 | 110 | 111 | class BertEmbeddings(nn.Module): 112 | """ 113 | embeddings层 114 | 构造word, position and token_type embeddings. 115 | """ 116 | def __init__(self, vocab_size, hidden_size, max_position, segment_vocab_size, drop_rate): 117 | super(BertEmbeddings, self).__init__() 118 | self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) 119 | self.position_embeddings = nn.Embedding(max_position, hidden_size) 120 | self.segment_embeddings = nn.Embedding(segment_vocab_size, hidden_size) 121 | 122 | self.layerNorm = LayerNorm(hidden_size, eps=1e-12) 123 | self.dropout = nn.Dropout(drop_rate) 124 | 125 | def forward(self, token_ids, segment_ids=None): 126 | seq_length = token_ids.size(1) 127 | position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) 128 | position_ids = position_ids.unsqueeze(0).expand_as(token_ids) 129 | if segment_ids is None: 130 | segment_ids = torch.zeros_like(token_ids) 131 | 132 | words_embeddings = self.word_embeddings(token_ids) 133 | position_embeddings = self.position_embeddings(position_ids) 134 | segment_embeddings = self.segment_embeddings(segment_ids) 135 | 136 | embeddings = words_embeddings + position_embeddings + segment_embeddings 137 | embeddings = self.layerNorm(embeddings) 138 | embeddings = self.dropout(embeddings) 139 | return embeddings 140 | 141 | 142 | class BertLayer(nn.Module): 143 | """ 144 | Transformer层: 145 | 顺序为: Attention --> Add --> LayerNorm --> Feed Forward --> Add --> LayerNorm 146 | 147 | 注意: 1、以上都不计dropout层,并不代表没有dropout,每一层的dropout使用略有不同,注意区分 148 | 2、原始的Transformer的encoder中的Feed Forward层一共有两层linear, 149 | config.intermediate_size的大小不仅是第一层linear的输出尺寸,也是第二层linear的输入尺寸 150 | """ 151 | def __init__(self, hidden_size, num_attention_heads, dropout_rate, intermediate_size, hidden_act, is_dropout=False): 152 | super(BertLayer, self).__init__() 153 | self.multiHeadAttention = MultiHeadAttentionLayer(hidden_size, num_attention_heads, dropout_rate) 154 | self.dropout1 = nn.Dropout(dropout_rate) 155 | self.layerNorm1 = LayerNorm(hidden_size, eps=1e-12) 156 | self.feedForward = PositionWiseFeedForward(hidden_size, intermediate_size, hidden_act, is_dropout=is_dropout) 157 | self.dropout2 = nn.Dropout(dropout_rate) 158 | self.layerNorm2 = LayerNorm(hidden_size, eps=1e-12) 159 | 160 | def forward(self, hidden_states, attention_mask): 161 | self_attn_output = self.multiHeadAttention(hidden_states, hidden_states, hidden_states, attention_mask) 162 | hidden_states = hidden_states + self.dropout1(self_attn_output) 163 | hidden_states = self.layerNorm1(hidden_states) 164 | self_attn_output2 = self.feedForward(hidden_states) 165 | hidden_states = hidden_states + self.dropout2(self_attn_output2) 166 | hidden_states = self.layerNorm2(hidden_states) 167 | return hidden_states 168 | 169 | 170 | class BERT(Transformer): 171 | """构建BERT模型 172 | """ 173 | 174 | def __init__( 175 | self, 176 | max_position, # 序列最大长度 177 | segment_vocab_size=2, # segment总数目 178 | initializer_range=0.02, # 权重初始化方差 179 | with_pool=False, # 是否包含Pool部分 180 | with_nsp=False, # 是否包含NSP部分 181 | with_mlm=False, # 是否包含MLM部分 182 | hierarchical_position=None, # 是否层次分解位置编码 183 | custom_position_ids=False, # 是否自行传入位置id 184 | **kwargs 185 | ): 186 | self.max_position = max_position 187 | self.segment_vocab_size = segment_vocab_size 188 | self.initializer_range = initializer_range 189 | self.with_pool = with_pool 190 | self.with_nsp = with_nsp 191 | self.with_mlm = with_mlm 192 | self.hierarchical_position = hierarchical_position 193 | self.custom_position_ids = custom_position_ids 194 | if self.with_nsp and not self.with_pool: 195 | self.with_pool = True 196 | 197 | super(BERT, self).__init__(**kwargs) 198 | 199 | self.embeddings = BertEmbeddings(self.vocab_size, self.hidden_size, self.max_position, self.segment_vocab_size, self.dropout_rate) 200 | layer = BertLayer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.intermediate_size, self.hidden_act, is_dropout=False) 201 | self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) for _ in range(self.num_hidden_layers)]) 202 | if self.with_pool: 203 | # Pooler部分(提取CLS向量) 204 | self.pooler = nn.Linear(self.hidden_size, self.hidden_size) 205 | self.pooler_activation = nn.Tanh() 206 | if self.with_nsp: 207 | # Next Sentence Prediction部分 208 | # nsp的输入为pooled_output, 所以with_pool为True是使用nsp的前提条件 209 | self.nsp = nn.Linear(self.hidden_size, 2) 210 | else: 211 | self.pooler = None 212 | self.pooler_activation = None 213 | if self.with_mlm: 214 | self.mlmDecoder = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 215 | # 需不需要这一操作,有待验证 216 | # self.mlmDecoder.weight = self.embeddings.word_embeddings.weight 217 | self.mlmBias = nn.Parameter(torch.zeros(self.vocab_size)) 218 | self.mlmDecoder.bias = self.mlmBias 219 | self.mlmDense = nn.Linear(self.hidden_size, self.hidden_size) 220 | self.transform_act_fn = activations[self.hidden_act] 221 | self.mlmLayerNorm = LayerNorm(self.hidden_size, eps=1e-12) 222 | self.apply(self.init_model_weights) 223 | 224 | def init_model_weights(self, module): 225 | """ 初始化权重 226 | """ 227 | if isinstance(module, (nn.Linear, nn.Embedding)): 228 | # bert参数初始化, tf版本在linear和Embedding层使用的是截断正太分布, pytorch没有实现该函数, 229 | # 此种初始化对于加载预训练模型后进行finetune没有任何影响, 230 | # cf https://github.com/pytorch/pytorch/pull/5617 231 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 232 | elif isinstance(module, LayerNorm): 233 | module.bias.data.zero_() 234 | module.weight.data.fill_(1.0) 235 | if isinstance(module, nn.Linear) and module.bias is not None: 236 | module.bias.data.zero_() 237 | 238 | def forward(self, token_ids, segment_ids=None, attention_mask=None, output_all_encoded_layers=False): 239 | """ 240 | token_ids: 一连串token在vocab中对应的id 241 | segment_ids: 就是token对应的句子id,值为0或1(0表示对应的token属于第一句,1表示属于第二句),当 242 | 任务只有一个句子输入时,segment_ids的每个值都是0,可不用传值 243 | attention_mask:各元素的值为0或1,避免在padding的token上计算attention, 1进行attetion, 0不进行attention 244 | 245 | 以上三个参数的shape为: (batch_size, sequence_length); type为tensor 246 | """ 247 | 248 | if attention_mask is None: 249 | # 根据token_ids创建一个3D的attention mask矩阵,尺寸为[batch_size, 1, 1, to_seq_length], 250 | # 目的是为了适配多头注意力机制,从而能广播到[batch_size, num_heads, from_seq_length, to_seq_length]尺寸 251 | attention_mask = (token_ids != 0).long().unsqueeze(1).unsqueeze(2) 252 | if segment_ids is None: 253 | segment_ids = torch.zeros_like(token_ids) 254 | 255 | # 兼容fp16 256 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) 257 | # 对mask矩阵中,数值为0的转换成很大的负数,使得不需要attention的位置经过softmax后,分数趋近于0 258 | # attention_mask = (1.0 - attention_mask) * -10000.0 259 | # 执行embedding 260 | hidden_states = self.embeddings(token_ids, segment_ids) 261 | # 执行encoder 262 | encoded_layers = [hidden_states] # 添加embedding的输出 263 | for layer_module in self.encoderLayer: 264 | hidden_states = layer_module(hidden_states, attention_mask) 265 | if output_all_encoded_layers: 266 | encoded_layers.append(hidden_states) 267 | if not output_all_encoded_layers: 268 | encoded_layers.append(hidden_states) 269 | 270 | # 获取最后一层隐藏层的输出 271 | sequence_output = encoded_layers[-1] 272 | # 是否取最后一层输出 273 | if not output_all_encoded_layers: 274 | encoded_layers = encoded_layers[-1] 275 | 276 | # 是否添加pool层 277 | if self.with_pool: 278 | pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) 279 | else: 280 | pooled_output = None 281 | # 是否添加nsp 282 | if self.with_pool and self.with_nsp: 283 | nsp_scores = self.nsp(pooled_output) 284 | else: 285 | nsp_scores = None 286 | # 是否添加mlm 287 | if self.with_mlm: 288 | mlm_hidden_state = self.mlmDense(sequence_output) 289 | mlm_hidden_state = self.transform_act_fn(mlm_hidden_state) 290 | mlm_hidden_state = self.mlmLayerNorm(mlm_hidden_state) 291 | mlm_scores = self.mlmDecoder(mlm_hidden_state) 292 | else: 293 | mlm_scores = None 294 | # 根据情况返回值 295 | if mlm_scores is None and nsp_scores is None: 296 | return encoded_layers, pooled_output 297 | elif mlm_scores is not None and nsp_scores is not None: 298 | return mlm_scores, nsp_scores 299 | elif mlm_scores is not None: 300 | return mlm_scores 301 | else: 302 | return nsp_scores 303 | 304 | def variable_mapping(self): 305 | mapping = { 306 | 'embeddings.word_embeddings.weight': 'bert.embeddings.word_embeddings.weight', 307 | 'embeddings.position_embeddings.weight': 'bert.embeddings.position_embeddings.weight', 308 | 'embeddings.segment_embeddings.weight': 'bert.embeddings.token_type_embeddings.weight', 309 | 'embeddings.layerNorm.weight': 'bert.embeddings.LayerNorm.weight', 310 | 'embeddings.layerNorm.bias': 'bert.embeddings.LayerNorm.bias', 311 | 'pooler.weight': 'bert.pooler.dense.weight', 312 | 'pooler.bias': 'bert.pooler.dense.bias', 313 | 'nsp.weight': 'cls.seq_relationship.weight', 314 | 'nsp.bias': 'cls.seq_relationship.bias', 315 | 'mlmDense.weight': 'cls.predictions.transform.dense.weight', 316 | 'mlmDense.bias': 'cls.predictions.transform.dense.bias', 317 | 'mlmLayerNorm.weight': 'cls.predictions.transform.LayerNorm.weight', 318 | 'mlmLayerNorm.bias': 'cls.predictions.transform.LayerNorm.bias', 319 | 'mlmBias': 'cls.predictions.bias', 320 | 'mlmDecoder.weight': 'cls.predictions.decoder.weight' 321 | 322 | } 323 | for i in range(self.num_hidden_layers): 324 | prefix = 'bert.encoder.layer.%d.' % i 325 | mapping.update({'encoderLayer.%d.multiHeadAttention.q.weight' % i: prefix + 'attention.self.query.weight', 326 | 'encoderLayer.%d.multiHeadAttention.q.bias' % i: prefix + 'attention.self.query.bias', 327 | 'encoderLayer.%d.multiHeadAttention.k.weight' % i: prefix + 'attention.self.key.weight', 328 | 'encoderLayer.%d.multiHeadAttention.k.bias' % i: prefix + 'attention.self.key.bias', 329 | 'encoderLayer.%d.multiHeadAttention.v.weight' % i: prefix + 'attention.self.value.weight', 330 | 'encoderLayer.%d.multiHeadAttention.v.bias' % i: prefix + 'attention.self.value.bias', 331 | 'encoderLayer.%d.multiHeadAttention.o.weight' % i: prefix + 'attention.output.dense.weight', 332 | 'encoderLayer.%d.multiHeadAttention.o.bias' % i: prefix + 'attention.output.dense.bias', 333 | 'encoderLayer.%d.layerNorm1.weight' % i: prefix + 'attention.output.LayerNorm.weight', 334 | 'encoderLayer.%d.layerNorm1.bias' % i: prefix + 'attention.output.LayerNorm.bias', 335 | 'encoderLayer.%d.feedForward.intermediateDense.weight' % i: prefix + 'intermediate.dense.weight', 336 | 'encoderLayer.%d.feedForward.intermediateDense.bias' % i: prefix + 'intermediate.dense.bias', 337 | 'encoderLayer.%d.feedForward.outputDense.weight' % i: prefix + 'output.dense.weight', 338 | 'encoderLayer.%d.feedForward.outputDense.bias' % i: prefix + 'output.dense.bias', 339 | 'encoderLayer.%d.layerNorm2.weight' % i: prefix + 'output.LayerNorm.weight', 340 | 'encoderLayer.%d.layerNorm2.bias' % i: prefix + 'output.LayerNorm.bias' 341 | }) 342 | 343 | return mapping 344 | 345 | 346 | def build_transformer_model( 347 | config_path=None, 348 | checkpoint_path=None, 349 | model='bert', 350 | application='encoder', 351 | **kwargs 352 | ): 353 | """根据配置文件构建模型,可选加载checkpoint权重 354 | """ 355 | configs = {} 356 | if config_path is not None: 357 | configs.update(json.load(open(config_path))) 358 | configs.update(kwargs) 359 | if 'max_position' not in configs: 360 | configs['max_position'] = configs.get('max_position_embeddings', 512) 361 | if 'dropout_rate' not in configs: 362 | configs['dropout_rate'] = configs.get('hidden_dropout_prob') 363 | if 'segment_vocab_size' not in configs: 364 | configs['segment_vocab_size'] = configs.get('type_vocab_size', 2) 365 | models = { 366 | 'bert': BERT, 367 | 'roberta': BERT 368 | } 369 | 370 | my_model = models[model] 371 | transformer = my_model(**configs) 372 | if checkpoint_path is not None: 373 | transformer.load_weights_from_pytorch_checkpoint(checkpoint_path) 374 | return transformer 375 | -------------------------------------------------------------------------------- /bert4pytorch/optimization.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Iterable, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 10 | 11 | """ 12 | 带warmup的schedule 13 | 14 | 参数 15 | num_warmup_steps: 16 | 需要warmup的步数,一般为 num_training_steps * warmup_proportion(warmup的比例,建议0.05-0.15) 17 | 18 | num_training_steps: 19 | 总的训练步数,一般为 train_batches * num_epoch 20 | """ 21 | 22 | def lr_lambda(current_step: int): 23 | if current_step < num_warmup_steps: 24 | return float(current_step) / float(max(1, num_warmup_steps)) 25 | return max( 26 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 27 | ) 28 | 29 | return LambdaLR(optimizer, lr_lambda, last_epoch) 30 | 31 | 32 | class AdamW(Optimizer): 33 | """ 34 | 带权重衰减的Adam 35 | `__. 36 | 37 | 参数: 38 | params (:obj:`Iterable[torch.nn.parameter.Parameter]`): 39 | lr (:obj:`float`, `optional`, defaults to 1e-3): 40 | 学习率. 41 | betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)): 42 | Adam的betas参数 (b1, b2) 43 | eps (:obj:`float`, `optional`, defaults to 1e-6): 44 | Adam的epsilon参数,用于数值稳定性 45 | weight_decay (:obj:`float`, `optional`, defaults to 0): 46 | 权重衰减参数 47 | correct_bias (:obj:`bool`, `optional`, defaults to `True`): 48 | 修正Adm的bias (原始的tf版本的bert,没有修正bias,取值为False,但是可以尝试用True,可能会收敛更稳定) 49 | 例子: 50 | param_optimizer = list(model.named_parameters()) 51 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 52 | optimizer_grouped_parameters = [ 53 | {'params': [p for n, p in param_optimizer 54 | if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 55 | {'params': [p for n, p in param_optimizer 56 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 57 | ] 58 | optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5, correct_bias=False) 59 | 60 | """ 61 | 62 | def __init__( 63 | self, 64 | params: Iterable[torch.nn.parameter.Parameter], 65 | lr: float = 1e-3, 66 | betas: Tuple[float, float] = (0.9, 0.999), 67 | eps: float = 1e-6, 68 | weight_decay: float = 0.0, 69 | correct_bias: bool = True, 70 | ): 71 | if lr < 0.0: 72 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") 73 | if not 0.0 <= betas[0] < 1.0: 74 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0[") 75 | if not 0.0 <= betas[1] < 1.0: 76 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0[") 77 | if not 0.0 <= eps: 78 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") 79 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 80 | super().__init__(params, defaults) 81 | 82 | def step(self, closure: Callable = None): 83 | """ 84 | 执行单步优化 85 | 86 | 参数: 87 | closure (:obj:`Callable`, `optional`): 88 | 评估模型并返回loss,是一个闭包 89 | """ 90 | loss = None 91 | if closure is not None: 92 | with torch.enable_grad(): 93 | loss = closure() 94 | 95 | for group in self.param_groups: 96 | for p in group["params"]: 97 | if p.grad is None: 98 | continue 99 | grad = p.grad.data 100 | if grad.is_sparse: 101 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 102 | 103 | state = self.state[p] 104 | 105 | # state初始化 106 | if len(state) == 0: 107 | state["step"] = 0 108 | # 一阶梯度的指数加权移动平均,也即累积一阶动量的计算 109 | state["exp_avg"] = torch.zeros_like(p.data) 110 | # 二阶梯度的指数加权移动平均,也即累积二阶动量的计算 111 | state["exp_avg_sq"] = torch.zeros_like(p.data) 112 | 113 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 114 | beta1, beta2 = group["betas"] 115 | 116 | state["step"] += 1 117 | 118 | # 计算一二阶梯度的beta系数下的衰减值,并进行更新 119 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 120 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 121 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 122 | 123 | step_size = group["lr"] 124 | # 修正bias,对于bert来说,不需要执行此操作 125 | if group["correct_bias"]: 126 | bias_correction1 = 1.0 - beta1 ** state["step"] 127 | bias_correction2 = 1.0 - beta2 ** state["step"] 128 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 129 | 130 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 131 | 132 | # 权重衰减项,目的是为了解决在adam等自适应优化算法中由于m和v的相互作用导致的L2正则表现不佳的情况。 133 | # 使用权重衰减,能使得每个梯度都以相同的比例进行衰减(等价于SGD下的L2正则) 134 | if group["weight_decay"] > 0.0: 135 | p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) 136 | 137 | return loss 138 | 139 | 140 | class FGM(): 141 | ''' 142 | FGM对抗训练 143 | 例子: 144 | # 初始化 145 | fgm = FGM(model) 146 | ... 这里省略中间过程 147 | # 在计算损失并backward后,调用attack,目的是对word embedding增加扰动 148 | loss = critertion(outputs, labels) 149 | loss.backward() 150 | fgm.attack() 151 | # optimizer.zero_grad() # 如果不想累加梯度,就把这里的注释取消,一般不使用 152 | # 输入再次传入model计算损失,然后反向传播,累加对抗训练的梯度 153 | loss_sum = critertion(model(token_ids, segment_ids), labels) 154 | loss_sum.backward() 155 | # 恢复Embedding的参数 156 | fgm.restore() 157 | # 梯度下降,更新参数 158 | optimizer.step() 159 | optimizer.zero_grad() 160 | 161 | ''' 162 | def __init__(self, model): 163 | self.model = model 164 | self.backup = {} 165 | 166 | def attack(self, epsilon=1., emb_name='word_embeddings'): 167 | for name, param in self.model.named_parameters(): 168 | if param.requires_grad and emb_name in name: 169 | self.backup[name] = param.data.clone() 170 | norm = torch.norm(param.grad) 171 | if norm != 0 and not torch.isnan(norm): 172 | r_at = epsilon * param.grad / norm 173 | param.data.add_(r_at) 174 | 175 | def restore(self, emb_name='word_embeddings'): 176 | for name, param in self.model.named_parameters(): 177 | if param.requires_grad and emb_name in name: 178 | assert name in self.backup 179 | param.data = self.backup[name] 180 | self.backup = {} 181 | 182 | 183 | class ExponentialMovingAverage(): 184 | ''' 185 | 模型权重的指数滑动平均 186 | 注意区别于类似adam一类的自适应学习率优化器,针对一阶二阶梯度的指数滑动平均,两者完全不同 187 | 188 | 例子: 189 | # 初始化 190 | ema = ExponentialMovingAverage(model, 0.999) 191 | 192 | # 训练过程中,更新完参数后,同步update ema_weights weights 193 | def train(): 194 | optimizer.step() 195 | ema.update() 196 | 197 | # eval前,调用apply_ema_weights weights;eval之后,恢复原来模型的参数 198 | def evaluate(): 199 | ema.apply_ema_weights() 200 | # evaluate 201 | # 如果想保存ema后的模型,请在reset_old_weights方法之前调用torch.save() 202 | ema.reset_old_weights() 203 | ''' 204 | def __init__(self, model, decay): 205 | self.model = model 206 | self.decay = decay 207 | # 保存ema权重(当前step的每一层的滑动平均权重) 208 | self.ema_weights = {} 209 | # 在进行evaluate的时候,保存原始的模型权重,当执行完evaluate后,从ema权重恢复到原始权重 210 | self.model_weights = {} 211 | 212 | # 初始化ema_weights为model_weights 213 | for name, param in self.model.named_parameters(): 214 | if param.requires_grad: 215 | self.ema_weights[name] = param.data.clone() 216 | 217 | def update(self): 218 | for name, param in self.model.named_parameters(): 219 | if param.requires_grad: 220 | assert name in self.ema_weights 221 | new_average = (1.0 - self.decay) * param.data + self.decay * self.ema_weights[name] 222 | self.ema_weights[name] = new_average.clone() 223 | 224 | def apply_ema_weights(self): 225 | for name, param in self.model.named_parameters(): 226 | if param.requires_grad: 227 | assert name in self.ema_weights 228 | self.model_weights[name] = param.data 229 | param.data = self.ema_weights[name] 230 | 231 | def reset_old_weights(self): 232 | for name, param in self.model.named_parameters(): 233 | if param.requires_grad: 234 | assert name in self.model_weights 235 | param.data = self.model_weights[name] 236 | self.model_weights = {} 237 | 238 | 239 | # def extend_with_exponential_moving_average(BaseOptimizer, model): 240 | 241 | # class EmaOptimizer(BaseOptimizer): 242 | 243 | # # @insert_arguments(ema_momentum=0.999) 244 | # def __init__(self, model, *args, **kwargs): 245 | # super(EmaOptimizer, self).__init__(*args, **kwargs) 246 | # self.model = model 247 | # # 保存ema权重(当前step的每一层的滑动平均权重) 248 | # self.ema_weights = {} 249 | # # 在进行evaluate的时候,保存原始的模型权重,当执行完evaluate后,从ema权重恢复到原始权重 250 | # self.model_weights = {} 251 | 252 | # # 初始化ema_weights为model_weights 253 | # for name, param in self.model.named_parameters(): 254 | # if param.requires_grad: 255 | # self.ema_weights[name] = param.data.clone() 256 | # def step(sel, closure: Callable = None): 257 | # """ 258 | # 执行单步优化 259 | 260 | # 参数: 261 | # closure (:obj:`Callable`, `optional`): 262 | # 评估模型并返回loss,是一个闭包 263 | # """ 264 | # loss = None 265 | # if closure is not None: 266 | # loss = closure() 267 | # loss = super(NewOptimizer, self).step() 268 | # self.update() 269 | # return loss 270 | 271 | # def update(self): 272 | # for name, param in self.model.named_parameters(): 273 | # if param.requires_grad: 274 | # assert name in self.ema_weights 275 | # new_average = (1.0 - self.decay) * param.data + self.decay * self.ema_weights[name] 276 | # self.ema_weights[name] = new_average.clone() 277 | 278 | # def apply_ema_weights(self): 279 | # for name, param in self.model.named_parameters(): 280 | # if param.requires_grad: 281 | # assert name in self.ema_weights 282 | # self.model_weights[name] = param.data 283 | # param.data = self.ema_weights[name] 284 | 285 | # def reset_old_weights(self): 286 | # for name, param in self.model.named_parameters(): 287 | # if param.requires_grad: 288 | # assert name in self.model_weights 289 | # param.data = self.model_weights[name] 290 | # self.model_weights = {} 291 | # return EmaOptimizer -------------------------------------------------------------------------------- /bert4pytorch/snippets.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 其他代码合 3 | 4 | import numpy as np 5 | 6 | def truncate_sequences(maxlen, indices, *sequences): 7 | """截断总长度至不超过maxlen 8 | """ 9 | sequences = [s for s in sequences if s] 10 | if not isinstance(indices, (list, tuple)): 11 | indices = [indices] * len(sequences) 12 | 13 | while True: 14 | lengths = [len(s) for s in sequences] 15 | if sum(lengths) > maxlen: 16 | i = np.argmax(lengths) 17 | sequences[i].pop(indices[i]) 18 | else: 19 | return sequences 20 | 21 | 22 | def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'): 23 | """Numpy函数,将序列padding到同一长度 24 | """ 25 | if length is None: 26 | length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) 27 | elif not hasattr(length, '__getitem__'): 28 | length = [length] 29 | 30 | slices = [np.s_[:length[i]] for i in range(seq_dims)] 31 | slices = tuple(slices) if len(slices) > 1 else slices[0] 32 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 33 | 34 | outputs = [] 35 | for x in inputs: 36 | x = x[slices] 37 | for i in range(seq_dims): 38 | if mode == 'post': 39 | pad_width[i] = (0, length[i] - np.shape(x)[i]) 40 | elif mode == 'pre': 41 | pad_width[i] = (length[i] - np.shape(x)[i], 0) 42 | else: 43 | raise ValueError('"mode" argument must be "post" or "pre".') 44 | x = np.pad(x, pad_width, 'constant', constant_values=value) 45 | outputs.append(x) 46 | 47 | return np.array(outputs) 48 | 49 | 50 | def insert_arguments(**arguments): 51 | """ 52 | 装饰器,为类方法增加参数 53 | (主要用于类的__init__方法) 54 | """ 55 | def actual_decorator(func): 56 | def new_func(self, *args, **kwargs): 57 | for k, v in arguments.items(): 58 | if k in kwargs: 59 | v = kwargs.pop(k) 60 | setattr(self, k, v) 61 | return func(self, *args, **kwargs) 62 | 63 | return new_func 64 | 65 | return actual_decorator 66 | 67 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 68 | """加载 tf checkpoints 到 pytorch model.""" 69 | # 需要安装tensorflow,请自行安装 70 | try: 71 | import re 72 | 73 | import numpy as np 74 | import tensorflow as tf 75 | except ImportError: 76 | logger.error( 77 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 78 | "https://www.tensorflow.org/install/ for installation instructions." 79 | ) 80 | raise 81 | tf_path = os.path.abspath(tf_checkpoint_path) 82 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 83 | # Load weights from TF model 84 | init_vars = tf.train.list_variables(tf_path) 85 | names = [] 86 | arrays = [] 87 | for name, shape in init_vars: 88 | logger.info(f"Loading TF weight {name} with shape {shape}") 89 | array = tf.train.load_variable(tf_path, name) 90 | names.append(name) 91 | arrays.append(array) 92 | 93 | for name, array in zip(names, arrays): 94 | name = name.split("/") 95 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 96 | # which are not required for using pretrained model 97 | if any( 98 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 99 | for n in name 100 | ): 101 | logger.info(f"Skipping {'/'.join(name)}") 102 | continue 103 | pointer = model 104 | for m_name in name: 105 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 106 | scope_names = re.split(r"_(\d+)", m_name) 107 | else: 108 | scope_names = [m_name] 109 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 110 | pointer = getattr(pointer, "weight") 111 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 112 | pointer = getattr(pointer, "bias") 113 | elif scope_names[0] == "output_weights": 114 | pointer = getattr(pointer, "weight") 115 | elif scope_names[0] == "squad": 116 | pointer = getattr(pointer, "classifier") 117 | else: 118 | try: 119 | pointer = getattr(pointer, scope_names[0]) 120 | except AttributeError: 121 | logger.info(f"Skipping {'/'.join(name)}") 122 | continue 123 | if len(scope_names) >= 2: 124 | num = int(scope_names[1]) 125 | pointer = pointer[num] 126 | if m_name[-11:] == "_embeddings": 127 | pointer = getattr(pointer, "weight") 128 | elif m_name == "kernel": 129 | array = np.transpose(array) 130 | try: 131 | assert ( 132 | pointer.shape == array.shape 133 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 134 | except AssertionError as e: 135 | e.args += (pointer.shape, array.shape) 136 | raise 137 | logger.info(f"Initialize PyTorch weight {name}") 138 | pointer.data = torch.from_numpy(array) 139 | return model 140 | 141 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 142 | """tf模型转pytorch""" 143 | # 初始化 PyTorch model 144 | config = BertConfig.from_json_file(bert_config_file) 145 | print("Building PyTorch model from configuration: {}".format(str(config))) 146 | model = BertForPreTraining(config) 147 | 148 | # 从tf checkpoint加载权重 149 | load_tf_weights_in_bert(model, tf_checkpoint_path) 150 | 151 | # 保存pytorch模型 152 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 153 | torch.save(model.state_dict(), pytorch_dump_path) 154 | -------------------------------------------------------------------------------- /bert4pytorch/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Tokenization classes.""" 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import collections 8 | import logging 9 | import os 10 | import unicodedata 11 | from io import open 12 | from bert4pytorch.snippets import truncate_sequences 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def load_vocab(vocab_file): 18 | """加载词典文件到dict""" 19 | vocab = collections.OrderedDict() 20 | index = 0 21 | with open(vocab_file, "r", encoding="utf-8") as reader: 22 | while True: 23 | token = reader.readline() 24 | if not token: 25 | break 26 | token = token.strip() 27 | vocab[token] = index 28 | index += 1 29 | return vocab 30 | 31 | 32 | def whitespace_tokenize(text): 33 | """去除文本中的空白符""" 34 | text = text.strip() 35 | if not text: 36 | return [] 37 | tokens = text.split() 38 | return tokens 39 | 40 | 41 | class Tokenizer(object): 42 | 43 | def __init__( 44 | self, 45 | vocab_file, 46 | do_lower_case=True, 47 | do_basic_tokenize=True, 48 | unk_token="[UNK]", 49 | sep_token="[SEP]", 50 | pad_token="[PAD]", 51 | cls_token="[CLS]", 52 | mask_token="[MASK]"): 53 | """ 54 | 55 | 参数: 56 | vocab_file: 57 | 词典文件 58 | do_lower_case: 59 | 是否转换成小写 60 | do_basic_tokenize: 61 | 分词前,是否进行基础的分词 62 | unk_token: 63 | 未知词标记 64 | sep_token: 65 | 句子切分标记,当只有一句话作为输入时,此标记知识作为结束符;当有两句话作为输入时,此标记作为分隔符、最后一句话的结束符 66 | pad_token: 67 | padding填充标记 68 | cls_token: 69 | 分类标记,位于整个序列的第一个 70 | mask_token: 71 | mask标记 72 | 73 | """ 74 | if not os.path.isfile(vocab_file): 75 | raise ValueError( 76 | "Can't find a vocabulary file at path '{}'.".format(vocab_file)) 77 | self.vocab = load_vocab(vocab_file) 78 | self.ids_to_tokens = collections.OrderedDict( 79 | [(ids, tok) for tok, ids in self.vocab.items()]) 80 | self.do_basic_tokenize = do_basic_tokenize 81 | if do_basic_tokenize: 82 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 83 | never_split=(unk_token, sep_token, pad_token, cls_token, mask_token)) 84 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 85 | self.unk_token = unk_token 86 | self.sep_token = sep_token 87 | self.pad_token = pad_token 88 | self.cls_token = cls_token 89 | self.mask_token = mask_token 90 | 91 | def tokenize(self, text): 92 | split_tokens = [] 93 | if self.do_basic_tokenize: 94 | for token in self.basic_tokenizer.tokenize(text): 95 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 96 | split_tokens.append(sub_token) 97 | else: 98 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 99 | if self.cls_token is not None: 100 | split_tokens.insert(0, self.cls_token) 101 | if self.sep_token is not None: 102 | split_tokens.append(self.sep_token) 103 | return split_tokens 104 | 105 | def convert_tokens_to_ids(self, tokens): 106 | """tokens转为vocab中的id""" 107 | ids = [] 108 | for token in tokens: 109 | ids.append(self.vocab[token]) 110 | return ids 111 | 112 | def convert_ids_to_tokens(self, ids): 113 | """ids转为词表中的token""" 114 | tokens = [] 115 | for i in ids: 116 | tokens.append(self.ids_to_tokens[i]) 117 | return tokens 118 | 119 | def encode( 120 | self, 121 | first_text, 122 | second_text=None, 123 | max_len=None, 124 | truncate_from='right' 125 | ): 126 | """输出文本对应token id和segment id 127 | """ 128 | if isinstance(first_text, str): 129 | first_tokens = self.tokenize(first_text) 130 | else: 131 | first_tokens = first_text 132 | 133 | if second_text is None: 134 | second_tokens = None 135 | elif isinstance(second_text, str): 136 | second_tokens = self.tokenize(second_text) 137 | else: 138 | second_tokens = second_text 139 | 140 | if max_len is not None: 141 | if truncate_from == 'right': 142 | index = -2 143 | elif truncate_from == 'left': 144 | index = 1 145 | else: 146 | index = truncate_from 147 | if second_text is not None: 148 | max_len += 1 149 | truncate_sequences(max_len, index, first_tokens, second_tokens) 150 | 151 | first_token_ids = self.convert_tokens_to_ids(first_tokens) 152 | first_segment_ids = [0] * len(first_token_ids) 153 | 154 | if second_text is not None: 155 | idx = int(bool('[CLS]')) 156 | second_tokens = second_tokens[idx:] 157 | second_token_ids = self.convert_tokens_to_ids(second_tokens) 158 | second_segment_ids = [1] * len(second_token_ids) 159 | first_token_ids.extend(second_token_ids) 160 | first_segment_ids.extend(second_segment_ids) 161 | 162 | return first_token_ids, first_segment_ids 163 | 164 | 165 | 166 | class BasicTokenizer(object): 167 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 168 | 169 | def __init__(self, 170 | do_lower_case=True, 171 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 172 | """Constructs a BasicTokenizer. 173 | 174 | Args: 175 | do_lower_case: Whether to lower case the input. 176 | """ 177 | self.do_lower_case = do_lower_case 178 | self.never_split = never_split 179 | 180 | def tokenize(self, text): 181 | """文本切分成token""" 182 | text = self._clean_text(text) 183 | text = self._tokenize_chinese_chars(text) 184 | orig_tokens = whitespace_tokenize(text) 185 | split_tokens = [] 186 | for token in orig_tokens: 187 | if self.do_lower_case and token not in self.never_split: 188 | token = token.lower() 189 | token = self._run_strip_accents(token) 190 | split_tokens.extend(self._run_split_on_punc(token)) 191 | 192 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 193 | return output_tokens 194 | 195 | def _run_strip_accents(self, text): 196 | """Strips accents from a piece of text.""" 197 | text = unicodedata.normalize("NFD", text) 198 | output = [] 199 | for char in text: 200 | cat = unicodedata.category(char) 201 | if cat == "Mn": 202 | continue 203 | output.append(char) 204 | return "".join(output) 205 | 206 | def _run_split_on_punc(self, text): 207 | """Splits punctuation on a piece of text.""" 208 | if text in self.never_split: 209 | return [text] 210 | chars = list(text) 211 | i = 0 212 | start_new_word = True 213 | output = [] 214 | while i < len(chars): 215 | char = chars[i] 216 | if _is_punctuation(char): 217 | output.append([char]) 218 | start_new_word = True 219 | else: 220 | if start_new_word: 221 | output.append([]) 222 | start_new_word = False 223 | output[-1].append(char) 224 | i += 1 225 | 226 | return ["".join(x) for x in output] 227 | 228 | def _tokenize_chinese_chars(self, text): 229 | """Adds whitespace around any CJK character.""" 230 | output = [] 231 | for char in text: 232 | cp = ord(char) 233 | if self._is_chinese_char(cp): 234 | output.append(" ") 235 | output.append(char) 236 | output.append(" ") 237 | else: 238 | output.append(char) 239 | return "".join(output) 240 | 241 | def _is_chinese_char(self, cp): 242 | """Checks whether CP is the codepoint of a CJK character.""" 243 | # This defines a "chinese character" as anything in the CJK Unicode block: 244 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 245 | # 246 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 247 | # despite its name. The modern Korean Hangul alphabet is a different block, 248 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 249 | # space-separated words, so they are not treated specially and handled 250 | # like the all of the other languages. 251 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 252 | (cp >= 0x3400 and cp <= 0x4DBF) or # 253 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 254 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 255 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 256 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 257 | (cp >= 0xF900 and cp <= 0xFAFF) or # 258 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 259 | return True 260 | 261 | return False 262 | 263 | def _clean_text(self, text): 264 | """Performs invalid character removal and whitespace cleanup on text.""" 265 | output = [] 266 | for char in text: 267 | cp = ord(char) 268 | if cp == 0 or cp == 0xfffd or _is_control(char): 269 | continue 270 | if _is_whitespace(char): 271 | output.append(" ") 272 | else: 273 | output.append(char) 274 | return "".join(output) 275 | 276 | 277 | class WordpieceTokenizer(object): 278 | """Runs WordPiece tokenization.""" 279 | 280 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 281 | self.vocab = vocab 282 | self.unk_token = unk_token 283 | self.max_input_chars_per_word = max_input_chars_per_word 284 | 285 | def tokenize(self, text): 286 | """Tokenizes a piece of text into its word pieces. 287 | 288 | This uses a greedy longest-match-first algorithm to perform tokenization 289 | using the given vocabulary. 290 | 291 | For example: 292 | input = "unaffable" 293 | output = ["un", "##aff", "##able"] 294 | 295 | Args: 296 | text: A single token or whitespace separated tokens. This should have 297 | already been passed through `BasicTokenizer`. 298 | 299 | Returns: 300 | A list of wordpiece tokens. 301 | """ 302 | 303 | output_tokens = [] 304 | for token in whitespace_tokenize(text): 305 | chars = list(token) 306 | if len(chars) > self.max_input_chars_per_word: 307 | output_tokens.append(self.unk_token) 308 | continue 309 | 310 | is_bad = False 311 | start = 0 312 | sub_tokens = [] 313 | while start < len(chars): 314 | end = len(chars) 315 | cur_substr = None 316 | while start < end: 317 | substr = "".join(chars[start:end]) 318 | if start > 0: 319 | substr = "##" + substr 320 | if substr in self.vocab: 321 | cur_substr = substr 322 | break 323 | end -= 1 324 | if cur_substr is None: 325 | is_bad = True 326 | break 327 | sub_tokens.append(cur_substr) 328 | start = end 329 | 330 | if is_bad: 331 | output_tokens.append(self.unk_token) 332 | else: 333 | output_tokens.extend(sub_tokens) 334 | return output_tokens 335 | 336 | 337 | def _is_whitespace(char): 338 | """Checks whether `chars` is a whitespace character.""" 339 | # \t, \n, and \r are technically contorl characters but we treat them 340 | # as whitespace since they are generally considered as such. 341 | if char == " " or char == "\t" or char == "\n" or char == "\r": 342 | return True 343 | cat = unicodedata.category(char) 344 | if cat == "Zs": 345 | return True 346 | return False 347 | 348 | 349 | def _is_control(char): 350 | """Checks whether `chars` is a control character.""" 351 | # These are technically control characters but we count them as whitespace 352 | # characters. 353 | if char == "\t" or char == "\n" or char == "\r": 354 | return False 355 | cat = unicodedata.category(char) 356 | if cat.startswith("C"): 357 | return True 358 | return False 359 | 360 | 361 | def _is_punctuation(char): 362 | """Checks whether `chars` is a punctuation character.""" 363 | cp = ord(char) 364 | # We treat all non-letter/number ASCII as punctuation. 365 | # Characters such as "^", "$", and "`" are not in the Unicode 366 | # Punctuation class but we treat them as punctuation anyways, for 367 | # consistency. 368 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 369 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 370 | return True 371 | cat = unicodedata.category(char) 372 | if cat.startswith("P"): 373 | return True 374 | return False 375 | 376 | def convert_to_unicode(text): 377 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 378 | if isinstance(text, str): 379 | return text 380 | elif isinstance(text, bytes): 381 | return text.decode("utf-8", "ignore") 382 | else: 383 | raise ValueError("Unsupported string type: %s" % (type(text))) -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # 案例解释 2 | - [basic_language_model.py](./basic_language_model.py): 基础测试, 测试mlm预测的效果 3 | - [tnews_classify_finetune.py](./tnews_classify_finetune.py): 任务例子,通过bert,在CLUE的tnews数据集上进行分类。 -------------------------------------------------------------------------------- /examples/basic_language_model.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 基础测试:mlm预测 3 | 4 | from bert4pytorch.modeling import build_transformer_model 5 | from bert4pytorch.tokenization import Tokenizer 6 | import torch 7 | 8 | # 加载模型,请更换成自己的路径 9 | root_model_path = "D:/vscodeworkspace/pythonCode/my_project/my_pytorch_bert/pytorch_bert_pretrain_model" 10 | vocab_path = root_model_path + "/vocab.txt" 11 | config_path = root_model_path + "/config.json" 12 | checkpoint_path = root_model_path + '/pytorch_model.bin' 13 | 14 | 15 | # 建立分词器 16 | tokenizer = Tokenizer(vocab_path) 17 | sentence = "北京[MASK]安门" 18 | 19 | 20 | tokens_ids, segments_ids = tokenizer.encode(sentence) 21 | mask_position = tokens_ids.index(103) 22 | 23 | tokens_ids_tensor = torch.tensor([tokens_ids]) 24 | segment_ids_tensor = torch.tensor([segments_ids]) 25 | 26 | # 需要传入参数with_mlm 27 | model = build_transformer_model(config_path, checkpoint_path, with_mlm=True) 28 | model.eval() 29 | output = model(tokens_ids_tensor, segment_ids_tensor) 30 | 31 | result = torch.argmax(output[0, mask_position]).item() 32 | 33 | print(tokenizer.convert_ids_to_tokens([result])[0]) 34 | -------------------------------------------------------------------------------- /examples/tnews_classify_finetune.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | # sys.path.insert(0, 'D:/vscodeworkspace/pythonCode/git_workspace/') 3 | # sys.path.insert(0, 'D:/vscodeworkspace/pythonCode/git_workspace/bert4pytorch/') 4 | # sys.path.insert(0, 'D:/vscodeworkspace/pythonCode/git_workspace/bert4pytorch/bert4pytorch/') 5 | # sys.path.insert(0, 'D:/vscodeworkspace/pythonCode/git_workspace/bert4pytorch/examples/') 6 | # print(sys.path) 7 | 8 | #! -*- coding: utf-8 -*- 9 | # 基本测试:bert分类(15类) 10 | # 数据集:clue benchmark数据集 11 | # 数据下载链接:https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip 12 | 13 | from torch.utils.data import Dataset, DataLoader 14 | from bert4pytorch.modeling import build_transformer_model 15 | from bert4pytorch.tokenization import Tokenizer 16 | from bert4pytorch.optimization import AdamW, get_linear_schedule_with_warmup 17 | import torch 18 | import torch.nn as nn 19 | import json 20 | import time 21 | 22 | SEED = 100 23 | torch.manual_seed(SEED) 24 | # torch.cuda.manual_seed_all(SEED) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | learning_rate = 2e-5 29 | epochs = 50 30 | max_len = 32 31 | batch_size = 16 32 | 33 | device = torch.device("cuda" if torch.cuda.is_available() > 0 else "cpu") 34 | 35 | 36 | # 加载模型,请更换成自己的路径 37 | root_model_path = "D:/vscodeworkspace/pythonCode/my_project/my_pytorch_bert/pytorch_bert_pretrain_model" 38 | vocab_path = root_model_path + "/vocab.txt" 39 | config_path = root_model_path + "/config.json" 40 | checkpoint_path = root_model_path + '/pytorch_model.bin' 41 | 42 | 43 | def load_data(filename): 44 | """加载数据 45 | 单条格式:(文本, 标签id) 46 | """ 47 | texts = [] 48 | labels = [] 49 | with open(filename, encoding='utf8') as f: 50 | for i, l in enumerate(f): 51 | l = json.loads(l) 52 | text = l['sentence'] 53 | label = l['label'] 54 | texts.append(text) 55 | label_int = int(label) 56 | # 数据集标签转成pytorch需要的形式[0, 类别数-1] 57 | if label_int <= 104: 58 | labels.append(label_int - 100) 59 | elif 104 < label_int <= 110: 60 | labels.append(label_int - 101) 61 | else: 62 | labels.append(label_int - 102) 63 | return texts, labels 64 | 65 | 66 | # 加载数据集,请更换成自己的路径 67 | X_train, y_train = load_data('D:/vscodeworkspace/pythonCode/git_workspace/clue_dataset/tnews_public/train.json') 68 | X_test, y_test = load_data('D:/vscodeworkspace/pythonCode/git_workspace/clue_dataset/tnews_public/dev.json') 69 | 70 | 71 | # 建立分词器 72 | tokenizer = Tokenizer(vocab_path) 73 | 74 | 75 | class MyDataset(Dataset): 76 | def __init__(self, X, y): 77 | self.X = X 78 | self.y = y 79 | 80 | def __len__(self): 81 | return len(self.y) 82 | 83 | def __getitem__(self, index): 84 | sentence = self.X[index] 85 | label = self.y[index] 86 | tokens_ids, segments_ids = tokenizer.encode(sentence, max_len=max_len) 87 | tokens_ids = tokens_ids + (max_len - len(tokens_ids)) * [0] 88 | segments_ids = segments_ids + (max_len - len(segments_ids)) * [0] 89 | tokens_ids_tensor = torch.tensor(tokens_ids) 90 | segment_ids_tensor = torch.tensor(segments_ids) 91 | return tokens_ids_tensor, segment_ids_tensor, label 92 | 93 | 94 | class Model(nn.Module): 95 | 96 | def __init__(self, config, checkpoint): 97 | super(Model, self).__init__() 98 | self.model = build_transformer_model(config, checkpoint, with_pool=True) 99 | '''所有层都训练''' 100 | for param in self.model.parameters(): 101 | param.requires_grad = True 102 | self.dropout = nn.Dropout(p=0.5) 103 | self.fc = nn.Linear(768, 15) 104 | 105 | def forward(self, token_ids, segment_ids): 106 | encoded_layers, pooled_output = self.model(token_ids, segment_ids) 107 | # 取最后一个输出层的第一个位置 108 | cls_rep = self.dropout(encoded_layers[:, 0]) 109 | out = self.fc(cls_rep) 110 | return out 111 | 112 | # 构建dataset 113 | train_dataset = MyDataset(X_train, y_train) 114 | test_dataset = MyDataset(X_test, y_test) 115 | #构建dataloader 116 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 117 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 118 | 119 | # 模型实例化 120 | model = Model(config_path, checkpoint_path).to(device) 121 | # 定义损失函数 122 | critertion = nn.CrossEntropyLoss() 123 | # 权重衰减,layernorm层,以及每一层的bias不进行权重衰减 124 | param_optimizer = list(model.named_parameters()) 125 | no_decay = ['bias', 'layerNorm'] 126 | optimizer_grouped_parameters = [ 127 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 128 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 129 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate) 130 | # 使用warmup 131 | num_training_steps = (len(train_dataloader) + 1) * epochs 132 | num_warmup_steps = num_training_steps * 0.05 133 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) 134 | 135 | total_step = len(train_dataloader) 136 | loss_list = [] 137 | test_acc_list = [] 138 | 139 | 140 | best_acc = 0.0 141 | model.train() 142 | for epoch in range(epochs): 143 | start = time.time() 144 | for i, (token_ids, segment_ids, labels) in enumerate(train_dataloader): 145 | optimizer.zero_grad() 146 | token_ids = token_ids.to(device) 147 | segment_ids = segment_ids.to(device) 148 | labels = labels.to(device) 149 | outputs = model(token_ids, segment_ids) 150 | loss = critertion(outputs, labels) 151 | loss.backward() 152 | optimizer.step() 153 | scheduler.step() 154 | loss_list.append(loss.item()) 155 | 156 | if (i % 100) == 0: 157 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, spend time: {:.4f}' 158 | .format(epoch + 1, epochs, i + 1, total_step, loss.item(), time.time() - start)) 159 | start = time.time() 160 | model.eval() 161 | with torch.no_grad(): 162 | correct = 0 163 | total = 0 164 | for i, (token_ids, segment_ids, labels) in enumerate(test_dataloader): 165 | token_ids = token_ids.to(device) 166 | segment_ids = segment_ids.to(device) 167 | labels = labels.to(device) 168 | outputs = model(token_ids, segment_ids) 169 | _, predicted = torch.max(outputs.data, 1) 170 | total += labels.size(0) 171 | correct += (predicted == labels).sum().item() 172 | 173 | test_acc = correct / total 174 | test_acc_list.append(test_acc) 175 | 176 | print('Epoch [{}/{}], train_acc: {:.6f}' 177 | .format(epoch + 1, epochs, test_acc)) 178 | model.train() 179 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='bert4pytorch', 7 | version='0.1.3', 8 | description='an elegant bert4pytorch', 9 | long_description='bert4pytorch: ', 10 | license='Apache License 2.0', 11 | url='https://github.com/MuQiuJun-AI/bert4pytorch', 12 | author='MuQiuJun', 13 | install_requires=['torch>1.0', 'numpy>=1.17'], 14 | packages=find_packages() 15 | ) --------------------------------------------------------------------------------