├── README.md
├── convert_bert_torch_to_tf.py
├── convert_ernie_paddle_to_tf_bert.py
├── run_classifier.py
├── run_dureader2021.py
├── run_element_extract.py
├── run_ner.py
├── run_pretraining.py
├── run_semeval2010_re.py
├── simple_trainer.py
└── tfbert
├── __init__.py
├── adversarial.py
├── config
├── __init__.py
├── base.py
└── ptm.py
├── data
├── __init__.py
├── classification.py
├── dataset.py
├── mrc.py
├── ner.py
└── pretrain.py
├── metric
├── __init__.py
├── dureader2021.py
├── multi_label.py
└── ner.py
├── models
├── __init__.py
├── activations.py
├── albert.py
├── base.py
├── bert.py
├── crf.py
├── electra.py
├── embeddings.py
├── for_task.py
├── glyce_bert.py
├── layers.py
├── loss.py
├── model_utils.py
└── nezha.py
├── optimization
├── __init__.py
├── adamw.py
├── create_optimizer.py
├── lamb.py
└── schedule.py
├── serving.py
├── tokenizer
├── __init__.py
├── albert.py
├── bert.py
├── glyce_bert.py
├── tokenization_base.py
└── wobert.py
├── trainer.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # tfbert
2 | - 基于tensorflow 1.x 的bert系列预训练模型工具
3 | - 支持多GPU训练,支持梯度累积,支持pb模型导出,自动剔除adam参数
4 | - 采用dataset 和 string handle配合,可以灵活训练、验证、测试,在训练阶段也可以使用验证集测试模型,并根据验证结果保存参数。
5 |
6 |
7 | ## 说明
8 |
9 |
10 | config、tokenizer参考的transformers的实现。
11 |
12 | 内置有自定义的Trainer,像pytorch一样使用tensorflow1.14,具体使用下边会介绍。
13 |
14 | 目前内置 [文本分类](run_classifier.py)、[文本多标签分类](run_element_extract.py)、[命名实体识别](run_ner.py)例子。
15 |
16 | 内置的几个例子的数据处理代码都支持多进程处理,实现方式参考的transformers。
17 |
18 | 内置代码示例数据集[百度网盘提取码:rhxk](https://pan.baidu.com/s/1lYy7BJdadT0LJfMSsKz6AA)
19 | ## 支持模型
20 |
21 | bert、electra、albert、nezha、wobert、ChineseBert(GlyceBert)
22 |
23 | ## requirements
24 | ```
25 | tensorflow==1.x
26 | tqdm
27 | jieba
28 | ```
29 | 目前本项目都是在tensorflow 1.x下实现并测试的,最好使用1.14及以上版本,因为内部tf导包都是用的
30 |
31 | import tensorflow.compat.v1 as tf
32 |
33 | ## **使用说明**
34 | #### **Config 和 Tokenizer**
35 | 使用方法和transformers一样
36 | ```python
37 | from tfbert import BertTokenizer, BertConfig
38 |
39 | config = BertConfig.from_pretrained('config_path')
40 | tokenizer = BertTokenizer.from_pretrained('vocab_path', do_lower_case=True)
41 |
42 | inputs = tokenizer.encode_plus(
43 | '测试样例', text_pair=None, max_length=128, padding="max_length", add_special_tokens=True)
44 |
45 | config.save_pretrained("save_path")
46 | tokenizer.save_pretrained("save_path")
47 |
48 | ```
49 | 多卡运行方式,需要设置环境变量CUDA_VISIBLE_DEVICES,内置trainer会读取参数:
50 | ```
51 | CUDA_VISIBLE_DEVICES=1,2 python run.py
52 | ```
53 | 详情查看代码样例
54 |
55 | ## **XLA和混合精度训练训练速度测试**
56 |
57 | 使用哈工大的rbt3权重进行实验对比,数据为example中的文本分类数据集。
58 | 开启xla和混合精度后刚开始训练需要等待一段时间优化,所以第一轮会比较慢,
59 | 等开启后训练速度会加快很多。最大输入长度32,批次大小32,训练3个epoch,
60 | 测试环境为tensorflow1.14,GPU是2080ti。
61 |
62 | | use_xla | mixed_precision | first epoch (s/epoch) | second epoch (s/epoch) | eval accuracy |
63 | | :------: | :------: | :------: | :------: | :------: |
64 | | False | False | 76 | 61 | 0.9570 |
65 | | True | False | 73 | 42 | 0.9584 |
66 | | True | True | 85 | 37 | 0.9582 |
67 |
68 | 开启混合精度比较慢,base版本模型的话需要一两分钟,但是开启后越到后边越快,训练步数少的话可以只开启xla就行了,如果多的话
69 | 最好xla和混合精度(混合精度前提是你的卡支持fp16)都打开。
70 |
71 | ## 可加载中文权重链接
72 | | 模型简称 | 下载链接 |
73 | | :------- | :--------- |
74 | | **`BERT wwm 系列`** | **[Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)**|
75 | | **`BERT-base, Chinese`Google** | [Google Cloud](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) |
76 | | **`ALBERT-base, Chinese`Google** | [google-research/albert](https://github.com/google-research/albert) |
77 | | **`MacBERT, Chinese`** | **[MacBERT](https://github.com/ymcui/MacBERT)**|
78 | | **`ELECTRA, Chinese`** | **[Chinese-ELECTRA](https://github.com/ymcui/Chinese-ELECTRA)**|
79 | | **`ERNIE 1.0.1, Chinese`** | **[百度网盘(xrku)](https://pan.baidu.com/s/13eRD6uVnr4xeUfYXk8XKIw)**|
80 | | **`ERNIE gram base, Chinese`** | **[百度网盘(7xet)](https://pan.baidu.com/s/1qzIuduI2ZRJDZSnNqTfscw)**|
81 | | **`ChineseBert, Chinese`** | **[base(sxhj)](https://pan.baidu.com/s/1ehO52PQd6TFVhOu5RiRtZA)** **[large(zi0r)](https://pan.baidu.com/s/1IifQuRFhpwWzLJHvMR9gOQ)**|
82 |
83 |
84 | ## **更新记录**
85 | - 2021/7/31 内置模型新增香侬科技开源的ChineseBert,见[glyce_bert](tfbert/models/glyce_bert.py),目前官方只有torch版本。
86 | 模型增加了字形和拼音特征作为embedding表示,获得了和mac bert接近的效果,官方见[ChineseBert](https://github.com/ShannonAI/ChineseBert)
87 | 。tf权重已经转好,可自行下载。
88 |
89 | - 2021/5/19 增加机器阅读理解示例代码,以dureader2021比赛数据为例,应该兼容大部分squad格式的数据。
90 | 同时更新tokenizer代码,贴合transformers使用接口,大部分直接整合的transformers的tokenizer
91 |
92 | - 2021/5/9 增加fgm,pgd,freelb接口,代码见[adversarial.py](tfbert/adversarial.py),
93 | 使用方式,在trainer的build_model中传入adversarial_type即可,这两天没GPU和相应数据集,所以功能还没测试。
94 |
95 | - 2021/4/18 花了一天时间重整Trainer,新增一个Dataset类。由于更新有点多,还没来得及写太多注释,敬请见谅。具体更新:
96 | 1. trainer封装了train、evaluate、predict方法,具体见新版的使用例子。
97 | 2. 写了一个Dataset类,支持简单的数据包装,也可以直接导出tf的dataset类型,具体[dataset.py](tfbert/data/dataset.py).
98 | 3. 去除了原版需要自定义shapes和types的方式(原有data代码还没删),都可以通过新增的Dataset类下的方法直接自行获取。
99 |
100 |
101 | - 2021/4/17 新增SimpleTrainer,采用feed dict的方式进行调用,操作简单,但是相比Trainer的dataset方式要慢好多,
102 | 随便写了个例子[simple_trainer.py](simple_trainer.py),以后有时间再完善
103 | - tf.layers.dropout 需要将training设置为None才会根据tf.keras.backend.learning_phase()进行mode判定。
104 | 之前默认的training为False,dropout都没起作用,非常抱歉。
105 | - 增加resize_word_embeddings方法,可对已保存权重文件的embedding部分就行词表大小修改。
106 | 具体见[resize_word_embeddings方法](tfbert/utils.py)
107 | - 对抗训练暂不可用...代码实现错误
108 | - 2021年2月22日 增加FGM对抗训练方式,可以在trainer.build_model()时设置use_fgm为True,
109 | 即可开启fgm对抗训练,目前未测试效果。
110 |
111 | - 2021年2月8日 毕业论文写完了,花了点时间进行大更新,此次更新对原有代码重组,进一步提升训练速度。使用NVIDIA的方法修改梯度累积方式,去除backward、
112 | zero_grad方法,统一使用train_step训练。梯度累积只需要在配置优化节点时传入梯度累积步数即可。
113 | 最后,代码增加xla加速和混合精度训练,混合精度目前只支持部分gpu,支持情况自行百度。
114 | 最后详情请自行看使用例子对比。
115 |
116 | - 2020年11月14日 增加xla加速模块,可以在trainer设定use_xla传参决定是否开启,开启后可以加速训练。backward、zero_grad、train_step模式增加开启关闭操作,
117 | 可以在trainer设定use_torch_mode决定是否取消该模式,取消后不支持梯度累积,直接调用train_step进行训练,
118 | 这样会加快训练速度。
119 |
120 | - 2020年9月23日 增加梯度累积,采用trainer.backward(), trainer.zero_grad(), trainer.train_step() 一同进行训练,参考pytorch训练方式。
121 | - 2020年9月21日 第一次上传,支持模型bert、albert、electra、nezha、wobert。
122 |
123 | **Reference**
124 | 1. [Transformers: State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch. ](https://github.com/huggingface/transformers)
125 | 2. [TensorFlow code and pre-trained models for BERT](https://github.com/google-research/bert)
126 | 3. [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://github.com/google-research/albert)
127 | 4. [NEZHA-TensorFlow](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-TensorFlow)
128 | 5. [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://github.com/google-research/electra)
129 | 6. [基于词颗粒度的中文WoBERT](https://github.com/ZhuiyiTechnology/WoBERT)
130 | 7. [NVIDIA/BERT模型使用方案](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT)
131 |
--------------------------------------------------------------------------------
/convert_bert_torch_to_tf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | # author : 'huanghui'
3 | # date : '2021/5/29 8:23'
4 | # project: 'tfbert'
5 | import os
6 | import argparse
7 | import tensorflow.compat.v1 as tf
8 | import numpy as np
9 | import shutil
10 |
11 | import torch
12 |
13 |
14 | def convert_pytorch_checkpoint_to_tf(pt_weight_file, pt_config_file, pt_vocab_file, save_dir: str):
15 | tensors_to_transpose = (
16 | "dense.weight", "attention.self.query", "attention.self.key", "attention.self.value", "glyph_map.weight",
17 | "map_fc.weight")
18 | glyce_bert_conv_tensors = ("conv.weight",)
19 |
20 | var_map = (
21 | ("layer.", "layer_"),
22 | ("word_embeddings.weight", "word_embeddings"),
23 | ("position_embeddings.weight", "position_embeddings"),
24 | ("token_type_embeddings.weight", "token_type_embeddings"),
25 | ("pinyin_embeddings.embedding.weight", "pinyin_embeddings/embeddings"),
26 | ("glyph_embeddings.embedding.weight", "glyph_embeddings/embeddings"),
27 | (".", "/"),
28 | ("LayerNorm/weight", "LayerNorm/gamma"),
29 | ("LayerNorm/bias", "LayerNorm/beta"),
30 | ("weight", "kernel"),
31 | )
32 |
33 | if not os.path.isdir(save_dir):
34 | os.makedirs(save_dir)
35 |
36 | state_dict = torch.load(pt_weight_file, map_location='cpu')
37 |
38 | def to_tf_var_name(name: str):
39 | for patt, repl in iter(var_map):
40 | name = name.replace(patt, repl)
41 | return f"{name}"
42 |
43 | def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
44 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
45 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
46 | session.run(tf.variables_initializer([tf_var]))
47 | session.run(tf_var)
48 | return tf_var
49 |
50 | tf.reset_default_graph()
51 | with tf.Session() as session:
52 | for var_name in state_dict:
53 | tf_name = to_tf_var_name(var_name)
54 | torch_tensor = state_dict[var_name].numpy()
55 | if any([x in var_name for x in tensors_to_transpose]):
56 | torch_tensor = torch_tensor.T
57 | if any([x in var_name for x in glyce_bert_conv_tensors]):
58 | torch_tensor = torch_tensor.T
59 | torch_tensor = np.expand_dims(torch_tensor, axis=2)
60 |
61 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
62 | tf.keras.backend.set_value(tf_var, torch_tensor)
63 | tf_weight = session.run(tf_var)
64 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
65 |
66 | saver = tf.train.Saver(tf.trainable_variables())
67 | saver.save(session, os.path.join(save_dir, "model.ckpt"))
68 | if os.path.exists(os.path.join(save_dir, 'checkpoint')):
69 | try:
70 | os.remove(os.path.join(save_dir, 'checkpoint'))
71 | print(
72 | "We will delete the checkpoint file to avoid errors in loading weights "
73 | "using tf.train.latest_checkpoint api.")
74 | except:
75 | pass
76 | if pt_config_file is not None and os.path.exists(pt_config_file):
77 | shutil.copyfile(pt_config_file, os.path.join(save_dir, 'config.json'))
78 | if pt_vocab_file is not None and os.path.exists(pt_vocab_file):
79 | shutil.copyfile(pt_vocab_file, os.path.join(save_dir, 'vocab.txt'))
80 |
81 | config_path = os.path.join(os.path.split(pt_config_file)[0], 'config')
82 | target_dir = os.path.join(save_dir, 'config')
83 | if os.path.isdir(config_path) and not os.path.exists(target_dir):
84 | os.makedirs(target_dir)
85 | shutil.copytree(config_path, target_dir)
86 |
87 |
88 | def main():
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument("--pytorch_model_dir", type=str, default=None, help="pytorch 模型文件所在文件夹")
91 | parser.add_argument("--pt_weight_file", type=str, default=None, help="pytorch 权重文件")
92 | parser.add_argument("--pt_config_file", type=str, default=None, help="pytorch 配置文件名")
93 | parser.add_argument("--pt_vocab_file", type=str, default=None, help="pytorch 词典文件")
94 | parser.add_argument(
95 | "--save_dir", type=str, default=None, required=True, help="转换后权重保存文件夹"
96 | )
97 | args = parser.parse_args()
98 | if args.pytorch_model_dir is not None:
99 | args.pt_weight_file = os.path.join(args.pytorch_model_dir, 'pytorch_model.bin')
100 | args.pt_config_file = os.path.join(args.pytorch_model_dir, 'config.json')
101 | args.pt_vocab_file = os.path.join(args.pytorch_model_dir, 'vocab.txt')
102 | convert_pytorch_checkpoint_to_tf(
103 | args.pt_weight_file, args.pt_config_file, args.pt_vocab_file,
104 | args.save_dir
105 | )
106 |
107 |
108 | if __name__ == '__main__':
109 | main()
110 |
--------------------------------------------------------------------------------
/convert_ernie_paddle_to_tf_bert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | # author : 'huanghui'
3 | # date : '2021/5/25 22:05'
4 | # project: 'tfbert'
5 | import json
6 | import os
7 | import paddle
8 | import collections
9 | import numpy as np
10 | import argparse
11 | import tensorflow.compat.v1 as tf
12 | from tfbert import BertConfig
13 |
14 |
15 | def build_params_map_to_pt(num_layers=12):
16 | """
17 | build params map from paddle-paddle's ERNIE to transformer's BERT
18 | :return:
19 | """
20 | weight_map = collections.OrderedDict({
21 | 'word_emb.weight': "bert.embeddings.word_embeddings.weight",
22 | 'pos_emb.weight': "bert.embeddings.position_embeddings.weight",
23 | 'sent_emb.weight': "bert.embeddings.token_type_embeddings.weight",
24 | 'ln.weight': 'bert.embeddings.LayerNorm.gamma',
25 | 'ln.bias': 'bert.embeddings.LayerNorm.beta',
26 | })
27 | for i in range(num_layers):
28 | weight_map[f'encoder_stack.block.{i}.attn.q.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
29 | weight_map[f'encoder_stack.block.{i}.attn.q.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
30 | weight_map[f'encoder_stack.block.{i}.attn.k.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
31 | weight_map[f'encoder_stack.block.{i}.attn.k.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
32 | weight_map[f'encoder_stack.block.{i}.attn.v.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
33 | weight_map[f'encoder_stack.block.{i}.attn.v.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
34 | weight_map[f'encoder_stack.block.{i}.attn.o.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
35 | weight_map[f'encoder_stack.block.{i}.attn.o.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
36 | weight_map[f'encoder_stack.block.{i}.ln1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.gamma'
37 | weight_map[f'encoder_stack.block.{i}.ln1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.beta'
38 | weight_map[f'encoder_stack.block.{i}.ffn.i.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
39 | weight_map[f'encoder_stack.block.{i}.ffn.i.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
40 | weight_map[f'encoder_stack.block.{i}.ffn.o.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
41 | weight_map[f'encoder_stack.block.{i}.ffn.o.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
42 | weight_map[f'encoder_stack.block.{i}.ln2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.gamma'
43 | weight_map[f'encoder_stack.block.{i}.ln2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.beta'
44 | # add pooler
45 | weight_map.update(
46 | {
47 | 'pooler.weight': 'bert.pooler.dense.weight',
48 | 'pooler.bias': 'bert.pooler.dense.bias',
49 | 'mlm.weight': 'cls.predictions.transform.dense.weight',
50 | 'mlm.bias': 'cls.predictions.transform.dense.bias',
51 | 'mlm_ln.weight': 'cls.predictions.transform.LayerNorm.gamma',
52 | 'mlm_ln.bias': 'cls.predictions.transform.LayerNorm.beta',
53 | 'mlm_bias': 'cls.predictions.bias'
54 | }
55 | )
56 |
57 | return weight_map
58 |
59 |
60 | def build_config(paddle_config_file):
61 | ernie_config = json.load(open(paddle_config_file, 'r', encoding='utf-8'))
62 | if 'sent_type_vocab_size' in ernie_config:
63 | ernie_config['type_vocab_size'] = ernie_config['sent_type_vocab_size']
64 | config = BertConfig(
65 | **ernie_config
66 | )
67 | return config
68 |
69 |
70 | def convert_paddle_checkpoint_to_tf(
71 | paddle_weight_file, paddle_config_file, paddle_vocab_file, save_dir):
72 | params = paddle.load(paddle_weight_file)
73 | config = build_config(paddle_config_file)
74 | weight_map = build_params_map_to_pt(config.num_hidden_layers)
75 |
76 | var_map = (
77 | ("layer.", "layer_"),
78 | ("word_embeddings.weight", "word_embeddings"),
79 | ("position_embeddings.weight", "position_embeddings"),
80 | ("token_type_embeddings.weight", "token_type_embeddings"),
81 | (".", "/"),
82 | ("LayerNorm/weight", "LayerNorm/gamma"),
83 | ("LayerNorm/bias", "LayerNorm/beta"),
84 | ("weight", "kernel"),
85 | )
86 |
87 | if not os.path.isdir(save_dir):
88 | os.makedirs(save_dir)
89 |
90 | def to_tf_var_name(name: str):
91 | for patt, repl in iter(var_map):
92 | name = name.replace(patt, repl)
93 | return name
94 |
95 | def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
96 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
97 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
98 | session.run(tf.variables_initializer([tf_var]))
99 | session.run(tf_var)
100 | return tf_var
101 |
102 | tf.reset_default_graph()
103 | with tf.Session() as session:
104 | for pd_name, pd_var in params.items():
105 | tf_name = to_tf_var_name(weight_map[pd_name])
106 | pd_tensor = pd_var.numpy()
107 | tf_var = create_tf_var(tensor=pd_tensor, name=tf_name, session=session)
108 | tf.keras.backend.set_value(tf_var, pd_tensor)
109 | tf_weight = session.run(tf_var)
110 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, pd_tensor)))
111 |
112 | saver = tf.train.Saver(tf.trainable_variables())
113 | saver.save(session, os.path.join(save_dir, "model.ckpt"))
114 |
115 | if os.path.exists(os.path.join(save_dir, 'checkpoint')):
116 | try:
117 | os.remove(os.path.join(save_dir, 'checkpoint'))
118 | print(
119 | "We will delete the checkpoint file to avoid errors in loading weights "
120 | "using tf.train.latest_checkpoint api.")
121 | except:
122 | pass
123 | config.save_pretrained(save_dir)
124 | # ernie gram 里边是vocab + \t + id
125 | with open(paddle_vocab_file, 'r', encoding='utf-8') as f, \
126 | open(os.path.join(save_dir, "vocab.txt"), 'w', encoding='utf-8') as w:
127 | for line in f:
128 | line = line.strip()
129 | if '\t' in line:
130 | line = line.split('\t')[0]
131 | if line:
132 | w.write(line + '\n')
133 |
134 |
135 | def main():
136 | parser = argparse.ArgumentParser()
137 | parser.add_argument("--paddle_weight_file", type=str, required=True, help="paddle 权重文件,只支持动态图的权重文件")
138 | parser.add_argument("--paddle_config_file", type=str, required=True, help="paddle 配置文件名")
139 | parser.add_argument("--paddle_vocab_file", type=str, required=True, help="paddle 词典文件")
140 | parser.add_argument(
141 | "--save_dir", type=str, default=None, required=True, help="转换后权重保存文件夹"
142 | )
143 | args = parser.parse_args()
144 | convert_paddle_checkpoint_to_tf(
145 | args.paddle_weight_file, args.paddle_config_file, args.paddle_vocab_file,
146 | args.save_dir
147 | )
148 |
149 |
150 | if __name__ == '__main__':
151 | main()
152 |
--------------------------------------------------------------------------------
/run_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | __author__ = 'huanghui'
3 | __date__ = '2021/4/18 15:06'
4 | __project__ = 'tfbert'
5 |
6 | import json
7 | import os
8 | import argparse
9 | import tensorflow.compat.v1 as tf
10 | from tfbert import (
11 | Trainer, Dataset,
12 | SequenceClassification,
13 | CONFIGS, TOKENIZERS, devices, set_seed)
14 | from tfbert.data.classification import convert_examples_to_features, InputExample
15 | from sklearn.metrics import accuracy_score
16 | import pandas as pd
17 | from typing import Dict
18 | import numpy as np
19 |
20 |
21 | def create_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--model_type', default='bert', type=str, choices=CONFIGS.keys())
24 | parser.add_argument('--optimizer_type', default='adamw', type=str, help="优化器类型")
25 | parser.add_argument('--model_dir', default='model_path', type=str,
26 | help="预训练模型存放文件夹,文件夹下ckpt文件名为model.ckpt,"
27 | "config文件名为config.json,词典文件名为vocab.txt")
28 |
29 | parser.add_argument('--config_path', default=None, type=str, help="若配置文件名不是默认的,可在这里输入")
30 | parser.add_argument('--vocab_path', default=None, type=str, help="若词典文件名不是默认的,可在这里输入")
31 | parser.add_argument('--pretrained_checkpoint_path', default=None, type=str, help="若模型文件名不是默认的,可在这里输入")
32 | parser.add_argument('--output_dir', default='output/classification', type=str, help="")
33 | parser.add_argument('--export_dir', default='output/classification/pb', type=str, help="")
34 |
35 | parser.add_argument('--labels', default='体育,娱乐,家居,房产,教育', type=str, help="文本分类标签")
36 | parser.add_argument('--train_file', default='data/classification/train.csv', type=str, help="")
37 | parser.add_argument('--dev_file', default='data/classification/dev.csv', type=str, help="")
38 | parser.add_argument('--test_file', default='data/classification/test.csv', type=str, help="")
39 |
40 | parser.add_argument("--num_train_epochs", default=3, type=int, help="训练轮次")
41 | parser.add_argument("--max_seq_length", default=32, type=int, help="最大句子长度")
42 | parser.add_argument("--batch_size", default=32, type=int, help="训练批次")
43 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="梯度累积")
44 | parser.add_argument("--learning_rate", default=2e-5, type=float, help="学习率")
45 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
46 | help="Proportion of training to perform linear learning rate warmup for.")
47 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.")
48 |
49 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
50 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
51 | parser.add_argument("--do_predict", action="store_true", help="Whether to run test on the test set.")
52 | parser.add_argument("--evaluate_during_training", action="store_true", help="是否边训练边验证")
53 | parser.add_argument("--do_export", action="store_true", help="将模型导出为pb格式.")
54 |
55 | parser.add_argument("--logging_steps", default=1000, type=int, help="训练时每隔几步验证一次")
56 | parser.add_argument("--saving_steps", default=1000, type=int, help="训练时每隔几步保存一次")
57 | parser.add_argument("--random_seed", default=42, type=int, help="随机种子")
58 | parser.add_argument("--threads", default=8, type=int, help="数据处理进程数")
59 | parser.add_argument("--max_checkpoints", default=1, type=int, help="模型保存最大数量,默认只保存一个")
60 | parser.add_argument("--single_device", action="store_true", help="是否只使用一个device,默认使用所有的device训练")
61 | parser.add_argument("--use_xla", action="store_true", help="是否使用XLA加速")
62 | parser.add_argument(
63 | "--mixed_precision", action="store_true",
64 | help="混合精度训练,tf下测试需要同时使用xla才有加速效果,但是开始编译很慢")
65 | args = parser.parse_args()
66 |
67 | if not os.path.exists(args.output_dir):
68 | os.makedirs(args.output_dir)
69 |
70 | if not args.single_device:
71 | args.batch_size = args.batch_size * len(devices())
72 |
73 | args.labels = args.labels.split(',')
74 | return args
75 |
76 |
77 | def create_dataset(set_type, tokenizer, args):
78 | filename_map = {
79 | 'train': args.train_file, 'dev': args.dev_file, 'test': args.test_file
80 | }
81 | examples = []
82 | datas = pd.read_csv(filename_map[set_type], encoding='utf-8', sep='\t').values.tolist()
83 | for data in datas:
84 | examples.append(InputExample(
85 | guid=0, text_a=data[1], label=data[0]
86 | ))
87 | features = convert_examples_to_features(
88 | examples, tokenizer,
89 | max_length=args.max_seq_length, set_type=set_type,
90 | label_list=args.labels, threads=args.threads)
91 | dataset = Dataset(features,
92 | is_training=bool(set_type == 'train'),
93 | batch_size=args.batch_size,
94 | drop_last=bool(set_type == 'train'),
95 | buffer_size=len(features),
96 | max_length=args.max_seq_length)
97 | columns = ['input_ids', 'attention_mask', 'token_type_ids', 'label_ids']
98 | if "pinyin_ids" in features[0] and features[0]['pinyin_ids'] is not None:
99 | columns = ['input_ids', 'attention_mask', 'token_type_ids', 'pinyin_ids', 'label_ids']
100 | dataset.format_as(columns)
101 | return dataset
102 |
103 |
104 | def get_model_fn(config, args):
105 | def model_fn(inputs, is_training):
106 | model = SequenceClassification(
107 | model_type=args.model_type, config=config,
108 | num_classes=len(args.labels), is_training=is_training,
109 | **inputs)
110 |
111 | outputs = {'outputs': {'logits': model.logits, 'label_ids': inputs['label_ids']}}
112 | if model.loss is not None:
113 | loss = model.loss / args.gradient_accumulation_steps
114 | outputs['loss'] = loss
115 | return outputs
116 |
117 | return model_fn
118 |
119 |
120 | def get_serving_fn(config, args):
121 | def serving_fn():
122 | input_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='input_ids')
123 | attention_mask = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='attention_mask')
124 | token_type_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='token_type_ids')
125 | if args.model_type == 'glyce_bert':
126 | pinyin_ids = tf.placeholder(shape=[None, args.max_seq_length, 8], dtype=tf.int64, name='pinyin_ids')
127 | else:
128 | pinyin_ids = None
129 | model = SequenceClassification(
130 | model_type=args.model_type, config=config,
131 | num_classes=len(args.labels), is_training=False,
132 | input_ids=input_ids,
133 | pinyin_ids=pinyin_ids,
134 | attention_mask=attention_mask,
135 | token_type_ids=token_type_ids
136 | )
137 | inputs = {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids}
138 | if pinyin_ids is not None:
139 | inputs['pinyin_ids'] = pinyin_ids
140 | outputs = {'logits': model.logits}
141 | return inputs, outputs
142 |
143 | return serving_fn
144 |
145 |
146 | def metric_fn(outputs: Dict) -> Dict:
147 | """
148 | 这里定义评估函数
149 | :param outputs: trainer evaluate 返回的预测结果,model fn的outputs包含哪些字段就会有哪些字段
150 | :return: 需要返回字典结果
151 | """
152 | predictions = np.argmax(outputs['logits'], -1)
153 | score = accuracy_score(outputs['label_ids'], predictions)
154 | return {'accuracy': score}
155 |
156 |
157 | def main():
158 | args = create_args()
159 | set_seed(args.random_seed)
160 |
161 | config = CONFIGS[args.model_type].from_pretrained(
162 | args.model_dir if args.config_path is None else args.config_path)
163 |
164 | tokenizer = TOKENIZERS[args.model_type].from_pretrained(
165 | args.model_dir if args.vocab_path is None else args.vocab_path, do_lower_case=True)
166 |
167 | train_dataset, dev_dataset, predict_dataset = None, None, None
168 | if args.do_train:
169 | train_dataset = create_dataset('train', tokenizer, args)
170 |
171 | if args.do_eval:
172 | dev_dataset = create_dataset('dev', tokenizer, args)
173 |
174 | if args.do_predict:
175 | predict_dataset = create_dataset('test', tokenizer, args)
176 |
177 | output_types, output_shapes = (train_dataset or dev_dataset or predict_dataset).output_types_and_shapes()
178 | trainer = Trainer(
179 | train_dataset=train_dataset,
180 | eval_dataset=dev_dataset,
181 | output_types=output_types,
182 | output_shapes=output_shapes,
183 | metric_fn=metric_fn,
184 | use_xla=args.use_xla,
185 | optimizer_type=args.optimizer_type,
186 | learning_rate=args.learning_rate,
187 | num_train_epochs=args.num_train_epochs,
188 | gradient_accumulation_steps=args.gradient_accumulation_steps,
189 | max_checkpoints=1,
190 | max_grad=1.0,
191 | warmup_proportion=args.warmup_proportion,
192 | mixed_precision=args.mixed_precision,
193 | single_device=args.single_device,
194 | logging=True
195 | )
196 | trainer.build_model(model_fn=get_model_fn(config, args))
197 | if args.do_train and train_dataset is not None:
198 | # 训练阶段需要先compile优化器才能初始化权重
199 | # 因为adam也是具备参数的
200 | trainer.compile()
201 |
202 | trainer.from_pretrained(
203 | args.model_dir if args.pretrained_checkpoint_path is None else args.pretrained_checkpoint_path)
204 | if args.do_train and train_dataset is not None:
205 | trainer.train(
206 | output_dir=args.output_dir,
207 | evaluate_during_training=args.evaluate_during_training,
208 | logging_steps=args.logging_steps,
209 | saving_steps=args.saving_steps,
210 | greater_is_better=True,
211 | load_best_model=True,
212 | metric_for_best_model='accuracy')
213 | config.save_pretrained(args.output_dir)
214 | tokenizer.save_pretrained(args.output_dir)
215 |
216 | if args.do_eval and dev_dataset is not None:
217 | eval_outputs = trainer.evaluate()
218 | print(json.dumps(
219 | eval_outputs, ensure_ascii=False, indent=4
220 | ))
221 |
222 | if args.do_predict and predict_dataset is not None:
223 | outputs = trainer.predict('test', ['logits'], dataset=predict_dataset)
224 | label_ids = np.argmax(outputs['logits'], axis=-1)
225 | labels = list(map(lambda x: args.labels[x], label_ids))
226 | open(
227 | os.path.join(args.output_dir, 'prediction.txt'), 'w', encoding='utf-8'
228 | ).write("\n".join(labels))
229 |
230 | if args.do_export:
231 | trainer.export(
232 | get_serving_fn(config, args),
233 | args.output_dir,
234 | args.export_dir
235 | )
236 |
237 |
238 | if __name__ == '__main__':
239 | main()
240 |
--------------------------------------------------------------------------------
/run_pretraining.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | # author : 'huanghui'
3 | # date : '2021/5/21 9:14'
4 | # project: 'tfbert'
5 | """
6 | 使用 gen dataset 实现动态mask,但是数据量大的话可能搞不定,需要使用tfrecord。
7 | """
8 |
9 | import json
10 | import os
11 | import argparse
12 | import random
13 | import tensorflow.compat.v1 as tf
14 | from tfbert import (
15 | Trainer, MaskedLM,
16 | CONFIGS, TOKENIZERS, devices, set_seed,
17 | compute_types, compute_shapes, process_dataset)
18 | from tfbert.data.pretrain import create_masked_lm_predictions, convert_to_unicode
19 | from tfbert.tokenizer.tokenization_base import PTMTokenizer
20 | from typing import Dict, List
21 | from sklearn.metrics import accuracy_score
22 |
23 |
24 | def create_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--model_type', default='bert', type=str, choices=CONFIGS.keys())
27 | parser.add_argument('--optimizer_type', default='adamw', type=str, help="优化器类型")
28 | parser.add_argument('--model_dir', default='model_path', type=str,
29 | help="预训练模型存放文件夹,文件夹下ckpt文件名为model.ckpt,"
30 | "config文件名为config.json,词典文件名为vocab.txt")
31 |
32 | parser.add_argument('--config_path', default=None, type=str, help="若配置文件名不是默认的,可在这里输入")
33 | parser.add_argument('--vocab_path', default=None, type=str, help="若词典文件名不是默认的,可在这里输入")
34 | parser.add_argument('--pretrained_checkpoint_path', default=None, type=str, help="若模型文件名不是默认的,可在这里输入")
35 | parser.add_argument('--output_dir', default='output/pretrain', type=str, help="")
36 | parser.add_argument('--export_dir', default='output/pretrain/pb', type=str, help="")
37 |
38 | parser.add_argument('--train_dir', default='data/pretrain/train', type=str, help="训练文件所在文件夹")
39 | parser.add_argument('--dev_dir', default='data/pretrain/dev', type=str, help="验证文件所在文件夹")
40 |
41 | parser.add_argument("--num_train_epochs", default=10, type=int, help="训练轮次")
42 | parser.add_argument("--max_seq_length", default=128, type=int, help="最大句子长度")
43 | parser.add_argument("--batch_size", default=64, type=int, help="训练批次")
44 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="梯度累积")
45 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="学习率")
46 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
47 | help="Proportion of training to perform linear learning rate warmup for.")
48 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.")
49 |
50 | parser.add_argument("--masked_lm_prob", default=0.15, type=float, help="mask 概率.")
51 | parser.add_argument("--max_predictions_per_seq", default=20, type=int, help="最大mask数量.")
52 | parser.add_argument("--ngram", default=4, type=int, help="ngram mask 最大个数.")
53 |
54 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
55 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
56 | parser.add_argument("--evaluate_during_training", action="store_true", help="是否边训练边验证")
57 | parser.add_argument("--do_export", action="store_true", help="将模型导出为pb格式.")
58 | parser.add_argument("--do_whole_word_mask", action="store_true", help="全词mask.")
59 |
60 | parser.add_argument("--logging_steps", default=1000, type=int, help="训练时每隔几步验证一次")
61 | parser.add_argument("--saving_steps", default=1000, type=int, help="训练时每隔几步保存一次")
62 | parser.add_argument("--random_seed", default=42, type=int, help="随机种子")
63 | parser.add_argument("--max_checkpoints", default=1, type=int, help="模型保存最大数量,默认只保存一个")
64 | parser.add_argument("--single_device", action="store_true", help="是否只使用一个device,默认使用所有的device训练")
65 | parser.add_argument("--use_xla", action="store_true", help="是否使用XLA加速")
66 | parser.add_argument(
67 | "--mixed_precision", action="store_true",
68 | help="混合精度训练,tf下测试需要同时使用xla才有加速效果,但是开始编译很慢")
69 | args = parser.parse_args()
70 |
71 | if not os.path.exists(args.output_dir):
72 | os.makedirs(args.output_dir)
73 |
74 | if not args.single_device:
75 | args.batch_size = args.batch_size * len(devices())
76 |
77 | def find_files(dir_or_file):
78 | if os.path.isdir(dir_or_file):
79 | files = os.listdir(dir_or_file)
80 | files = [os.path.join(dir_or_file, file_) for file_ in files]
81 | elif isinstance(dir_or_file, str):
82 | files = [dir_or_file]
83 | else:
84 | files = []
85 | return files
86 |
87 | args.train_files = find_files(args.train_dir)
88 | args.dev_files = find_files(args.dev_dir)
89 |
90 | if len(args.dev_files) == 0:
91 | args.do_eval = False
92 | args.evaluate_during_training = False
93 |
94 | if len(args.train_files) == 0 and args.do_train:
95 | args.do_train = False
96 | tf.logging.warn("If you need to perform training, please ensure that the training file is not empty")
97 | return args
98 |
99 |
100 | def create_dataset(args, input_files, tokenizer: PTMTokenizer, set_type):
101 | if not isinstance(input_files, List):
102 | input_files = [input_files]
103 | all_tokens = []
104 | for input_file in input_files:
105 | with open(input_file, 'r', encoding='utf-8') as reader:
106 | for line in reader:
107 | line = convert_to_unicode(line)
108 | if not line:
109 | break
110 | line = line.strip()
111 | if not line:
112 | continue
113 | tokens = tokenizer.tokenize(line)
114 | tokens = tokens[:args.max_seq_length - 2]
115 | tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
116 |
117 | all_tokens.append(tokens)
118 | # 打乱
119 | random.shuffle(all_tokens)
120 |
121 | # 定义生成器,提供动态mask
122 | def dynamic_mask_gen():
123 | for tokens in all_tokens:
124 | output_tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
125 | tokens, args.masked_lm_prob,
126 | args.max_predictions_per_seq,
127 | list(tokenizer.vocab.keys()),
128 | do_whole_word_mask=args.do_whole_word_mask,
129 | favor_shorter_ngram=True,
130 | ngram=args.ngram
131 | )
132 |
133 | encoded = tokenizer.encode_plus(
134 | output_tokens, add_special_tokens=False, padding="max_length",
135 | truncation=True, max_length=args.max_seq_length
136 | )
137 | masked_lm_positions = list(masked_lm_positions)
138 | masked_lm_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels)
139 | masked_lm_weights = [1.0] * len(masked_lm_ids)
140 |
141 | while len(masked_lm_positions) < args.max_predictions_per_seq:
142 | masked_lm_positions.append(0)
143 | masked_lm_ids.append(0)
144 | masked_lm_weights.append(0.0)
145 | encoded.update(
146 | {'masked_lm_ids': masked_lm_ids,
147 | 'masked_lm_weights': masked_lm_weights,
148 | 'masked_lm_positions': masked_lm_positions}
149 | )
150 | yield encoded
151 |
152 | sample_example = {
153 | 'input_ids': [0], 'token_type_ids': [0], 'attention_mask': [0],
154 | 'masked_lm_ids': [0], 'masked_lm_weights': [0.0], 'masked_lm_positions': [0]
155 | }
156 | types = compute_types(sample_example)
157 | shapes = compute_shapes(sample_example)
158 | dataset = tf.data.Dataset.from_generator(
159 | dynamic_mask_gen, types, shapes
160 | )
161 | dataset, steps = process_dataset(
162 | dataset, args.batch_size, len(all_tokens), set_type, buffer_size=100)
163 | return dataset, steps
164 |
165 |
166 | def get_model_fn(config, args):
167 | def model_fn(inputs, is_training):
168 | model = MaskedLM(
169 | model_type=args.model_type,
170 | config=config,
171 | is_training=is_training,
172 | **inputs)
173 |
174 | masked_lm_ids = tf.reshape(inputs['masked_lm_ids'], [-1])
175 | masked_lm_log_probs = tf.reshape(model.prediction_scores,
176 | [-1, model.prediction_scores.shape[-1]])
177 | masked_lm_predictions = tf.argmax(
178 | masked_lm_log_probs, axis=-1, output_type=tf.int32)
179 | masked_lm_weights = tf.reshape(inputs['masked_lm_weights'], [-1])
180 |
181 | outputs = {'outputs': {
182 | 'masked_lm_predictions': masked_lm_predictions,
183 | 'masked_lm_ids': masked_lm_ids,
184 | 'masked_lm_weights': masked_lm_weights
185 | }}
186 | if model.loss is not None:
187 | loss = model.loss / args.gradient_accumulation_steps
188 | outputs['loss'] = loss
189 | return outputs
190 |
191 | return model_fn
192 |
193 |
194 | def metric_fn(outputs: Dict) -> Dict:
195 | """
196 | 这里定义评估函数
197 | :param outputs: trainer evaluate 返回的预测结果,model fn的outputs包含哪些字段就会有哪些字段
198 | :return: 需要返回字典结果
199 | """
200 | score = accuracy_score(outputs['masked_lm_ids'], outputs['masked_lm_predictions'],
201 | sample_weight=outputs['masked_lm_weights'])
202 | return {'accuracy': score}
203 |
204 |
205 | def main():
206 | args = create_args()
207 | set_seed(args.random_seed)
208 |
209 | config = CONFIGS[args.model_type].from_pretrained(
210 | args.model_dir if args.config_path is None else args.config_path)
211 |
212 | tokenizer = TOKENIZERS[args.model_type].from_pretrained(
213 | args.model_dir if args.vocab_path is None else args.vocab_path, do_lower_case=True)
214 |
215 | # tf 自带的dataset不知道怎么自动得到一轮需要步数
216 | # 因此提前算出来传入trainer
217 | train_dataset, train_steps, dev_dataset, dev_steps = None, 0, None, 0
218 | if args.do_train:
219 | train_dataset, train_steps = create_dataset(args, args.train_files, tokenizer, 'train')
220 |
221 | if args.do_eval:
222 | dev_dataset, dev_steps = create_dataset(args, args.dev_files, tokenizer, 'dev')
223 |
224 | trainer = Trainer(
225 | train_dataset=train_dataset,
226 | eval_dataset=dev_dataset,
227 | metric_fn=metric_fn,
228 | use_xla=args.use_xla,
229 | optimizer_type=args.optimizer_type,
230 | learning_rate=args.learning_rate,
231 | num_train_epochs=args.num_train_epochs,
232 | gradient_accumulation_steps=args.gradient_accumulation_steps,
233 | max_checkpoints=1,
234 | max_grad=1.0,
235 | warmup_proportion=args.warmup_proportion,
236 | mixed_precision=args.mixed_precision,
237 | single_device=args.single_device,
238 | logging=True
239 | )
240 | trainer.build_model(model_fn=get_model_fn(config, args))
241 | if args.do_train and train_dataset is not None:
242 | # 训练阶段需要先compile优化器才能初始化权重
243 | # 因为adam也是具备参数的
244 | trainer.compile()
245 | trainer.from_pretrained(
246 | args.model_dir if args.pretrained_checkpoint_path is None else args.pretrained_checkpoint_path)
247 | if args.do_train and train_dataset is not None:
248 |
249 | trainer.train(
250 | output_dir=args.output_dir,
251 | train_steps=train_steps, # 这个是一轮的步数
252 | eval_steps=dev_steps,
253 | evaluate_during_training=args.evaluate_during_training,
254 | logging_steps=args.logging_steps,
255 | saving_steps=args.saving_steps,
256 | greater_is_better=True,
257 | load_best_model=True,
258 | metric_for_best_model='accuracy')
259 | config.save_pretrained(args.output_dir)
260 | tokenizer.save_pretrained(args.output_dir)
261 |
262 | if args.do_eval and dev_dataset is not None:
263 | eval_outputs = trainer.evaluate(eval_steps=dev_steps)
264 | print(json.dumps(
265 | eval_outputs, ensure_ascii=False, indent=4
266 | ))
267 |
268 |
269 | if __name__ == '__main__':
270 | main()
271 |
--------------------------------------------------------------------------------
/run_semeval2010_re.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # Author : Huanghui
3 | # Project : tfbert
4 | # File Name: run_semeval2010_re
5 | # Date : 2021/7/23
6 | __author__ = 'huanghui'
7 | __date__ = '2021/4/18 15:06'
8 | __project__ = 'tfbert'
9 |
10 | import json
11 | import os
12 | import csv
13 | import argparse
14 | import tensorflow.compat.v1 as tf
15 | from tfbert import (
16 | Trainer, Dataset,
17 | SequenceClassification,
18 | CONFIGS, TOKENIZERS, devices, set_seed)
19 | from tfbert.data.classification import convert_examples_to_features, InputExample
20 | from sklearn.metrics import f1_score
21 | from typing import Dict
22 | import numpy as np
23 |
24 |
25 | def create_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--model_type', default='bert', type=str, choices=CONFIGS.keys())
28 | parser.add_argument('--optimizer_type', default='adamw', type=str, help="优化器类型")
29 | parser.add_argument('--model_dir', default='model_path', type=str,
30 | help="预训练模型存放文件夹,文件夹下ckpt文件名为model.ckpt,"
31 | "config文件名为config.json,词典文件名为vocab.txt")
32 |
33 | parser.add_argument('--config_path', default=None, type=str, help="若配置文件名不是默认的,可在这里输入")
34 | parser.add_argument('--vocab_path', default=None, type=str, help="若词典文件名不是默认的,可在这里输入")
35 | parser.add_argument('--pretrained_checkpoint_path', default=None, type=str, help="若模型文件名不是默认的,可在这里输入")
36 | parser.add_argument('--output_dir', default='output/semeval2010', type=str, help="")
37 | parser.add_argument('--export_dir', default='output/semeval2010/pb', type=str, help="")
38 |
39 | parser.add_argument('--label_file', default='data/semeval2010/label.txt', type=str, help="标签信息")
40 | parser.add_argument('--train_file', default='data/semeval2010/train.tsv', type=str, help="")
41 | parser.add_argument('--dev_file', default='data/semeval2010/test.tsv', type=str, help="")
42 | parser.add_argument('--test_file', default='data/semeval2010/test.tsv', type=str, help="")
43 |
44 | parser.add_argument("--num_train_epochs", default=5, type=int, help="训练轮次")
45 | parser.add_argument("--max_seq_length", default=64, type=int, help="最大句子长度")
46 | parser.add_argument("--batch_size", default=16, type=int, help="训练批次")
47 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="梯度累积")
48 | parser.add_argument("--learning_rate", default=2e-5, type=float, help="学习率")
49 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
50 | help="Proportion of training to perform linear learning rate warmup for.")
51 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.")
52 |
53 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
54 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
55 | parser.add_argument("--do_predict", action="store_true", help="Whether to run test on the test set.")
56 | parser.add_argument("--evaluate_during_training", action="store_true", help="是否边训练边验证")
57 | parser.add_argument("--do_export", action="store_true", help="将模型导出为pb格式.")
58 |
59 | parser.add_argument("--logging_steps", default=500, type=int, help="训练时每隔几步验证一次")
60 | parser.add_argument("--saving_steps", default=-1, type=int, help="训练时每隔几步保存一次")
61 | parser.add_argument("--random_seed", default=42, type=int, help="随机种子")
62 | parser.add_argument("--threads", default=8, type=int, help="数据处理进程数")
63 | parser.add_argument("--max_checkpoints", default=1, type=int, help="模型保存最大数量,默认只保存一个")
64 | parser.add_argument("--single_device", action="store_true", help="是否只使用一个device,默认使用所有的device训练")
65 | parser.add_argument("--use_xla", action="store_true", help="是否使用XLA加速")
66 | parser.add_argument(
67 | "--mixed_precision", action="store_true",
68 | help="混合精度训练,tf下测试需要同时使用xla才有加速效果,但是开始编译很慢")
69 | args = parser.parse_args()
70 |
71 | if not os.path.exists(args.output_dir):
72 | os.makedirs(args.output_dir)
73 |
74 | if not args.single_device:
75 | args.batch_size = args.batch_size * len(devices())
76 |
77 | args.labels = open(args.label_file, 'r', encoding='utf-8').read().strip().split("\n")
78 | return args
79 |
80 |
81 | def create_dataset(set_type, tokenizer, args):
82 | filename_map = {
83 | 'train': args.train_file, 'dev': args.dev_file, 'test': args.test_file
84 | }
85 | examples = []
86 |
87 | with open(filename_map[set_type], "r", encoding="utf-8") as f:
88 | reader = csv.reader(f, delimiter="\t")
89 | for line in reader:
90 | text = line[1].replace("", "[unused1]").replace("", "[unused2]").replace("",
91 | "[unused3]").replace(
92 | "", "[unused4]")
93 | examples.append(InputExample(
94 | guid=0, text_a=text, label=line[0]
95 | ))
96 |
97 | features = convert_examples_to_features(
98 | examples, tokenizer,
99 | max_length=args.max_seq_length, set_type=set_type,
100 | label_list=args.labels, threads=args.threads)
101 |
102 | dataset = Dataset(features,
103 | is_training=bool(set_type == 'train'),
104 | batch_size=args.batch_size,
105 | drop_last=bool(set_type == 'train'),
106 | buffer_size=len(features),
107 | max_length=args.max_seq_length)
108 | columns = ['input_ids', 'attention_mask', 'token_type_ids', 'label_ids']
109 | if "pinyin_ids" in features[0] and features[0]['pinyin_ids'] is not None:
110 | columns = ['input_ids', 'attention_mask', 'token_type_ids', 'pinyin_ids', 'label_ids']
111 | dataset.format_as(columns)
112 | return dataset
113 |
114 |
115 | def get_model_fn(config, args):
116 | def model_fn(inputs, is_training):
117 | model = SequenceClassification(
118 | model_type=args.model_type, config=config,
119 | num_classes=len(args.labels), is_training=is_training,
120 | **inputs)
121 |
122 | outputs = {'outputs': {'logits': model.logits, 'label_ids': inputs['label_ids']}}
123 | if model.loss is not None:
124 | loss = model.loss / args.gradient_accumulation_steps
125 | outputs['loss'] = loss
126 | return outputs
127 |
128 | return model_fn
129 |
130 |
131 | def get_serving_fn(config, args):
132 | def serving_fn():
133 | input_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='input_ids')
134 | attention_mask = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='attention_mask')
135 | token_type_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='token_type_ids')
136 | if args.model_type == 'glyce_bert':
137 | pinyin_ids = tf.placeholder(shape=[None, args.max_seq_length, 8], dtype=tf.int64, name='pinyin_ids')
138 | else:
139 | pinyin_ids = None
140 | model = SequenceClassification(
141 | model_type=args.model_type, config=config,
142 | num_classes=len(args.labels), is_training=False,
143 | input_ids=input_ids,
144 | pinyin_ids=pinyin_ids,
145 | attention_mask=attention_mask,
146 | token_type_ids=token_type_ids
147 | )
148 | inputs = {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids}
149 | if pinyin_ids is not None:
150 | inputs['pinyin_ids'] = pinyin_ids
151 | outputs = {'logits': model.logits}
152 | return inputs, outputs
153 |
154 | return serving_fn
155 |
156 |
157 | def metric_fn(outputs: Dict) -> Dict:
158 | """
159 | 这里定义评估函数
160 | :param outputs: trainer evaluate 返回的预测结果,model fn的outputs包含哪些字段就会有哪些字段
161 | :return: 需要返回字典结果
162 | """
163 | predictions = np.argmax(outputs['logits'], -1)
164 | score = f1_score(outputs['label_ids'], predictions, average='macro')
165 | return {'f1': score}
166 |
167 |
168 | def main():
169 | args = create_args()
170 | set_seed(args.random_seed)
171 |
172 | config = CONFIGS[args.model_type].from_pretrained(
173 | args.model_dir if args.config_path is None else args.config_path)
174 |
175 | tokenizer = TOKENIZERS[args.model_type].from_pretrained(
176 | args.model_dir if args.vocab_path is None else args.vocab_path, do_lower_case=True)
177 |
178 | tokenizer.additional_special_tokens = ["[unused1]", "[unused2]", "[unused3]", "[unused4]"]
179 | train_dataset, dev_dataset, predict_dataset = None, None, None
180 | if args.do_train:
181 | train_dataset = create_dataset('train', tokenizer, args)
182 |
183 | if args.do_eval:
184 | dev_dataset = create_dataset('dev', tokenizer, args)
185 |
186 | if args.do_predict:
187 | predict_dataset = create_dataset('test', tokenizer, args)
188 |
189 | output_types, output_shapes = (train_dataset or dev_dataset or predict_dataset).output_types_and_shapes()
190 | trainer = Trainer(
191 | train_dataset=train_dataset,
192 | eval_dataset=dev_dataset,
193 | output_types=output_types,
194 | output_shapes=output_shapes,
195 | metric_fn=metric_fn,
196 | use_xla=args.use_xla,
197 | optimizer_type=args.optimizer_type,
198 | learning_rate=args.learning_rate,
199 | num_train_epochs=args.num_train_epochs,
200 | gradient_accumulation_steps=args.gradient_accumulation_steps,
201 | max_checkpoints=1,
202 | max_grad=1.0,
203 | warmup_proportion=args.warmup_proportion,
204 | mixed_precision=args.mixed_precision,
205 | single_device=args.single_device,
206 | logging=True
207 | )
208 | trainer.build_model(model_fn=get_model_fn(config, args))
209 | if args.do_train and train_dataset is not None:
210 | # 训练阶段需要先compile优化器才能初始化权重
211 | # 因为adam也是具备参数的
212 | trainer.compile()
213 | trainer.from_pretrained(
214 | args.model_dir if args.pretrained_checkpoint_path is None else args.pretrained_checkpoint_path)
215 | if args.do_train and train_dataset is not None:
216 | trainer.train(
217 | output_dir=args.output_dir,
218 | evaluate_during_training=args.evaluate_during_training,
219 | logging_steps=args.logging_steps,
220 | saving_steps=args.saving_steps,
221 | greater_is_better=True,
222 | load_best_model=True,
223 | metric_for_best_model='f1')
224 | config.save_pretrained(args.output_dir)
225 | tokenizer.save_pretrained(args.output_dir)
226 |
227 | if args.do_eval and dev_dataset is not None:
228 | eval_outputs = trainer.evaluate()
229 | print(json.dumps(
230 | eval_outputs, ensure_ascii=False, indent=4
231 | ))
232 |
233 | if args.do_predict and predict_dataset is not None:
234 | outputs = trainer.predict('test', ['logits'], dataset=predict_dataset)
235 | predictions = np.argmax(outputs['logits'], axis=-1)
236 | with open(os.path.join(args.output_dir, 'prediction.txt'), "w", encoding="utf-8") as f:
237 | for idx, pred in enumerate(predictions):
238 | f.write("{}\t{}\n".format(8001 + idx, args.labels[pred]))
239 |
240 | if args.do_export:
241 | trainer.export(
242 | get_serving_fn(config, args),
243 | args.output_dir,
244 | args.export_dir
245 | )
246 |
247 |
248 | if __name__ == '__main__':
249 | main()
250 |
--------------------------------------------------------------------------------
/simple_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :run_classifier.py
3 | # @Time :2021/4/14 19:45
4 | # @Author :huanghui
5 | import numpy as np
6 |
7 | from tfbert.data import Dataset
8 | from tfbert.models import create_word_embeddings, dropout, create_initializer
9 | from tfbert.models.loss import cross_entropy_loss
10 | from tfbert.models.layers import conv2d_layer, max_pooling_layer
11 | import tensorflow.compat.v1 as tf
12 | from tfbert import SimplerTrainer, ProgressBar, set_seed
13 | import pandas as pd
14 | from collections import Counter
15 | from tqdm import tqdm, trange
16 | import platform
17 | from sklearn.metrics import accuracy_score
18 |
19 | if platform.system() == 'Windows':
20 | bar_fn = ProgressBar
21 | else:
22 | bar_fn = tqdm
23 |
24 | set_seed(42)
25 |
26 |
27 | class TextCNN:
28 | def __init__(self,
29 | max_seq_length,
30 | vocab_size,
31 | is_training,
32 | input_ids,
33 | label_ids):
34 | embedding, _ = create_word_embeddings(
35 | input_ids=input_ids, vocab_size=vocab_size, embedding_size=300
36 | )
37 | embedding = tf.expand_dims(embedding, -1)
38 | pooled_outputs = []
39 | for i, filter_size in enumerate([2, 3, 5]):
40 | with tf.variable_scope("conv_{}".format(filter_size)):
41 | filter_shape = [filter_size, 300, 1, 128]
42 | h = conv2d_layer(embedding, filter_shape)
43 | pooled = max_pooling_layer(h, ksize=[1, max_seq_length - filter_size + 1, 1, 1])
44 | pooled_outputs.append(pooled)
45 | conv_output = tf.concat(pooled_outputs, 3)
46 | conv_output = tf.reshape(conv_output, [-1, 128 * 3])
47 |
48 | with tf.variable_scope("classifier"):
49 | # dropout = get_dropout_prob(is_training, dropout_prob=dropout)
50 | if is_training:
51 | conv_output = dropout(conv_output, dropout_prob=0.3)
52 | self.logits = tf.layers.dense(
53 | conv_output,
54 | 5,
55 | kernel_initializer=create_initializer(0.02),
56 | name='logits'
57 | )
58 | if label_ids is not None:
59 | self.loss = cross_entropy_loss(self.logits, label_ids, 5)
60 |
61 |
62 | def get_model_fn(is_training, vocab_size):
63 | def model_fn():
64 | input_ids = tf.placeholder(shape=[None, 32], dtype=tf.int64, name='input_ids')
65 | if is_training:
66 | label_ids = tf.placeholder(shape=[None], dtype=tf.int64, name='label_ids')
67 | else:
68 | label_ids = None
69 | model = TextCNN(
70 | 32, vocab_size,
71 | is_training=is_training,
72 | input_ids=input_ids,
73 | label_ids=label_ids)
74 | inputs = {'input_ids': input_ids}
75 |
76 | outputs = {"logits": model.logits}
77 | if is_training:
78 | outputs['loss'] = model.loss
79 | inputs['label_ids'] = label_ids
80 | outputs['label_ids'] = label_ids
81 | return inputs, outputs
82 |
83 | return model_fn
84 |
85 |
86 | def create_vocab(train_file, dev_file):
87 | datas = pd.read_csv(train_file, encoding='utf-8', sep='\t').values.tolist()
88 | datas.extend(
89 | pd.read_csv(dev_file, encoding='utf-8', sep='\t').values.tolist()
90 | )
91 | words = []
92 | for data in datas:
93 | words.extend(list(data[1]))
94 | words = [word.strip() for word in words if word.strip()]
95 | counter = Counter(words)
96 | words = counter.most_common(5000)
97 | vocabs = ["", ""] + [word[0] for word in words]
98 | vocab2id = dict(zip(vocabs, range(len(vocabs))))
99 | return vocab2id
100 |
101 |
102 | def load_dataset(filename, is_training, batch_size, max_seq_length, vocab2id, label2id):
103 | data = pd.read_csv(filename, encoding='utf-8', sep='\t').values.tolist()
104 | examples = []
105 | for d in data:
106 | label, text = d
107 | id_ = list(map(lambda x: vocab2id[x] if x in vocab2id else vocab2id[''], list(text)))
108 | id_ = id_[:max_seq_length]
109 | id_ += [vocab2id[""]] * (max_seq_length - len(id_))
110 | examples.append({'input_ids': id_, 'label_ids': label2id[label]})
111 | dataset = Dataset(examples,
112 | is_training=is_training,
113 | batch_size=batch_size,
114 | drop_last=is_training,
115 | buffer_size=len(examples),
116 | max_length=max_seq_length)
117 | dataset.format_as(['input_ids', 'label_ids'])
118 | return dataset
119 |
120 |
121 | max_length = 32
122 | batch_size = 32
123 | data_dir = "D:/python/data/data/classification"
124 | vocab2id = create_vocab(data_dir + "/train.csv", data_dir + "/dev.csv")
125 | label2id = {'体育': 0, '娱乐': 1, '家居': 2, '房产': 3, '教育': 4}
126 |
127 | train_dataset = load_dataset(
128 | data_dir + "/train.csv",
129 | True, batch_size, max_length, vocab2id, label2id
130 | )
131 | dev_dataset = load_dataset(
132 | data_dir + "/dev.csv",
133 | False, batch_size, max_length, vocab2id, label2id
134 | )
135 |
136 | trainer = SimplerTrainer(
137 | optimizer_type='adamw',
138 | learning_rate=5e-5
139 | )
140 | trainer.build_model(model_fn=get_model_fn(True, len(vocab2id)))
141 | trainer.compile()
142 | trainer.init_variables()
143 | best_score = 0
144 | for epoch in trange(5):
145 | epoch_iter = bar_fn(train_dataset)
146 | for d in epoch_iter:
147 | loss = trainer.train_step(d)
148 | epoch_iter.set_description(desc='epoch {} ,loss {:.4f}'.format(epoch + 1, loss))
149 | epoch_iter.close()
150 | outputs = trainer.predict(dev_dataset.get_all_features(), output_names=['logits', 'label_ids'])
151 | y_true, y_pred = outputs['label_ids'], np.argmax(outputs['logits'], axis=-1)
152 | score = accuracy_score(y_true, y_pred)
153 | if score > best_score:
154 | best_score = score
155 | trainer.save_pretrained('output')
156 | print()
157 | print(score)
158 |
--------------------------------------------------------------------------------
/tfbert/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :__init__.py.py
3 | # @Time :2021/1/31 15:16
4 | # @Author :huanghui
5 |
6 | import tensorflow.compat.v1 as tf
7 | import os
8 |
9 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
10 | tf.disable_v2_behavior()
11 |
12 | from .models import (
13 | BertModel, ALBertModel, ElectraModel,
14 | NezhaModel, WoBertModel, GlyceBertModel,
15 | SequenceClassification, MODELS, crf,
16 | TokenClassification, MultiLabelClassification,
17 | MaskedLM, PretrainingLM, QuestionAnswering)
18 | from .config import (
19 | BaseConfig, BertConfig, ALBertConfig,
20 | ElectraConfig, NeZhaConfig, WoBertConfig, GlyceBertConfig, CONFIGS)
21 | from .tokenizer import (
22 | BasicTokenizer, BertTokenizer, WoBertTokenizer,
23 | ALBertTokenizer, ElectraTokenizer, NeZhaTokenizer,
24 | GlyceBertTokenizer, TOKENIZERS)
25 |
26 | from .utils import (
27 | devices, init_checkpoints,
28 | get_assignment_map_from_checkpoint, ProgressBar,
29 | clean_bert_model,
30 | set_seed)
31 |
32 | from .optimization import (
33 | AdamWeightDecayOptimizer, LAMBOptimizer,
34 | lr_schedule, create_optimizer, create_train_op)
35 | from .data import (Dataset, collate_batch, sequence_padding,
36 | single_example_to_features,
37 | multiple_convert_examples_to_features,
38 | compute_types, compute_shapes,
39 | process_dataset, compute_types_and_shapes_from_dataset)
40 | from .trainer import Trainer, SimplerTrainer
41 |
--------------------------------------------------------------------------------
/tfbert/adversarial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | __author__ = 'huanghui'
3 | __date__ = '2021/5/7 20:31'
4 | __project__ = 'tfbert'
5 |
6 | import tensorflow.compat.v1 as tf
7 | from . import utils
8 | import numpy as np
9 |
10 |
11 | class AdversarialOutput:
12 | def __init__(self, outputs: dict, grads_and_vars):
13 | self.outputs = outputs
14 | self.grads_and_vars = grads_and_vars
15 |
16 | def keys(self):
17 | return list(self.outputs.keys())
18 |
19 | def __getitem__(self, item):
20 | return self.outputs[item]
21 |
22 |
23 | def fgm(model_fn, inputs, optimizer=None, layer_name='word_embeddings', epsilon=0.5):
24 | """
25 | FGM对抗训练tensorflow1.x实现
26 | :param model_fn:
27 | :param inputs:
28 | :param optimizer: 优化器
29 | :param layer_name: 扰动的变量名
30 | :param epsilon: 扰动参数
31 | :return:
32 | """
33 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
34 | model_outputs = model_fn(inputs, True)
35 | grads_and_vars = utils.compute_gradients(model_outputs['loss'], optimizer)
36 | # loss对embedding的梯度
37 | embedding_gradients, embeddings = utils.find_grad_and_var(grads_and_vars, layer_name)
38 |
39 | r = tf.multiply(epsilon, embedding_gradients / (tf.norm(embedding_gradients) + 1e-9))
40 | attack_op = embeddings.assign(embeddings + r)
41 | # restore
42 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE), tf.control_dependencies([attack_op]):
43 | adv_outputs = model_fn(inputs, True)
44 | attack_grad_and_vars = utils.compute_gradients(adv_outputs['loss'], optimizer)
45 | restore_op = embeddings.assign(embeddings - r)
46 |
47 | # sum up
48 | with tf.control_dependencies([restore_op]):
49 | grads_and_vars = utils.average_grads_and_vars([grads_and_vars, attack_grad_and_vars])
50 |
51 | return AdversarialOutput(model_outputs, grads_and_vars)
52 |
53 |
54 | def pgd(model_fn, inputs, optimizer=None, layer_name='word_embeddings', epsilon=0.05, n_loop=2):
55 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
56 | model_outputs = model_fn(inputs, True)
57 | grads_and_vars = utils.compute_gradients(model_outputs['loss'], optimizer)
58 | acc_r = 0.0
59 | attack_op = tf.no_op()
60 | for k in range(n_loop):
61 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE), tf.control_dependencies([attack_op]):
62 | adv_outputs = model_fn(inputs, True)
63 | attack_grad_and_vars = utils.compute_gradients(adv_outputs['loss'], optimizer)
64 | embedding_gradients, embeddings = utils.find_grad_and_var(attack_grad_and_vars, layer_name)
65 |
66 | tmp_r = tf.multiply(1 / n_loop, embedding_gradients / (tf.norm(embedding_gradients) + 1e-9))
67 |
68 | norm = tf.norm(acc_r + tmp_r)
69 | cur_r = tf.cond(norm > epsilon,
70 | lambda: (acc_r + tmp_r) * tf.divide(epsilon, norm),
71 | lambda: (acc_r + tmp_r))
72 | r = cur_r - acc_r # calculate current step
73 | attack_op = embeddings.assign(embeddings + r)
74 | acc_r = cur_r
75 |
76 | # restore
77 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE), tf.control_dependencies([attack_op]):
78 | attack_outputs = model_fn(inputs, True)
79 | attack_grad_and_vars = utils.compute_gradients(attack_outputs['loss'], optimizer)
80 | embedding_gradients, embeddings = utils.find_grad_and_var(attack_grad_and_vars, layer_name)
81 | restore_op = embeddings.assign(embeddings - acc_r)
82 | # sum up
83 | with tf.control_dependencies([restore_op]):
84 | grads_and_vars = utils.average_grads_and_vars([grads_and_vars, attack_grad_and_vars])
85 | return AdversarialOutput(model_outputs, grads_and_vars)
86 |
87 |
88 | def freelb(
89 | model_fn, inputs, batch_size, max_length,
90 | optimizer=None, layer_name='word_embeddings',
91 | epsilon=0.3, n_loop=3):
92 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
93 | model_outputs = model_fn(inputs, True)
94 | grads_and_vars = utils.compute_gradients(model_outputs['loss'], optimizer)
95 | # loss对embedding的梯度
96 | embedding_gradients, embeddings = utils.find_grad_and_var(grads_and_vars, layer_name)
97 | init_r = tf.get_variable(
98 | 'init_r',
99 | shape=[batch_size * max_length,
100 | embeddings.shape.as_list()[-1]],
101 | initializer=tf.random_uniform_initializer(
102 | minval=-epsilon, maxval=epsilon),
103 | trainable=False)
104 | init_op = tf.variables_initializer([init_r])
105 | with tf.control_dependencies([init_op]): # fix perturbation
106 | # Scale randomly initialized permutation, to make sure norm
107 | # of `r` is smaller than epsilon.
108 | r = tf.divide(init_r, tf.norm(init_r, np.inf))
109 | r = tf.IndexedSlices(values=r,
110 | indices=embedding_gradients.indices,
111 | dense_shape=embedding_gradients.dense_shape)
112 | attack_op = embeddings.assign(embeddings + r)
113 | # attack
114 | acc_r = r
115 | all_grads_and_vars = []
116 | for k in range(n_loop):
117 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE), tf.control_dependencies([attack_op]):
118 | adv_outputs = model_fn(inputs, True)
119 | attack_grad_and_vars = utils.compute_gradients(adv_outputs['loss'], optimizer)
120 | all_grads_and_vars.append(attack_grad_and_vars)
121 | gradients, _ = utils.find_grad_and_var(attack_grad_and_vars, layer_name)
122 | tmp_r = tf.multiply(1 / n_loop, gradients / (tf.norm(gradients) + 1e-9))
123 |
124 | # In order not to shuffle the distribution of gradient-
125 | # induced perturbation, we use norm to scale instead of
126 | # simply clip the values.
127 | norm = tf.norm(acc_r + tmp_r)
128 | cur_r = tf.cond(norm > epsilon,
129 | lambda: (acc_r + tmp_r) * tf.divide(epsilon, norm),
130 | lambda: (acc_r + tmp_r))
131 | r = cur_r - acc_r # calculate current step
132 | attack_op = embeddings.assign(embeddings + r)
133 | acc_r = cur_r
134 | # restore
135 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE), tf.control_dependencies([attack_op]):
136 | attack_outputs = model_fn(inputs, True)
137 | attack_grad_and_vars = utils.compute_gradients(attack_outputs['loss'], optimizer)
138 |
139 | all_grads_and_vars.append(attack_grad_and_vars)
140 | restore_op = embeddings.assign(embeddings - acc_r)
141 |
142 | # sum up
143 | with tf.control_dependencies([restore_op]):
144 | grads_and_vars = utils.average_grads_and_vars(all_grads_and_vars)
145 | return AdversarialOutput(model_outputs, grads_and_vars)
146 |
--------------------------------------------------------------------------------
/tfbert/config/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: __init__.py.py
5 | @date: 2020/09/08
6 | """
7 |
8 | from .base import BaseConfig
9 | from .ptm import (
10 | BertConfig, ALBertConfig, ElectraConfig, GlyceBertConfig)
11 | from .ptm import BertConfig as NeZhaConfig
12 | from .ptm import BertConfig as WoBertConfig
13 |
14 | CONFIGS = {
15 | 'bert': BertConfig, 'albert': ALBertConfig,
16 | 'nezha': NeZhaConfig, 'electra': ElectraConfig,
17 | 'wobert': WoBertConfig, 'glyce_bert': GlyceBertConfig
18 | }
19 |
--------------------------------------------------------------------------------
/tfbert/config/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: tokenization_base.py
5 | @date: 2020/09/08
6 | """
7 | import os
8 | from typing import Dict
9 | import tensorflow.compat.v1 as tf
10 | import copy
11 | import json
12 |
13 |
14 | class BaseConfig(object):
15 | filename = "config.json"
16 |
17 | def __init__(self, **kwargs):
18 |
19 | self.output_attentions = kwargs.pop("output_attentions", False)
20 | self.output_hidden_states = kwargs.pop("output_hidden_states", False)
21 | self.use_one_hot_embeddings = kwargs.pop('use_one_hot_embeddings', False)
22 |
23 | for key, value in kwargs.items():
24 | try:
25 | setattr(self, key, value)
26 | except AttributeError as err:
27 | tf.logging.info("Can't set {} with value {} for {}".format(key, value, self))
28 | raise err
29 |
30 | def __eq__(self, other):
31 | return self.__dict__ == other.__dict__
32 |
33 | def __repr__(self):
34 | return "{} : \n{}".format(self.__class__.__name__, self.to_json_string())
35 |
36 | def to_dict(self):
37 | """
38 | 将config属性序列化成dict
39 | """
40 | output = copy.deepcopy(self.__dict__)
41 | return output
42 |
43 | def to_json_string(self):
44 | """
45 | 将config属性序列化成json字符串
46 | """
47 | return json.dumps(self.to_dict(), indent=4, sort_keys=True)
48 |
49 | def save_to_json_file(self, json_file_path):
50 | """
51 | 将config存入json文件
52 | """
53 | with open(json_file_path, "w", encoding="utf-8") as writer:
54 | writer.write(self.to_json_string())
55 |
56 | @classmethod
57 | def _dict_from_json_file(cls, json_file: str):
58 | '''
59 | 从json文件读取字典
60 | :param json_file:
61 | :return:
62 | '''
63 | with open(json_file, "r", encoding="utf-8") as reader:
64 | text = reader.read()
65 | return json.loads(text)
66 |
67 | @classmethod
68 | def from_json_file(cls, json_file: str) -> "BaseConfig":
69 | """
70 | 从json文件中加载config
71 | """
72 | config_dict = cls._dict_from_json_file(json_file)
73 | return cls(**config_dict)
74 |
75 | @classmethod
76 | def from_dict(cls, config_dict: Dict, **kwargs) -> "BaseConfig":
77 | """
78 | 从字典中加载config
79 | """
80 | config = cls(**config_dict)
81 |
82 | # Update config with kwargs if needed
83 | to_remove = []
84 | for key, value in kwargs.items():
85 | if hasattr(config, key):
86 | setattr(config, key, value)
87 | to_remove.append(key)
88 | for key in to_remove:
89 | kwargs.pop(key, None)
90 |
91 | return config
92 |
93 | def save_pretrained(self, save_dir_or_file):
94 | '''
95 | 保存config,如果save_dir_or_file是个文件夹,
96 | 则保存至默认文件名:save_dir_or_file + 'config.json'
97 | 如果是文件名,保存至该文件
98 | :param save_dir_or_file:
99 | :return:
100 | '''
101 | if os.path.isdir(save_dir_or_file):
102 | output_config_file = os.path.join(save_dir_or_file, self.filename)
103 | else:
104 | output_config_file = save_dir_or_file
105 |
106 | self.save_to_json_file(output_config_file)
107 | tf.logging.info(' Configuration saved in {}'.format(output_config_file))
108 | return output_config_file
109 |
110 | @classmethod
111 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
112 | '''
113 | 从文件夹或文件中加载config
114 | :param pretrained_model_name_or_path:
115 | :param kwargs:
116 | :return:
117 | '''
118 |
119 | if os.path.isdir(pretrained_model_name_or_path):
120 | config_file = os.path.join(pretrained_model_name_or_path, cls.filename)
121 | elif os.path.isfile(pretrained_model_name_or_path):
122 | config_file = pretrained_model_name_or_path
123 | else:
124 | raise ValueError('Config path should be a directory or file')
125 |
126 | config_dict = cls._dict_from_json_file(config_file)
127 | return cls.from_dict(config_dict, **kwargs)
128 |
--------------------------------------------------------------------------------
/tfbert/config/ptm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: tokenization_base.py
5 | @date: 2020/09/08
6 | """
7 | from . import BaseConfig
8 | import re
9 | import tensorflow.compat.v1 as tf
10 | import os
11 | import shutil
12 |
13 |
14 | class BertConfig(BaseConfig):
15 | def __init__(self,
16 | vocab_size,
17 | embedding_size=None,
18 | hidden_size=768,
19 | num_hidden_layers=12,
20 | num_attention_heads=12,
21 | intermediate_size=3072,
22 | hidden_act="gelu",
23 | hidden_dropout_prob=0.1,
24 | attention_probs_dropout_prob=0.1,
25 | max_position_embeddings=512,
26 | type_vocab_size=16,
27 | initializer_range=0.02,
28 | **kwargs
29 | ):
30 | super().__init__(**kwargs)
31 |
32 | self.vocab_size = vocab_size
33 | self.embedding_size = embedding_size if embedding_size is not None else hidden_size
34 | self.hidden_size = hidden_size
35 | self.num_hidden_layers = num_hidden_layers
36 | self.num_attention_heads = num_attention_heads
37 | self.hidden_act = hidden_act
38 | self.intermediate_size = intermediate_size
39 | self.hidden_dropout_prob = hidden_dropout_prob
40 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
41 | self.max_position_embeddings = max_position_embeddings
42 | self.type_vocab_size = type_vocab_size
43 | self.initializer_range = initializer_range
44 |
45 |
46 | class ALBertConfig(BaseConfig):
47 |
48 | def __init__(self,
49 | vocab_size,
50 | embedding_size=128,
51 | hidden_size=4096,
52 | num_hidden_layers=12,
53 | num_hidden_groups=1,
54 | num_attention_heads=64,
55 | intermediate_size=16384,
56 | inner_group_num=1,
57 | down_scale_factor=1,
58 | hidden_act="gelu",
59 | hidden_dropout_prob=0,
60 | attention_probs_dropout_prob=0,
61 | max_position_embeddings=512,
62 | type_vocab_size=2,
63 | initializer_range=0.02,
64 | **kwargs):
65 | super().__init__(**kwargs)
66 |
67 | self.vocab_size = vocab_size
68 | self.embedding_size = embedding_size
69 | self.hidden_size = hidden_size
70 | self.num_hidden_layers = num_hidden_layers
71 | self.num_hidden_groups = num_hidden_groups
72 | self.num_attention_heads = num_attention_heads
73 | self.inner_group_num = inner_group_num
74 | self.down_scale_factor = down_scale_factor
75 | self.hidden_act = hidden_act
76 | self.intermediate_size = intermediate_size
77 | self.hidden_dropout_prob = hidden_dropout_prob
78 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
79 | self.max_position_embeddings = max_position_embeddings
80 | self.type_vocab_size = type_vocab_size
81 | self.initializer_range = initializer_range
82 |
83 |
84 | class ElectraConfig(BaseConfig):
85 | """Configuration for `BertModel` (ELECTRA uses the same ptm as BERT)."""
86 |
87 | def __init__(self,
88 | vocab_size,
89 | hidden_size=768,
90 | num_hidden_layers=12,
91 | num_attention_heads=12,
92 | intermediate_size=3072,
93 | hidden_act="gelu",
94 | hidden_dropout_prob=0.1,
95 | attention_probs_dropout_prob=0.1,
96 | max_position_embeddings=512,
97 | type_vocab_size=2,
98 | initializer_range=0.02,
99 | **kwargs):
100 | super().__init__(**kwargs)
101 |
102 | self.vocab_size = vocab_size
103 | self.hidden_size = hidden_size
104 | self.num_hidden_layers = num_hidden_layers
105 | self.num_attention_heads = num_attention_heads
106 | self.hidden_act = hidden_act
107 | self.intermediate_size = intermediate_size
108 | self.hidden_dropout_prob = hidden_dropout_prob
109 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
110 | self.max_position_embeddings = max_position_embeddings
111 | self.type_vocab_size = type_vocab_size
112 | self.initializer_range = initializer_range
113 |
114 | @classmethod
115 | def from_checkpoint(cls, checkpoint_path, **kwargs):
116 | """
117 | 由于electra 没有给出config.json,所以构建一个方法,从checkpoint中读取配置信息。
118 | :param checkpoint_path: electra模型的checkpoint文件
119 | :param kwargs:
120 | :return:
121 | """
122 | # 参数映射,checkpoint变量名: (config配置参数,配置参数属于变量shape的哪个维度的大小)
123 | param_map = {
124 | 'electra/embeddings/word_embeddings': ('vocab_size', 0),
125 | 'electra/encoder/layer_0/attention/output/dense/bias': ('hidden_size', 0),
126 | 'electra/encoder/layer_0/intermediate/dense/bias': ('intermediate_size', 0),
127 | 'electra/embeddings/position_embeddings': ('max_position_embeddings', 0),
128 | 'electra/embeddings/token_type_embeddings': ('type_vocab_size', 0)
129 | }
130 | # 基本参数
131 | param = {
132 | 'hidden_dropout_prob': 0.1,
133 | 'attention_probs_dropout_prob': 0.1,
134 | 'hidden_act': 'gelu',
135 | 'initializer_range': 0.02
136 | }
137 |
138 | # 加载checkpoint,获取相应参数
139 | init_vars = tf.train.list_variables(checkpoint_path)
140 | num_hidden_layers = 0
141 | for x in init_vars:
142 | name, shape = x[0], x[1]
143 | if name in param_map:
144 | param[param_map[name][0]] = shape[param_map[name][1]]
145 |
146 | if 'layer_' in name:
147 | layer = re.match(".*?layer_(\\d+)/.*?", name).group(1)
148 | if int(layer) >= num_hidden_layers:
149 | num_hidden_layers = int(layer)
150 |
151 | param['num_hidden_layers'] = num_hidden_layers + 1
152 | param['num_attention_heads'] = max(1, param["hidden_size"] // 64)
153 |
154 | return cls(**param, **kwargs)
155 |
156 |
157 | class GlyceBertConfig(BaseConfig):
158 | def __init__(self,
159 | vocab_size,
160 | embedding_size=None,
161 | hidden_size=768,
162 | num_hidden_layers=12,
163 | num_attention_heads=12,
164 | intermediate_size=3072,
165 | hidden_act="gelu",
166 | hidden_dropout_prob=0.1,
167 | attention_probs_dropout_prob=0.1,
168 | max_position_embeddings=512,
169 | type_vocab_size=16,
170 | initializer_range=0.02,
171 | config_path="",
172 | **kwargs
173 | ):
174 | super().__init__(**kwargs)
175 |
176 | self.vocab_size = vocab_size
177 | self.embedding_size = embedding_size if embedding_size is not None else hidden_size
178 | self.hidden_size = hidden_size
179 | self.num_hidden_layers = num_hidden_layers
180 | self.num_attention_heads = num_attention_heads
181 | self.hidden_act = hidden_act
182 | self.intermediate_size = intermediate_size
183 | self.hidden_dropout_prob = hidden_dropout_prob
184 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
185 | self.max_position_embeddings = max_position_embeddings
186 | self.type_vocab_size = type_vocab_size
187 | self.initializer_range = initializer_range
188 | self.config_path = config_path
189 |
190 | @classmethod
191 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
192 | '''
193 | 从文件夹或文件中加载config
194 | :param pretrained_model_name_or_path:
195 | :param kwargs:
196 | :return:
197 | '''
198 |
199 | if os.path.isdir(pretrained_model_name_or_path):
200 | config_file = os.path.join(pretrained_model_name_or_path, cls.filename)
201 | config_path = os.path.join(pretrained_model_name_or_path, "config")
202 | elif os.path.isfile(pretrained_model_name_or_path):
203 | config_file = pretrained_model_name_or_path
204 | dir_ = os.path.split(config_file)[0]
205 | config_path = os.path.join(dir_, 'config')
206 | else:
207 | raise ValueError('Config path should be a directory or file')
208 |
209 | config_dict = cls._dict_from_json_file(config_file)
210 | kwargs['config_path'] = config_path
211 | return cls.from_dict(config_dict, **kwargs)
212 |
213 | def save_pretrained(self, save_dir_or_file):
214 | if os.path.isdir(save_dir_or_file):
215 | output_config_file = os.path.join(save_dir_or_file, self.filename)
216 | config_path = os.path.join(save_dir_or_file, 'config')
217 | else:
218 | output_config_file = save_dir_or_file
219 | config_path = os.path.join(os.path.split(save_dir_or_file)[0], "config")
220 | if not os.path.exists(config_path):
221 | os.makedirs(config_path)
222 |
223 | filenames = os.listdir(self.config_path)
224 | if len(filenames) > 0:
225 | for filename in filenames:
226 | if filename.endswith('.npy'):
227 | shutil.copyfile(
228 | os.path.join(self.config_path, filename), os.path.join(config_path, filename)
229 | )
230 | self.save_to_json_file(output_config_file)
231 | tf.logging.info(' Configuration saved in {}'.format(output_config_file))
232 | return output_config_file
233 |
--------------------------------------------------------------------------------
/tfbert/data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: __init__.py.py
5 | @date: 2020/09/08
6 | """
7 | import json
8 | import copy
9 | from multiprocessing import cpu_count, Pool
10 | from tqdm import tqdm
11 | import numpy as np
12 | import tensorflow.compat.v1 as tf
13 |
14 |
15 | class BaseClass:
16 | def dict(self):
17 | output = copy.deepcopy(self.__dict__)
18 | return output
19 |
20 | def __str__(self):
21 | return "{} \n {}".format(
22 | self.__class__.__name__, json.dumps(self.dict(), ensure_ascii=False))
23 |
24 | def keys(self):
25 | return list(self.dict().keys())
26 |
27 | def __getitem__(self, item):
28 | return self.dict()[item]
29 |
30 | def __contains__(self, item):
31 | return item in self.dict()
32 |
33 |
34 | def single_example_to_features(
35 | examples, annotate_, desc='convert examples to feature'):
36 | features = []
37 | for example in tqdm(examples, desc=desc):
38 | features.append(annotate_(example))
39 | return features
40 |
41 |
42 | def multiple_convert_examples_to_features(
43 | examples,
44 | annotate_,
45 | initializer,
46 | initargs,
47 | threads,
48 | desc='convert examples to feature'):
49 | threads = min(cpu_count(), threads)
50 | features = []
51 | with Pool(threads, initializer=initializer, initargs=initargs) as p:
52 | features = list(tqdm(
53 | p.imap(annotate_, examples, chunksize=32),
54 | total=len(examples),
55 | desc=desc
56 | ))
57 | return features
58 |
59 |
60 | def process_dataset(dataset, batch_size, num_features, set_type, buffer_size=None):
61 | if set_type == 'train':
62 | if buffer_size is None:
63 | buffer_size = num_features
64 | dataset = dataset.repeat()
65 | dataset = dataset.shuffle(buffer_size=buffer_size)
66 | dataset = dataset.batch(batch_size=batch_size,
67 | drop_remainder=bool(set_type == 'train'))
68 | dataset.prefetch(tf.data.experimental.AUTOTUNE)
69 |
70 | # 在train阶段,因为设置了drop_remainder,会舍弃不足batch size的一个batch,所以步数计算方式和验证测试不同
71 | if set_type == 'train':
72 | num_batch_per_epoch = num_features // batch_size
73 | else:
74 | num_batch_per_epoch = (num_features + batch_size - 1) // batch_size
75 | return dataset, num_batch_per_epoch
76 |
77 |
78 | def compute_types(example, columns=None):
79 | if columns is None:
80 | columns = list(example.keys())
81 |
82 | def fn(values):
83 | if isinstance(values, np.ndarray):
84 | if values.dtype == np.dtype(float):
85 | return tf.float32
86 | elif values.dtype == np.int64:
87 | return tf.int32 # 统一使用int32
88 | elif values.dtype == np.int32:
89 | return tf.int32
90 | else:
91 | raise ValueError(
92 | f"values={values} is an np.ndarray with items of dtype {values.dtype}, which cannot be supported"
93 | )
94 | # 支持到二维矩阵。。。
95 | elif (isinstance(values, list) and isinstance(values[0], float)) or isinstance(values, float):
96 | return tf.float32
97 | elif (isinstance(values, list) and isinstance(values[0], int)) or isinstance(values, int):
98 | return tf.int32
99 | elif (isinstance(values, list) and isinstance(values[0], str)) or isinstance(values, str):
100 | return tf.string
101 | elif isinstance(values, list) and isinstance(values[0], list):
102 | return fn(values[0])
103 | else:
104 | raise ValueError(f"values={values} has dtype {values.dtype}, which cannot be supported")
105 |
106 | tf_types = {}
107 | for k in columns:
108 | if k in example:
109 | tf_types[k] = fn(example[k])
110 | return tf_types
111 |
112 |
113 | def compute_shapes(example, columns=None):
114 | if columns is None:
115 | columns = list(example.keys())
116 |
117 | def fn(array):
118 | np_shape = np.shape(array)
119 | return [None] * len(np_shape)
120 |
121 | tf_shapes = {}
122 | for k in columns:
123 | if k in example:
124 | tf_shapes[k] = fn(example[k])
125 | return tf_shapes
126 |
127 |
128 | def compute_types_and_shapes_from_dataset(dataset: tf.data.Dataset, use_none=False):
129 | """
130 | 根据 tf dataset,得到dataset的tensor shapes和types
131 | :param dataset:
132 | :param use_none: 是否将shapes全部设置为None,这样避免bs不统一
133 | :return:
134 | """
135 |
136 | def to_none(tensor_shape: tf.TensorShape):
137 | return tf.TensorShape([None] * len(tensor_shape.as_list()))
138 |
139 | output_types = tf.data.get_output_types(dataset)
140 | output_shapes = tf.data.get_output_shapes(dataset)
141 | if use_none:
142 | for k in output_shapes:
143 | output_shapes[k] = to_none(output_shapes[k])
144 | return output_types, output_shapes
145 |
146 |
147 | from .dataset import Dataset, collate_batch, sequence_padding
148 |
--------------------------------------------------------------------------------
/tfbert/data/classification.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: classification.py
5 | @date: 2020/09/09
6 | """
7 | from functools import partial
8 | from . import BaseClass, multiple_convert_examples_to_features, single_example_to_features
9 |
10 |
11 | class InputExample(BaseClass):
12 | def __init__(self, guid, text_a, text_b=None, label=None):
13 | self.guid = guid
14 | self.text_a = text_a
15 | self.text_b = text_b
16 | self.label = label
17 |
18 |
19 | class InputFeature(BaseClass):
20 | """A single set of features of data."""
21 |
22 | def __init__(self,
23 | guid,
24 | input_ids,
25 | attention_mask=None,
26 | token_type_ids=None,
27 | pinyin_ids=None,
28 | label_ids=None,
29 | ex_id=None):
30 | self.guid = guid
31 | self.input_ids = input_ids
32 | self.attention_mask = attention_mask
33 | self.token_type_ids = token_type_ids
34 | self.pinyin_ids = pinyin_ids
35 | self.label_ids = label_ids
36 | self.ex_id = ex_id
37 |
38 |
39 | def convert_example_to_feature(example: InputExample,
40 | max_length=512,
41 | label_map=None,
42 | is_multi_label=False) -> InputFeature:
43 | """
44 | text,
45 | text_pair=None,
46 | max_length=512,
47 | pad_to_max_len=False,
48 | truncation_strategy="longest_first",
49 | return_position_ids=False,
50 | return_token_type_ids=True,
51 | return_attention_mask=True,
52 | return_length=False,
53 | return_overflowing_tokens=False,
54 | return_special_tokens_mask=False
55 | :param example:
56 | :param max_length:
57 | :param label_map:
58 | :param is_multi_label:
59 | :return:
60 | """
61 | inputs = tokenizer(
62 | example.text_a, # 传入句子 a
63 | text_pair=example.text_b, # 传入句子 b,可以为None
64 | max_length=max_length, # 最大长度
65 | padding="max_length", # 是否将句子padding到最大长度
66 | truncation=True
67 | )
68 | if example.label is not None:
69 | # 多标签分类的话,先将label设为one hot 类型
70 | if is_multi_label:
71 | label_id = [0] * len(label_map)
72 | for lb in example.label:
73 | label_id[label_map[lb]] = 1
74 | else:
75 | label_id = label_map[example.label]
76 | else:
77 | label_id = None
78 | return InputFeature(
79 | guid=0,
80 | input_ids=inputs['input_ids'],
81 | attention_mask=inputs['attention_mask'],
82 | token_type_ids=inputs['token_type_ids'],
83 | pinyin_ids=inputs['pinyin_ids'] if "pinyin_ids" in inputs else None,
84 | label_ids=label_id,
85 | ex_id=example.guid
86 | )
87 |
88 |
89 | def convert_example_to_feature_init(tokenizer_for_convert):
90 | global tokenizer
91 | tokenizer = tokenizer_for_convert
92 |
93 |
94 | def convert_examples_to_features(
95 | examples,
96 | tokenizer,
97 | max_length=512,
98 | label_list=None,
99 | set_type='train',
100 | is_multi_label=False,
101 | threads=1
102 | ):
103 | '''
104 | 将examples转为features, 适用于单句和双句分类任务
105 | :param examples:
106 | :param tokenizer: bert分词器
107 | :param max_length: 句子最大长度
108 | :param label_list: 标签
109 | :param set_type:
110 | :param is_multi_label: 是否是多标签分类
111 | :param threads:
112 | :return:
113 | '''
114 |
115 | label_map = None
116 | if label_list is not None:
117 | label_map = {label: i for i, label in enumerate(label_list)}
118 | annotate_ = partial(
119 | convert_example_to_feature,
120 | max_length=max_length,
121 | label_map=label_map,
122 | is_multi_label=is_multi_label
123 | )
124 | if threads > 1:
125 | features = multiple_convert_examples_to_features(
126 | examples,
127 | annotate_=annotate_,
128 | initializer=convert_example_to_feature_init,
129 | initargs=(tokenizer,),
130 | threads=threads
131 | )
132 | else:
133 | convert_example_to_feature_init(tokenizer)
134 | features = single_example_to_features(
135 | examples, annotate_=annotate_
136 | )
137 | new_features = []
138 | i = 0
139 | for feature in features:
140 | feature.guid = set_type + '-' + str(i)
141 | new_features.append(feature)
142 | return new_features
143 |
--------------------------------------------------------------------------------
/tfbert/data/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | __author__ = 'huanghui'
3 | __date__ = '2021/4/16 22:51'
4 | __project__ = 'tfbert'
5 |
6 | import copy
7 | from . import BaseClass, compute_shapes, compute_types, compute_types_and_shapes_from_dataset
8 | import numpy as np
9 | from typing import List, Dict, Optional, Union
10 | import random
11 | import tensorflow.compat.v1 as tf
12 |
13 |
14 | def sequence_padding(ids: List, max_length=None, pad_id=0, mode='post'):
15 | """
16 | copy的苏神sequence_padding代码
17 | :param ids:
18 | :param max_length:
19 | :param pad_id:
20 | :param mode:
21 | :return:
22 | """
23 | if not isinstance(ids[0], list):
24 | return ids
25 | if max_length is None:
26 | max_length = max([len(x) for x in ids])
27 |
28 | pad_width = [(0, 0) for _ in np.shape(ids[0])]
29 | outputs = []
30 | for id_ in ids:
31 | x = id_[:max_length]
32 | if mode == 'post':
33 | pad_width[0] = (0, max_length - len(x))
34 | elif mode == 'pre':
35 | pad_width[0] = (max_length - len(x), 0)
36 | else:
37 | raise ValueError('"mode" argument must be "post" or "pre".')
38 | x = np.pad(x, pad_width, 'constant', constant_values=pad_id)
39 | outputs.append(x)
40 |
41 | return np.array(outputs)
42 |
43 |
44 | def collate_batch(
45 | examples: Union[Dict, List[Dict]],
46 | max_length: Optional[Union[int, Dict]] = None,
47 | pad_id: Optional[Union[int, Dict]] = 0,
48 | mode='post'):
49 | """
50 | :param examples: 可以是一个二维列表,可以是一个元素为字典的列表, 也可以是一个字典
51 | :param max_length:单个int或者字典指定,字典指定的话会对每一个字典填充对应的长度
52 | :param pad_id: 单个id或者字典指定,字典指定的话会对每一个字典填充对应的id
53 | :param mode:
54 | :return:
55 | """
56 | if isinstance(examples, dict) or isinstance(examples[0], dict):
57 | if isinstance(examples, dict):
58 | result = examples
59 | else:
60 | result = {k: [] for k in examples[0]}
61 | for i in range(len(examples)):
62 | for k in result:
63 | result[k].append(examples[i][k])
64 | if not isinstance(pad_id, dict):
65 | pad_id = {k: pad_id for k in result}
66 | if not isinstance(max_length, dict):
67 | max_length = {k: max_length for k in result}
68 | for k, v in result.items():
69 | if isinstance(v[0], list):
70 | result[k] = sequence_padding(v, max_length[k], pad_id[k], mode)
71 | else:
72 | result[k] = v
73 | return result
74 | elif isinstance(examples[0], list):
75 | return sequence_padding(examples, max_length, pad_id, mode)
76 |
77 |
78 | class Dataset:
79 | def __init__(
80 | self,
81 | features: List[Union[Dict, BaseClass]],
82 | is_training=False,
83 | batch_size=1,
84 | drop_last=False,
85 | buffer_size=100,
86 | padding=False,
87 | max_length: Optional[Union[int, Dict]] = None,
88 | pad_id: Optional[Union[int, Dict]] = 0,
89 | pad_mode='post',
90 | padding_all=False):
91 | """
92 | 简单的dataset包装,未完成 map 方法
93 | :param features: 列表,元素为键值相同的字典
94 | :param is_training: 是否训练,训练模式会对数据shuffle
95 | :param batch_size: 批次大小
96 | :param drop_last: 舍弃最后一个不足批次大小的数据
97 | :param padding: 是否填充
98 | :param max_length: 填充的最大,为None的话补全到当前batch最大长度
99 | :param pad_id: 填充id,可以是单个id,也可以是字典类型,每个字典对应自己的填充id
100 | :param pad_mode: post 或 pre,即在后边填充和在前边填充
101 | :param padding_all:是否直接对所有features进行padding,这样会在迭代过程中减少padding的操作
102 | """
103 | if is_training:
104 | random.shuffle(features)
105 | self.is_training = is_training
106 | self.buffer_size = buffer_size
107 |
108 | self.features = {}
109 | self.columns = []
110 | for i in range(len(features)):
111 | feature = features[i]
112 | if isinstance(features[i], BaseClass):
113 | feature = features[i].dict()
114 | for k, v in feature.items():
115 | if i == 0:
116 | self.features[k] = []
117 | self.columns.append(k)
118 | self.features[k].append(v)
119 | self.back_columns = copy.deepcopy(self.columns)
120 | if not isinstance(pad_id, dict):
121 | pad_id = {k: pad_id for k in self.features}
122 |
123 | if not isinstance(max_length, dict):
124 | max_length = {k: max_length for k in self.features}
125 |
126 | self.batch_size = batch_size
127 | self.drop_last = drop_last
128 |
129 | self.idx = 0
130 | self.padding = padding
131 | self.pad_id = pad_id
132 | self.max_length = max_length
133 | self.pad_mode = pad_mode
134 |
135 | self.padded = False
136 | if padding_all:
137 | self.features = collate_batch(self.features, self.max_length, self.pad_id, self.pad_mode)
138 | self.padded = True
139 |
140 | self.output_types = {}
141 | self.output_shapes = {}
142 |
143 | def __num_batch__(self):
144 | num_features = len(self.features[self.columns[0]])
145 | if self.drop_last:
146 | num_batch = num_features // self.batch_size
147 | else:
148 | num_batch = (num_features + self.batch_size - 1) // self.batch_size
149 | return num_batch
150 |
151 | @property
152 | def num_batch(self):
153 | return self.__num_batch__()
154 |
155 | def __len__(self):
156 | return self.num_batch
157 |
158 | def __getitem__(self, item):
159 | if item in self.columns:
160 | return self.features[item]
161 | elif isinstance(item, int):
162 | data = {}
163 | for k in self.columns:
164 | data[k] = self.features[k][item]
165 | else:
166 | raise ValueError(f"type error of {item}")
167 | return data
168 |
169 | def __iter__(self):
170 | return self
171 |
172 | def __repr__(self):
173 | return f"Dataset({{\n features: {self.columns},\n num_batch: {self.num_batch}\n}})"
174 |
175 | def get_all_features(self):
176 | features = {}
177 | for k in self.columns:
178 | features[k] = self.features[k]
179 | return features
180 |
181 | def __next__(self):
182 | if self.idx < self.num_batch:
183 | start = self.idx * self.batch_size
184 | end = (self.idx + 1) * self.batch_size
185 | batch = {}
186 | for k in self.columns:
187 | if not self.padded and self.padding:
188 | batch[k] = sequence_padding(self.features[k][start: end], self.max_length[k], self.pad_id[k],
189 | self.pad_mode)
190 | else:
191 | batch[k] = self.features[k][start: end]
192 | self.idx += 1
193 | return batch
194 | else:
195 | self.idx = 0
196 | raise StopIteration
197 |
198 | def remove_columns(self, remove_columns):
199 | if isinstance(remove_columns, str):
200 | remove_columns = [remove_columns]
201 | elif not isinstance(remove_columns, list):
202 | remove_columns = []
203 | for remove_column in remove_columns:
204 | if remove_column in self.features:
205 | self.features.pop(remove_column)
206 | self.columns.remove(remove_column)
207 | self.back_columns.remove(remove_column)
208 |
209 | def format_as(self, columns):
210 | if not isinstance(columns, list):
211 | columns = [columns]
212 | new_columns = []
213 | for column in columns:
214 | if column in self.back_columns:
215 | new_columns.append(column)
216 | self.columns = new_columns
217 |
218 | def restore_columns(self):
219 | self.columns = copy.deepcopy(self.back_columns)
220 |
221 | def process_dataset(self, dataset: tf.data.Dataset):
222 | if self.is_training:
223 | dataset = dataset.repeat()
224 | dataset = dataset.shuffle(buffer_size=self.buffer_size)
225 | dataset = dataset.batch(batch_size=self.batch_size,
226 | drop_remainder=self.drop_last)
227 | dataset.prefetch(tf.data.experimental.AUTOTUNE)
228 | self.output_types, self.output_shapes = self.get_output_types_and_shapes(dataset)
229 | return dataset
230 |
231 | def tf_gen_dataset(self):
232 | """
233 | 将dataset转成tf dataset,此方法使用生成器的方式进行
234 | :return:
235 | """
236 |
237 | def gen():
238 | for i in range(len(self.features[self.columns[0]])):
239 | data = {}
240 | for k in self.columns:
241 | data[k] = self.features[k][i]
242 | yield data
243 |
244 | shapes = compute_shapes(self[0], self.columns)
245 | types = compute_types(self[0], self.columns)
246 | return self.dataset_from_generator(gen, types, shapes)
247 |
248 | def dataset_from_generator(self, generator, types, shapes):
249 | dataset = tf.data.Dataset.from_generator(
250 | generator,
251 | types,
252 | shapes
253 | )
254 | return self.process_dataset(dataset)
255 |
256 | def tf_slice_dataset(self):
257 | """
258 | 对应slice类型的tf dataset
259 | :return:
260 | """
261 | dataset = {}
262 | types = compute_types(self[0], self.columns)
263 | for k in self.columns:
264 | dataset[k] = tf.constant(self.features[k], dtype=types[k])
265 | dataset = tf.data.Dataset.from_tensor_slices(dataset)
266 | return self.process_dataset(dataset)
267 |
268 | def format_to_tf_dataset(self, dataset_type='generator'):
269 | assert dataset_type in ['generator', 'slice']
270 | if dataset_type == 'generator':
271 | return self.tf_gen_dataset()
272 | return self.tf_slice_dataset()
273 |
274 | @classmethod
275 | def get_output_types_and_shapes(cls, dataset: tf.data.Dataset, use_none=False):
276 | """
277 | 根据 tf dataset,得到 dataset的 tensor shapes和 types
278 | :param dataset:
279 | :param use_none: 是否将shapes全部设置为None,这样避免bs不统一
280 | :return:
281 | """
282 | return compute_types_and_shapes_from_dataset(dataset, use_none)
283 |
284 | def output_types_and_shapes(self):
285 | shapes = compute_shapes(self.features, self.columns)
286 | types = compute_types(self[0], self.columns)
287 | return types, shapes
288 |
--------------------------------------------------------------------------------
/tfbert/data/ner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: ner.py
5 | @date: 2020/09/12
6 | """
7 | from functools import partial
8 | from . import BaseClass, multiple_convert_examples_to_features, single_example_to_features
9 | from ..tokenizer import GlyceBertTokenizer
10 | from typing import List
11 |
12 |
13 | class InputExample(BaseClass):
14 | def __init__(self, guid, words: List[str], tags: List[str] = None):
15 | self.guid = guid
16 | self.words = words
17 | self.tags = tags
18 |
19 |
20 | class InputFeature(BaseClass):
21 | """A single set of features of data."""
22 |
23 | def __init__(self,
24 | guid,
25 | input_ids: List[int],
26 | attention_mask: List[int] = None,
27 | token_type_ids: List[int] = None,
28 | pinyin_ids=None,
29 | label_ids: List[int] = None,
30 | tok_to_orig_index: List[int] = None,
31 | ex_id=None):
32 | self.guid = guid
33 | self.input_ids = input_ids
34 | self.attention_mask = attention_mask
35 | self.token_type_ids = token_type_ids
36 | self.pinyin_ids = pinyin_ids
37 | self.label_ids = label_ids
38 | self.ex_id = ex_id
39 | self.tok_to_orig_index = tok_to_orig_index
40 |
41 |
42 | def convert_example_to_feature(
43 | example: InputExample,
44 | max_length=512,
45 | label_map=None,
46 | pad_token_label_id=0
47 | ):
48 | has_label = bool(example.tags is not None)
49 | tokens = []
50 | label_ids = []
51 | tok_to_orig_index = [] # 用来存放token 和 原始words列表的位置对应关系,因为bert分词可能会将一个word分成多个token
52 | for i in range(len(example.words)):
53 | word = example.words[i]
54 | if has_label:
55 | label = example.tags[i]
56 | word_tokens = tokenizer.tokenize(word)
57 |
58 | if len(word_tokens) > 0:
59 | tok_to_orig_index.append(i)
60 | tokens.extend(word_tokens)
61 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens
62 | if has_label:
63 | label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))
64 |
65 | special_tokens_count = tokenizer.num_special_tokens
66 | if len(tokens) > max_length - special_tokens_count:
67 | tokens = tokens[: (max_length - special_tokens_count)]
68 | label_ids = label_ids[: (max_length - special_tokens_count)]
69 |
70 | tokens += [tokenizer.sep_token]
71 |
72 | if has_label:
73 | label_ids += [pad_token_label_id]
74 |
75 | token_type_ids = [0] * len(tokens)
76 |
77 | tokens = [tokenizer.cls_token] + tokens
78 | if has_label:
79 | label_ids = [pad_token_label_id] + label_ids
80 | token_type_ids = [0] + token_type_ids
81 |
82 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
83 |
84 | attention_mask = [1] * len(input_ids)
85 |
86 | # Zero-pad up to the sequence length.
87 | padding_length = max_length - len(input_ids)
88 |
89 | input_ids += [tokenizer.pad_token_id] * padding_length
90 | attention_mask += [0] * padding_length
91 | token_type_ids += [tokenizer.pad_token_type_id] * padding_length
92 | if has_label:
93 | label_ids += [pad_token_label_id] * padding_length
94 |
95 | if isinstance(tokenizer, GlyceBertTokenizer):
96 | pinyin_ids = tokenizer.convert_token_ids_to_pinyin_ids(input_ids)
97 | else:
98 | pinyin_ids = None
99 |
100 | assert len(input_ids) == max_length
101 | assert len(attention_mask) == max_length
102 | assert len(token_type_ids) == max_length
103 | if has_label:
104 | assert len(label_ids) == max_length
105 |
106 | return InputFeature(
107 | guid=str(0),
108 | input_ids=input_ids,
109 | attention_mask=attention_mask,
110 | token_type_ids=token_type_ids,
111 | pinyin_ids=pinyin_ids,
112 | label_ids=label_ids if has_label else None,
113 | tok_to_orig_index=tok_to_orig_index
114 | )
115 |
116 |
117 | def convert_example_to_feature_init(tokenizer_for_convert):
118 | global tokenizer
119 | tokenizer = tokenizer_for_convert
120 |
121 |
122 | def convert_examples_to_features(
123 | examples: List[InputExample],
124 | tokenizer,
125 | max_length=512,
126 | label_list=None,
127 | set_type='train',
128 | pad_token_label_id=0,
129 | threads=1
130 | ) -> List[InputFeature]:
131 | label_map = None
132 | if label_list is not None:
133 | label_map = {label: i for i, label in enumerate(label_list)}
134 | annotate_ = partial(
135 | convert_example_to_feature,
136 | max_length=max_length,
137 | label_map=label_map,
138 | pad_token_label_id=pad_token_label_id
139 | )
140 |
141 | if threads > 1:
142 | features = multiple_convert_examples_to_features(
143 | examples,
144 | annotate_=annotate_,
145 | initializer=convert_example_to_feature_init,
146 | initargs=(tokenizer,),
147 | threads=threads
148 | )
149 | else:
150 | convert_example_to_feature_init(tokenizer)
151 | features = single_example_to_features(
152 | examples, annotate_=annotate_
153 | )
154 | new_features = []
155 | i = 0
156 | for feature in features:
157 | feature.guid = set_type + '-' + str(i)
158 | new_features.append(feature)
159 | return new_features
160 |
--------------------------------------------------------------------------------
/tfbert/metric/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Project Name: my_project
4 | File Name: __init__.py
5 | date: 2020/9/13
6 | author: HuangHui
7 | """
8 |
--------------------------------------------------------------------------------
/tfbert/metric/dureader2021.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | # __author__ = 'huanghui'
3 | # __date__ = '2021/5/16 22:16'
4 | # __project__ = 'tfbert'
5 |
6 | """
7 | dureader2021 的评估函数
8 | """
9 | from __future__ import print_function
10 | from collections import OrderedDict
11 | import io
12 | import json
13 | import six
14 | import sys
15 |
16 | if six.PY2:
17 | reload(sys)
18 | sys.setdefaultencoding('utf8')
19 | import argparse
20 |
21 |
22 | def _tokenize_chinese_chars(text):
23 | """
24 | :param text: input text, unicode string
25 | :return:
26 | tokenized text, list
27 | """
28 |
29 | def _is_chinese_char(cp):
30 | """Checks whether CP is the codepoint of a CJK character."""
31 | # This defines a "chinese character" as anything in the CJK Unicode block:
32 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
33 | #
34 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
35 | # despite its name. The modern Korean Hangul alphabet is a different block,
36 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
37 | # space-separated words, so they are not treated specially and handled
38 | # like the all of the other languages.
39 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
40 | (cp >= 0x3400 and cp <= 0x4DBF) or #
41 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
42 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
43 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
44 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
45 | (cp >= 0xF900 and cp <= 0xFAFF) or #
46 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
47 | return True
48 |
49 | return False
50 |
51 | output = []
52 | buff = ""
53 | for char in text:
54 | cp = ord(char)
55 | if _is_chinese_char(cp) or char == "=":
56 | if buff != "":
57 | output.append(buff)
58 | buff = ""
59 | output.append(char)
60 | else:
61 | buff += char
62 |
63 | if buff != "":
64 | output.append(buff)
65 |
66 | return output
67 |
68 |
69 | def _normalize(in_str):
70 | """
71 | normalize the input unicode string
72 | """
73 | in_str = in_str.lower()
74 | sp_char = [
75 | u':', u'_', u'`', u',', u'。', u':', u'?', u'!', u'(', u')',
76 | u'“', u'”', u';', u'’', u'《', u'》', u'……', u'·', u'、', u',',
77 | u'「', u'」', u'(', u')', u'-', u'~', u'『', u'』', '|'
78 | ]
79 | out_segs = []
80 | for char in in_str:
81 | if char in sp_char:
82 | continue
83 | else:
84 | out_segs.append(char)
85 | return ''.join(out_segs)
86 |
87 |
88 | def find_lcs(s1, s2):
89 | """find the longest common subsequence between s1 ans s2"""
90 | m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
91 | max_len = 0
92 | p = 0
93 | for i in range(len(s1)):
94 | for j in range(len(s2)):
95 | if s1[i] == s2[j]:
96 | m[i + 1][j + 1] = m[i][j] + 1
97 | if m[i + 1][j + 1] > max_len:
98 | max_len = m[i + 1][j + 1]
99 | p = i + 1
100 | return s1[p - max_len:p], max_len
101 |
102 |
103 | def evaluate(ref_ans, pred_ans, verbose=False):
104 | """
105 | ref_ans: reference answers, dict
106 | pred_ans: predicted answer, dict
107 | return:
108 | f1_score: averaged F1 score
109 | em_score: averaged EM score
110 | total_count: number of samples in the reference dataset
111 | skip_count: number of samples skipped in the calculation due to unknown errors
112 | """
113 | f1 = 0
114 | em = 0
115 | total_count = 0
116 | skip_count = 0
117 | for query_id, sample in ref_ans.items():
118 | total_count += 1
119 | para = sample['para']
120 | query_text = sample['question']
121 | title = sample['title']
122 | answers = sample['answers']
123 | is_impossible = sample['is_impossible']
124 | try:
125 | prediction = pred_ans[str(query_id)]
126 | except:
127 | skip_count += 1
128 | if verbose:
129 | print("para: {}".format(para))
130 | print("query: {}".format(query_text))
131 | print("ref: {}".format('#'.join(answers)))
132 | print("Skipped")
133 | print('----------------------------')
134 | continue
135 | if is_impossible:
136 | if prediction.lower() == 'no answer':
137 | _f1 = 1.0
138 | _em = 1.0
139 | else:
140 | _f1 = 0.0
141 | _em = 0.0
142 | else:
143 | _f1 = calc_f1_score(answers, prediction)
144 | _em = calc_em_score(answers, prediction)
145 | f1 += _f1
146 | em += _em
147 | if verbose:
148 | print("para: {}".format(para))
149 | print("query: {}".format(query_text))
150 | print("title: {}".format(title))
151 | print("ref: {}".format('#'.join(answers)))
152 | print("cand: {}".format(prediction))
153 | print("score: {}".format(_f1))
154 | print('----------------------------')
155 |
156 | f1_score = 100.0 * f1 / total_count
157 | em_score = 100.0 * em / total_count
158 | return f1_score, em_score, total_count, skip_count
159 |
160 |
161 | def calc_f1_score(answers, prediction):
162 | f1_scores = []
163 | for ans in answers:
164 | ans_segs = _tokenize_chinese_chars(_normalize(ans))
165 | prediction_segs = _tokenize_chinese_chars(_normalize(prediction))
166 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
167 | if lcs_len == 0:
168 | f1_scores.append(0)
169 | continue
170 | prec = 1.0 * lcs_len / len(prediction_segs)
171 | rec = 1.0 * lcs_len / len(ans_segs)
172 | f1 = (2 * prec * rec) / (prec + rec)
173 | f1_scores.append(f1)
174 | return max(f1_scores)
175 |
176 |
177 | def calc_em_score(answers, prediction):
178 | em = 0
179 | for ans in answers:
180 | ans_ = _normalize(ans)
181 | prediction_ = _normalize(prediction)
182 | if ans_ == prediction_:
183 | em = 1
184 | break
185 | return em
186 |
187 |
188 | def read_mrc_dataset(filename, tag=None):
189 | dataset = OrderedDict()
190 | with io.open(filename, encoding='utf-8') as fin:
191 | mrc_dataset = json.load(fin)
192 | for document in mrc_dataset['data']:
193 | for paragraph in document['paragraphs']:
194 | para = paragraph['context'].strip()
195 | title = ''
196 | if 'title' in paragraph:
197 | title = paragraph['title']
198 | for qa in (paragraph['qas']):
199 | query_id = qa['id']
200 | query_text = qa['question'].strip()
201 | answers = [a['text'] for a in qa['answers']]
202 | if tag is not None:
203 | if not qa['type'].startswith(tag):
204 | continue
205 | is_impossible = False
206 | if 'is_impossible' in qa:
207 | is_impossible = qa['is_impossible']
208 | if is_impossible:
209 | answers = ['no answer']
210 | dataset[query_id] = {
211 | 'answers': answers,
212 | 'question': query_text,
213 | 'para': para,
214 | 'is_impossible': is_impossible,
215 | 'title': title
216 | }
217 | return dataset
218 |
219 |
220 | def read_model_prediction(filename):
221 | with io.open(filename, encoding='utf-8') as fin:
222 | model_prediction = json.load(fin)
223 | return model_prediction
224 |
225 |
226 | def metric(predictions, gold_file: str, dict_report=False):
227 | """
228 | dureader使用的评估函数
229 | :param predictions: 预测结果(字典)或者预测文件地址
230 | :param gold_file: 标准答案数据文件
231 | :param dict_report: 是否返回字典形式结果
232 | :return:
233 | """
234 | ref_ans = read_mrc_dataset(gold_file, tag=None)
235 | if isinstance(predictions, str):
236 | pred_ans = json.load(io.open(predictions, encoding='utf-8'))
237 | elif isinstance(predictions, dict):
238 | pred_ans = predictions
239 | else:
240 | raise ValueError("Please input the file name or prediction result")
241 |
242 | F1, EM, TOTAL, SKIP = evaluate(ref_ans, pred_ans, False)
243 | output_result = OrderedDict()
244 | output_result['F1'] = F1
245 | output_result['EM'] = EM
246 | output_result['TOTAL'] = TOTAL
247 | output_result['SKIP'] = SKIP
248 | report = json.dumps(output_result, ensure_ascii=False, indent=4)
249 | if dict_report:
250 | return report, output_result
251 | return report
252 |
--------------------------------------------------------------------------------
/tfbert/metric/multi_label.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: multi_label.py
5 | @date: 2020/09/15
6 | """
7 |
8 |
9 | def multi_label_metric(y_true, y_pred, label_list, dict_report=False):
10 | '''
11 | 多标签文本分类的评估函数
12 | :param y_true: 正确标签, one hot 类型,[[0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]
13 | :param y_pred: 预测标签,one hot 类型,同上
14 | :param label_list: 标签列表,顺序对应one hot 位置
15 | :param dict_report:
16 | :return:
17 | '''
18 |
19 | def get_value(res):
20 | if res["TP"] == 0:
21 | if res["FP"] == 0 and res["FN"] == 0:
22 | precision = 1.0
23 | recall = 1.0
24 | f1 = 1.0
25 | else:
26 | precision = 0.0
27 | recall = 0.0
28 | f1 = 0.0
29 | else:
30 | precision = 1.0 * res["TP"] / (res["TP"] + res["FP"])
31 | recall = 1.0 * res["TP"] / (res["TP"] + res["FN"])
32 | f1 = 2 * precision * recall / (precision + recall)
33 | return precision, recall, f1
34 |
35 | result = {}
36 | for i in range(len(label_list)):
37 | result[i] = {"TP": 0, "FP": 0, "TN": 0, "FN": 0}
38 |
39 | for true, pred in zip(y_true, y_pred):
40 | for a in range(len(label_list)):
41 | in1 = pred[a] == 1
42 | in2 = true[a] == 1
43 | if in1:
44 | if in2:
45 | result[a]["TP"] += 1
46 | else:
47 | result[a]["FP"] += 1
48 | else:
49 | if in2:
50 | result[a]["FN"] += 1
51 | else:
52 | result[a]["TN"] += 1
53 |
54 | final_result = {}
55 |
56 | # 格式化输出
57 | headers = ["precision", "recall", "f1-score"]
58 | target_names = ['%s' % l for l in label_list]
59 | head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers)
60 | longest_last_line_heading = 'micro macro avg'
61 | name_width = max(len(cn) for cn in target_names)
62 | width = max(name_width, len(longest_last_line_heading), 4)
63 | report = head_fmt.format('', *headers, width=width)
64 | report += '\n\n'
65 |
66 | y = {"TP": 0, "FP": 0, "FN": 0, "TN": 0}
67 | sumf = 0
68 | sump = 0
69 | sumr = 0
70 | row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + '\n'
71 | for i, label in enumerate(label_list):
72 | p, r, f = get_value(result[i])
73 | final_result[label] = {"precision": p, "recall": r, "f1-score": f} # 每个类别下的p r f值
74 | report += row_fmt.format(*[label, p, r, f], width=width, digits=4) # 每个类别的string 输出
75 |
76 | sumf += f
77 | sump += p
78 | sumr += r
79 | for z in result[i].keys():
80 | y[z] += result[i][z] # 累积总的tp、fp、fn、tn,计算微平均
81 |
82 | report += '\n'
83 |
84 | micro_p, micro_r, micro_f = get_value(y)
85 | macro_p = sump * 1.0 / len(result)
86 | macro_r = sumr * 1.0 / len(result)
87 | macro_f = sumf * 1.0 / len(result)
88 |
89 | average_micro_macro_f = (macro_f + micro_f) / 2.0 # 这是法研杯要素抽取的评价指标,两者平均
90 | average_micro_macro_p = (macro_p + micro_p) / 2.0
91 | average_micro_macro_r = (macro_r + micro_r) / 2.0
92 |
93 | final_result['macro avg'] = {"precision": macro_p, "recall": macro_r, "f1-score": macro_f}
94 | report += row_fmt.format(*['macro avg', macro_p, macro_r, macro_f], width=width, digits=4)
95 |
96 | final_result['micro avg'] = {"precision": micro_p, "recall": micro_r, "f1-score": micro_f}
97 | report += row_fmt.format(*['micro avg', micro_p, micro_r, micro_f], width=width, digits=4)
98 |
99 | final_result['micro macro avg'] = {"precision": average_micro_macro_p,
100 | "recall": average_micro_macro_r,
101 | "f1-score": average_micro_macro_f}
102 |
103 | report += row_fmt.format(*['micro macro avg', average_micro_macro_p,
104 | average_micro_macro_r, average_micro_macro_f], width=width, digits=4)
105 |
106 | if dict_report:
107 | return report, final_result
108 | return report
109 |
--------------------------------------------------------------------------------
/tfbert/metric/ner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Project Name: my_project
4 | File Name: ner
5 | date: 2020/9/13
6 | author: HuangHui
7 | """
8 | import numpy as np
9 | from collections import defaultdict
10 |
11 |
12 | def get_entities(seq, suffix=False):
13 | """Gets entities from sequence.
14 |
15 | Args:
16 | seq (list): sequence of labels.
17 |
18 | Returns:
19 | list: list of (chunk_type, chunk_start, chunk_end).
20 |
21 | Example:
22 | seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
23 | get_entities(seq)
24 | [('PER', 0, 1), ('LOC', 3, 3)]
25 | """
26 | # for nested list
27 | if any(isinstance(s, list) for s in seq):
28 | seq = [item for sublist in seq for item in sublist + ['O']]
29 | prev_tag = 'O'
30 | prev_type = ''
31 | begin_offset = 0
32 | chunks = []
33 | for i, chunk in enumerate(seq + ['O']):
34 | if suffix:
35 | tag = chunk[-1]
36 | type_ = chunk.split('-')[0]
37 | else:
38 | tag = chunk[0]
39 | type_ = chunk.split('-')[-1]
40 |
41 | if end_of_chunk(prev_tag, tag, prev_type, type_):
42 | chunks.append((prev_type, begin_offset, i - 1))
43 | if start_of_chunk(prev_tag, tag, prev_type, type_):
44 | begin_offset = i
45 | prev_tag = tag
46 | prev_type = type_
47 |
48 | return chunks
49 |
50 |
51 | def end_of_chunk(prev_tag, tag, prev_type, type_):
52 | """Checks if a chunk ended between the previous and current word.
53 |
54 | Args:
55 | prev_tag: previous chunk tag.
56 | tag: current chunk tag.
57 | prev_type: previous type.
58 | type_: current type.
59 |
60 | Returns:
61 | chunk_end: boolean.
62 | """
63 | chunk_end = False
64 |
65 | if prev_tag == 'E': chunk_end = True
66 | if prev_tag == 'S': chunk_end = True
67 |
68 | if prev_tag == 'B' and tag == 'B': chunk_end = True
69 | if prev_tag == 'B' and tag == 'S': chunk_end = True
70 | if prev_tag == 'B' and tag == 'O': chunk_end = True
71 | if prev_tag == 'I' and tag == 'B': chunk_end = True
72 | if prev_tag == 'I' and tag == 'S': chunk_end = True
73 | if prev_tag == 'I' and tag == 'O': chunk_end = True
74 |
75 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
76 | chunk_end = True
77 |
78 | return chunk_end
79 |
80 |
81 | def start_of_chunk(prev_tag, tag, prev_type, type_):
82 | """Checks if a chunk started between the previous and current word.
83 |
84 | Args:
85 | prev_tag: previous chunk tag.
86 | tag: current chunk tag.
87 | prev_type: previous type.
88 | type_: current type.
89 |
90 | Returns:
91 | chunk_start: boolean.
92 | """
93 | chunk_start = False
94 |
95 | if tag == 'B': chunk_start = True
96 | if tag == 'S': chunk_start = True
97 |
98 | if prev_tag == 'E' and tag == 'E': chunk_start = True
99 | if prev_tag == 'E' and tag == 'I': chunk_start = True
100 | if prev_tag == 'S' and tag == 'E': chunk_start = True
101 | if prev_tag == 'S' and tag == 'I': chunk_start = True
102 | if prev_tag == 'O' and tag == 'E': chunk_start = True
103 | if prev_tag == 'O' and tag == 'I': chunk_start = True
104 |
105 | if tag != 'O' and tag != '.' and prev_type != type_:
106 | chunk_start = True
107 |
108 | return chunk_start
109 |
110 |
111 | def prf_score(y_true, y_pred, suffix=False):
112 | true_entities = set(get_entities(y_true, suffix))
113 | pred_entities = set(get_entities(y_pred, suffix))
114 |
115 | nb_correct = len(true_entities & pred_entities)
116 | nb_pred = len(pred_entities)
117 | nb_true = len(true_entities)
118 |
119 | p = 100 * nb_correct / nb_pred if nb_pred > 0 else 0
120 | r = 100 * nb_correct / nb_true if nb_true > 0 else 0
121 | score = 2 * p * r / (p + r) if p + r > 0 else 0
122 |
123 | return p, r, score
124 |
125 |
126 | def accuracy_score(y_true, y_pred):
127 | """Accuracy classification score.
128 |
129 | In multilabel classification, this function computes subset accuracy:
130 | the set of labels predicted for a sample must *exactly* match the
131 | corresponding set of labels in y_true.
132 |
133 | Args:
134 | y_true : 2d array. Ground truth (correct) target values.
135 | y_pred : 2d array. Estimated targets as returned by a tagger.
136 |
137 | Returns:
138 | score : float.
139 |
140 | Example:
141 | >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
142 | >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
143 | >>> accuracy_score(y_true, y_pred)
144 | 0.80
145 | """
146 | if any(isinstance(s, list) for s in y_true):
147 | y_true = [item for sublist in y_true for item in sublist]
148 | y_pred = [item for sublist in y_pred for item in sublist]
149 |
150 | nb_correct = sum(y_t == y_p for y_t, y_p in zip(y_true, y_pred))
151 | nb_true = len(y_true)
152 |
153 | score = nb_correct / nb_true
154 |
155 | return score
156 |
157 |
158 | def ner_metric(y_true, y_pred, digits=4, dict_report=False):
159 | """Build a text report showing the main classification metrics.
160 |
161 | Args:
162 | y_true : 2d array. Ground truth (correct) target values.
163 | y_pred : 2d array. Estimated targets as returned by a classifier.
164 | digits : int. Number of digits for formatting output floating point values.
165 | dict_report :
166 | Returns:
167 | report : string. Text summary of the precision, recall, F1 score for each class.
168 |
169 | Examples:
170 | >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
171 | >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
172 | >>> print(ner_metric(y_true, y_pred))
173 | precision recall f1-score support
174 |
175 | MISC 0.00 0.00 0.00 1
176 | PER 1.00 1.00 1.00 1
177 |
178 | micro avg 0.50 0.50 0.50 2
179 | micro avg 0.50 0.50 0.50 2
180 | weighted avg 0.50 0.50 0.50 2
181 |
182 | """
183 | true_entities = set(get_entities(y_true))
184 | pred_entities = set(get_entities(y_pred))
185 |
186 | name_width = 0
187 | d1 = defaultdict(set)
188 | d2 = defaultdict(set)
189 | for e in true_entities:
190 | d1[e[0]].add((e[1], e[2]))
191 | name_width = max(name_width, len(e[0]))
192 | for e in pred_entities:
193 | d2[e[0]].add((e[1], e[2]))
194 |
195 | last_longest_line_heading = 'weighted avg'
196 | width = max(name_width, len(last_longest_line_heading), digits)
197 |
198 | headers = ["precision", "recall", "f1-score", "support"]
199 | head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
200 | report = head_fmt.format(u'', *headers, width=width)
201 | report += u'\n\n'
202 |
203 | row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'
204 |
205 | ps, rs, f1s, s = [], [], [], []
206 | dict_result = {}
207 | all_correct = 0
208 | all_pred = 0
209 | all_true = 0
210 | for type_name, true_entities in d1.items():
211 | pred_entities = d2[type_name]
212 | nb_correct = len(true_entities & pred_entities)
213 | nb_pred = len(pred_entities)
214 | nb_true = len(true_entities)
215 |
216 | all_correct += nb_correct
217 | all_pred += nb_pred
218 | all_true += nb_true
219 |
220 | p = 1.0 * nb_correct / nb_pred if nb_pred > 0 else 0
221 | r = 1.0 * nb_correct / nb_true if nb_true > 0 else 0
222 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0
223 |
224 | report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)
225 | dict_result[type_name] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true}
226 |
227 | ps.append(p)
228 | rs.append(r)
229 | f1s.append(f1)
230 | s.append(nb_true)
231 |
232 | report += u'\n'
233 |
234 | # compute averages
235 | macro_p = 1.0 * all_correct / all_pred if all_pred > 0 else 0
236 | macro_r = 1.0 * all_correct / all_true if all_true > 0 else 0
237 | macro_f = 2 * macro_p * macro_r / (macro_p + macro_r) if macro_p + macro_r > 0 else 0
238 |
239 | micro_p = float(np.average(ps))
240 | micro_r = float(np.average(rs))
241 | micro_f = float(np.average(f1s))
242 |
243 | weighted_p = float(np.average(ps, weights=s))
244 | weighted_r = float(np.average(rs, weights=s))
245 | weighted_f = float(np.average(f1s, weights=s))
246 |
247 | avg_support = float(np.sum(s))
248 |
249 | avg_reports = [["micro avg", micro_p, micro_r, micro_f, avg_support],
250 | ["macro avg", macro_p, macro_r, macro_f, avg_support],
251 | ["weighted avg", weighted_p, weighted_r, weighted_f, avg_support]]
252 |
253 | for i in range(len(avg_reports)):
254 | report += row_fmt.format(*avg_reports[i],
255 | width=width, digits=digits)
256 | dict_result[avg_reports[i][0]] = {'precision': avg_reports[i][1], 'recall': avg_reports[i][2],
257 | 'f1-score': avg_reports[i][3], 'support': avg_reports[i][4]}
258 |
259 | if dict_report:
260 | return report, dict_result
261 | return report
262 |
--------------------------------------------------------------------------------
/tfbert/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :__init__.py.py
3 | # @Time :2021/1/31 15:20
4 | # @Author :huanghui
5 |
6 | from .bert import BertModel
7 | from .bert import BertModel as WoBertModel
8 | from .albert import ALBertModel
9 | from .electra import ElectraModel
10 | from .nezha import NezhaModel
11 | from .glyce_bert import GlyceBertModel
12 | from .model_utils import (
13 | dropout, layer_norm_and_dropout, layer_norm,
14 | create_weight, get_shape_list, gather_indexes, create_initializer)
15 |
16 | from .embeddings import (create_word_embeddings, create_position_embeddings, create_token_type_embeddings)
17 | from . import crf
18 |
19 |
20 | MODELS = {
21 | 'bert': BertModel,
22 | 'albert': ALBertModel,
23 | 'electra': ElectraModel,
24 | 'wobert': WoBertModel,
25 | 'nezha': NezhaModel,
26 | 'glyce_bert': GlyceBertModel
27 | }
28 |
29 | from .for_task import (
30 | SequenceClassification, TokenClassification,
31 | MultiLabelClassification, MaskedLM, PretrainingLM,
32 | QuestionAnswering)
33 |
34 |
--------------------------------------------------------------------------------
/tfbert/models/activations.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :activations.py
3 | # @Time :2021/1/31 15:34
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | import numpy as np
7 | import six
8 |
9 |
10 | def gelu(x):
11 | """Gaussian Error Linear Unit.
12 |
13 | This is a smoother version of the RELU.
14 | Original paper: https://arxiv.org/abs/1606.08415
15 | Args:
16 | x: float Tensor to perform activation.
17 |
18 | Returns:
19 | `x` with the GELU activation applied.
20 | """
21 | cdf = 0.5 * (1.0 + tf.tanh(
22 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
23 | return x * cdf
24 |
25 |
26 | def get_activation(activation_string):
27 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
28 |
29 | Args:
30 | activation_string: String name of the activation function.
31 |
32 | Returns:
33 | A Python function corresponding to the activation function. If
34 | `activation_string` is None, empty, or "linear", this will return None.
35 | If `activation_string` is not a string, it will return `activation_string`.
36 |
37 | Raises:
38 | ValueError: The `activation_string` does not correspond to a known
39 | activation.
40 | """
41 |
42 | # We assume that anything that"s not a string is already an activation
43 | # function, so we just return it.
44 | if not isinstance(activation_string, six.string_types):
45 | return activation_string
46 |
47 | if not activation_string:
48 | return None
49 |
50 | act = activation_string.lower()
51 | if act == "linear":
52 | return None
53 | elif act == "relu":
54 | return tf.nn.relu
55 | elif act == "gelu":
56 | return gelu
57 | elif act == "tanh":
58 | return tf.tanh
59 | else:
60 | raise ValueError("Unsupported activation: %s" % act)
61 |
--------------------------------------------------------------------------------
/tfbert/models/albert.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :albert.py
3 | # @Time :2021/1/31 19:34
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | from .bert import bert_embedding as albert_embedding
7 | from .base import BaseModel
8 | from . import model_utils, layers, activations
9 | from ..config import ALBertConfig
10 |
11 |
12 | def albert_layer(input_tensor, attention_mask, config: ALBertConfig):
13 | with tf.variable_scope("attention_1"):
14 | with tf.variable_scope("self"):
15 | attention_layer_outputs = layers.attention(
16 | input_tensor,
17 | attention_mask=attention_mask,
18 | hidden_size=config.hidden_size,
19 | num_attention_heads=config.num_attention_heads,
20 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
21 | initializer_range=config.initializer_range,
22 | use_relative_position=False, # nezha使用的相对位置
23 | do_return_attentions_probs=config.output_attentions
24 | )
25 |
26 | attention_output = attention_layer_outputs[0]
27 |
28 | with tf.variable_scope("output"):
29 | attention_output = layers.attention_output_layer(
30 | input_tensor=attention_output,
31 | hidden_size=config.hidden_size,
32 | initializer_range=config.initializer_range,
33 | hidden_dropout_prob=config.hidden_dropout_prob
34 | )
35 |
36 | # albert 的 layer norm 所在命名空间 和 bert 不一致
37 | attention_output = model_utils.layer_norm(attention_output + input_tensor)
38 | with tf.variable_scope("ffn_1"):
39 | with tf.variable_scope("intermediate"):
40 | intermediate_output = layers.intermediate_layer(
41 | attention_output,
42 | config.intermediate_size,
43 | activations.get_activation(config.hidden_act),
44 | config.initializer_range
45 | )
46 |
47 | with tf.variable_scope("output"):
48 | layer_output = layers.attention_output_layer(
49 | input_tensor=intermediate_output,
50 | hidden_size=config.hidden_size,
51 | initializer_range=config.initializer_range,
52 | hidden_dropout_prob=config.hidden_dropout_prob
53 | )
54 |
55 | # albert 的 layer norm 所在命名空间 和 bert 不一致
56 | layer_output = model_utils.layer_norm(layer_output + attention_output)
57 | if config.output_attentions:
58 | outputs = (layer_output, attention_layer_outputs[1])
59 | else:
60 | outputs = (layer_output,)
61 | return outputs
62 |
63 |
64 | def albert_encoder(input_tensor, attention_mask, config: ALBertConfig):
65 | # The Transformer performs sum residuals on all layers so the input needs
66 | # to be the same as the hidden size.
67 | if input_tensor.shape[-1] != config.hidden_size:
68 | input_tensor = layers.dense(
69 | input_tensor, config.hidden_size,
70 | name="embedding_hidden_mapping_in",
71 | initializer_range=config.initializer_range
72 | )
73 |
74 | all_layer_outputs = []
75 | all_layer_attention_probs = []
76 | prev_output = input_tensor
77 |
78 | # albert 共享transformer参数,所以需要reuse
79 | with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE):
80 | for layer_idx in range(config.num_hidden_layers):
81 | group_idx = int(layer_idx / config.num_hidden_layers * config.num_hidden_groups)
82 |
83 | with tf.variable_scope("group_%d" % group_idx):
84 | with tf.name_scope("layer_%d" % layer_idx):
85 | for inner_group_idx in range(config.inner_group_num):
86 | with tf.variable_scope("inner_group_%d" % inner_group_idx):
87 | layer_output = albert_layer(prev_output, attention_mask, config)
88 | prev_output = layer_output[0]
89 | if config.output_hidden_states:
90 | all_layer_outputs.append(layer_output[0])
91 | if config.output_attentions:
92 | all_layer_attention_probs.append(layer_output[1])
93 |
94 | outputs = (prev_output,)
95 |
96 | if config.output_hidden_states:
97 | outputs = outputs + (all_layer_outputs,)
98 |
99 | if config.output_attentions:
100 | outputs = outputs + (all_layer_attention_probs,)
101 | return outputs # (last layer output, all layer outputs, all layer att probs)
102 |
103 |
104 | class ALBertModel(BaseModel):
105 | def __init__(
106 | self,
107 | config,
108 | is_training,
109 | input_ids,
110 | attention_mask=None,
111 | token_type_ids=None,
112 | return_pool=True,
113 | scope=None,
114 | reuse=False,
115 | compute_type=tf.float32):
116 | super().__init__(config, is_training)
117 |
118 | input_shape = model_utils.get_shape_list(input_ids)
119 | batch_size = input_shape[0]
120 | seq_length = input_shape[1]
121 |
122 | if attention_mask is None:
123 | attention_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
124 |
125 | if token_type_ids is None:
126 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
127 |
128 | with tf.variable_scope(
129 | scope, default_name="bert",
130 | reuse=tf.AUTO_REUSE if reuse else None,
131 | custom_getter=model_utils.get_custom_getter(compute_type)):
132 | with tf.variable_scope("embeddings"):
133 | self.embedding_output, self.embedding_table = albert_embedding(
134 | config=self.config,
135 | input_ids=input_ids,
136 | token_type_ids=token_type_ids,
137 | )
138 |
139 | with tf.variable_scope("encoder"):
140 | attention_mask = model_utils.create_bert_mask(
141 | input_ids, attention_mask)
142 | encoder_outputs = albert_encoder(
143 | config=self.config,
144 | input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
145 | attention_mask=attention_mask
146 | )
147 | if return_pool:
148 | with tf.variable_scope("pooler"):
149 | pooled_output = layers.pooler_layer(
150 | sequence_output=encoder_outputs[0],
151 | hidden_size=self.config.hidden_size,
152 | initializer_range=self.config.initializer_range
153 | )
154 | else:
155 | pooled_output = None
156 | # (pooled output, sequence output, all layer outputs, all layer att probs)
157 | self.outputs = (pooled_output,) + encoder_outputs
158 |
--------------------------------------------------------------------------------
/tfbert/models/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :tokenization_base.py
3 | # @Time :2021/1/31 17:55
4 | # @Author :huanghui
5 | import copy
6 |
7 |
8 | class BaseModel(object):
9 | def __init__(self, config, is_training):
10 |
11 | self.config = copy.deepcopy(config)
12 | if not is_training:
13 | self.config.hidden_dropout_prob = 0.0
14 | self.config.attention_probs_dropout_prob = 0.0
15 |
16 | self.outputs = ()
17 | self.embedding_output = None
18 | self.embedding_table = None
19 |
20 | def get_pooled_output(self):
21 | return self.outputs[0]
22 |
23 | def get_sequence_output(self):
24 | return self.outputs[1]
25 |
26 | def get_outputs(self):
27 | return self.outputs
28 |
29 | def get_all_encoder_layers(self):
30 | if self.config.output_hidden_states:
31 | return self.outputs[2]
32 | else:
33 | raise ValueError('Please set {} with value {} for config'.format('output_hidden_states', 'True'))
34 |
35 | def get_all_attention_probs(self):
36 | if self.config.output_attentions:
37 | return self.outputs[3]
38 | else:
39 | raise ValueError('Please set {} with value {} for config'.format('output_attentions', 'True'))
40 |
41 | def get_embedding_output(self):
42 | return self.embedding_output
43 |
44 | def get_embedding_table(self):
45 | return self.embedding_table
46 |
--------------------------------------------------------------------------------
/tfbert/models/bert.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :bert.py
3 | # @Time :2021/1/31 17:19
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | from . import embeddings, layers, model_utils, activations
7 | from .base import BaseModel
8 |
9 |
10 | def bert_embedding(
11 | config,
12 | input_ids,
13 | token_type_ids=None,
14 | add_position_embedding=True
15 | ):
16 | (embedding_output, embedding_table) = embeddings.create_word_embeddings(
17 | input_ids=input_ids,
18 | vocab_size=config.vocab_size,
19 | embedding_size=config.embedding_size,
20 | initializer_range=config.initializer_range,
21 | word_embedding_name="word_embeddings"
22 | )
23 |
24 | token_type_embeddings = embeddings.create_token_type_embeddings(
25 | token_type_ids=token_type_ids,
26 | embedding_size=config.embedding_size,
27 | token_type_vocab_size=config.type_vocab_size,
28 | token_type_embedding_name='token_type_embeddings',
29 | initializer_range=config.initializer_range
30 | )
31 | embedding_output += token_type_embeddings
32 |
33 | if add_position_embedding:
34 | position_embeddings = embeddings.create_position_embeddings(
35 | seq_len=model_utils.get_shape_list(input_ids)[1],
36 | embedding_size=config.embedding_size,
37 | position_embedding_name='position_embeddings',
38 | initializer_range=config.initializer_range,
39 | max_position_embeddings=config.max_position_embeddings
40 | )
41 | embedding_output += position_embeddings
42 |
43 | embedding_output = model_utils.layer_norm_and_dropout(
44 | embedding_output,
45 | config.hidden_dropout_prob
46 | )
47 | return embedding_output, embedding_table
48 |
49 |
50 | def bert_layer(input_tensor, attention_mask, config, use_relative_position=False):
51 | with tf.variable_scope("attention"):
52 | # 多头自注意力层
53 | with tf.variable_scope("self"):
54 | attention_layer_outputs = layers.attention(
55 | input_tensor,
56 | attention_mask=attention_mask,
57 | hidden_size=config.hidden_size,
58 | num_attention_heads=config.num_attention_heads,
59 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
60 | initializer_range=config.initializer_range,
61 | use_relative_position=use_relative_position, # nezha使用的相对位置
62 | do_return_attentions_probs=config.output_attentions
63 | )
64 | attention_output = attention_layer_outputs[0]
65 |
66 | # Run a linear projection of `hidden_size` then add a residual
67 | # with `layer_input`.
68 | # attention输出层
69 | with tf.variable_scope("output"):
70 | attention_output = layers.attention_output_layer(
71 | input_tensor=attention_output,
72 | hidden_size=config.hidden_size,
73 | initializer_range=config.initializer_range,
74 | hidden_dropout_prob=config.hidden_dropout_prob
75 | )
76 | attention_output = model_utils.layer_norm(attention_output + input_tensor)
77 |
78 | # transformer中间层
79 | # The activation is only applied to the "intermediate" hidden layer.
80 | with tf.variable_scope("intermediate"):
81 | intermediate_output = layers.intermediate_layer(
82 | attention_output,
83 | config.intermediate_size,
84 | activations.get_activation(config.hidden_act),
85 | config.initializer_range
86 | )
87 |
88 | # Down-project back to `hidden_size` then add the residual.
89 | # transformer 输出层
90 | with tf.variable_scope("output"):
91 | layer_output = layers.attention_output_layer(
92 | input_tensor=intermediate_output,
93 | hidden_size=config.hidden_size,
94 | initializer_range=config.initializer_range,
95 | hidden_dropout_prob=config.hidden_dropout_prob
96 | )
97 | layer_output = model_utils.layer_norm(layer_output + attention_output)
98 | if config.output_attentions:
99 | outputs = (layer_output, attention_layer_outputs[1])
100 | else:
101 | outputs = (layer_output,)
102 | return outputs
103 |
104 |
105 | def bert_encoder(input_tensor, attention_mask,
106 | config, use_relative_position=False):
107 | # The Transformer performs sum residuals on all layers so the input needs
108 | # to be the same as the hidden size.
109 |
110 | all_layer_outputs = []
111 | all_layer_attention_probs = []
112 | prev_output = input_tensor
113 | for layer_idx in range(config.num_hidden_layers):
114 | with tf.variable_scope("layer_%d" % layer_idx):
115 |
116 | layer_output = bert_layer(
117 | prev_output, attention_mask, config,
118 | use_relative_position=use_relative_position
119 | )
120 | prev_output = layer_output[0]
121 |
122 | if config.output_hidden_states:
123 | all_layer_outputs.append(layer_output[0])
124 | if config.output_attentions:
125 | all_layer_attention_probs.append(layer_output[1])
126 |
127 | outputs = (prev_output,)
128 | if config.output_hidden_states:
129 | outputs = outputs + (all_layer_outputs,)
130 |
131 | if config.output_attentions:
132 | outputs = outputs + (all_layer_attention_probs,)
133 | return outputs # (last layer output, all layer outputs, all layer att probs)
134 |
135 |
136 | class BertModel(BaseModel):
137 | def __init__(
138 | self,
139 | config,
140 | is_training,
141 | input_ids,
142 | attention_mask=None,
143 | token_type_ids=None,
144 | return_pool=True,
145 | scope=None,
146 | reuse=False,
147 | compute_type=tf.float32
148 | ):
149 | super().__init__(config, is_training)
150 |
151 | input_shape = model_utils.get_shape_list(input_ids, expected_rank=2)
152 | batch_size = input_shape[0]
153 | seq_length = input_shape[1]
154 |
155 | if attention_mask is None:
156 | attention_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int64)
157 |
158 | if token_type_ids is None:
159 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int64)
160 |
161 | with tf.variable_scope(
162 | scope, default_name="bert",
163 | reuse=tf.AUTO_REUSE if reuse else None,
164 | custom_getter=model_utils.get_custom_getter(compute_type)):
165 | with tf.variable_scope("embeddings"):
166 | self.embedding_output, self.embedding_table = bert_embedding(
167 | config=self.config,
168 | input_ids=input_ids,
169 | token_type_ids=token_type_ids,
170 | add_position_embedding=True
171 | )
172 |
173 | with tf.variable_scope("encoder"):
174 | attention_mask = model_utils.create_bert_mask(
175 | input_ids, attention_mask)
176 | if model_utils.get_shape_list(self.embedding_output)[-1] != self.config.hidden_size:
177 | self.embedding_output = layers.dense(
178 | self.embedding_output, self.config.hidden_size,
179 | 'embedding_hidden_mapping_in', initializer_range=self.config.initializer_range
180 | )
181 | encoder_outputs = bert_encoder(
182 | input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
183 | attention_mask=attention_mask,
184 | config=self.config,
185 | use_relative_position=False
186 | )
187 | if return_pool:
188 | with tf.variable_scope("pooler"):
189 | pooled_output = layers.pooler_layer(
190 | sequence_output=encoder_outputs[0],
191 | hidden_size=self.config.hidden_size,
192 | initializer_range=self.config.initializer_range
193 | )
194 | else:
195 | pooled_output = None
196 | # (pooled output, sequence output, all layer outputs, all layer att probs)
197 | self.outputs = (pooled_output,) + encoder_outputs
198 |
--------------------------------------------------------------------------------
/tfbert/models/electra.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :electra.py
3 | # @Time :2021/1/31 22:27
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | from .base import BaseModel
7 | from .bert import bert_embedding, bert_encoder
8 | from . import model_utils, layers
9 |
10 |
11 | class ElectraModel(BaseModel):
12 | def __init__(
13 | self,
14 | config,
15 | is_training,
16 | input_ids,
17 | attention_mask=None,
18 | token_type_ids=None,
19 | return_pool=True,
20 | scope=None,
21 | reuse=False,
22 | compute_type=tf.float32
23 | ):
24 | super().__init__(config, is_training)
25 |
26 | input_shape = model_utils.get_shape_list(input_ids, expected_rank=2)
27 | batch_size = input_shape[0]
28 | seq_length = input_shape[1]
29 |
30 | if attention_mask is None:
31 | attention_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
32 |
33 | if token_type_ids is None:
34 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
35 |
36 | with tf.variable_scope(scope, default_name="electra",
37 | reuse=tf.AUTO_REUSE if reuse else None,
38 | custom_getter=model_utils.get_custom_getter(compute_type)):
39 | with tf.variable_scope("embeddings"):
40 | self.embedding_output, self.embedding_table = bert_embedding(
41 | config=self.config,
42 | input_ids=input_ids,
43 | token_type_ids=token_type_ids,
44 | add_position_embedding=True
45 | )
46 |
47 | if model_utils.get_shape_list(self.embedding_output)[-1] != self.config.hidden_size:
48 | self.embedding_output = layers.dense(
49 | self.embedding_output, self.config.hidden_size,
50 | 'embeddings_project', initializer_range=self.config.initializer_range
51 | )
52 |
53 | with tf.variable_scope("encoder"):
54 | attention_mask = model_utils.create_bert_mask(
55 | input_ids, attention_mask)
56 | encoder_outputs = bert_encoder(
57 | config=self.config,
58 | input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
59 | attention_mask=attention_mask,
60 | use_relative_position=False
61 | )
62 |
63 | # electra 的 pool output是直接返回first token的vec
64 | if return_pool:
65 | pooled_output = encoder_outputs[0][:, 0]
66 | else:
67 | pooled_output = None
68 | # (pooled output, sequence output, all layer outputs, all layer att probs)
69 | self.outputs = (pooled_output,) + encoder_outputs
70 |
--------------------------------------------------------------------------------
/tfbert/models/embeddings.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :embeddings.py
3 | # @Time :2021/1/31 15:32
4 | # @Author :huanghui
5 |
6 | import numpy as np
7 | import tensorflow.compat.v1 as tf
8 | from . import model_utils, layers
9 |
10 |
11 | def create_word_embeddings(
12 | input_ids,
13 | vocab_size,
14 | embedding_size=128,
15 | initializer_range=0.02,
16 | word_embedding_name="word_embeddings"):
17 | # 创建embedding table
18 | embedding_table = model_utils.create_weight(
19 | [vocab_size, embedding_size],
20 | word_embedding_name,
21 | initializer_range)
22 |
23 | flat_input_ids = tf.reshape(input_ids, [-1])
24 | output = tf.gather(embedding_table, flat_input_ids)
25 | input_shape = model_utils.get_shape_list(input_ids)
26 |
27 | output = tf.reshape(output,
28 | input_shape + [embedding_size])
29 | # output = tf.nn.embedding_lookup(embedding_table, input_ids)
30 |
31 | return (output, embedding_table)
32 |
33 |
34 | def create_token_type_embeddings(
35 | token_type_ids,
36 | embedding_size,
37 | token_type_vocab_size=2,
38 | token_type_embedding_name='token_type_embeddings',
39 | initializer_range=0.02):
40 | input_shape = model_utils.get_shape_list(token_type_ids)
41 | token_type_table = model_utils.create_weight(
42 | shape=[token_type_vocab_size, embedding_size],
43 | var_name=token_type_embedding_name,
44 | initializer_range=initializer_range
45 | )
46 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
47 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
48 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
49 | token_type_embeddings = tf.reshape(token_type_embeddings,
50 | [input_shape[0], input_shape[1], -1])
51 | # token_type_embeddings = tf.nn.embedding_lookup(token_type_table, token_type_ids)
52 | return token_type_embeddings
53 |
54 |
55 | def create_position_embeddings(
56 | seq_len,
57 | embedding_size,
58 | position_embedding_name="position_embeddings",
59 | initializer_range=0.02,
60 | max_position_embeddings=512
61 | ):
62 | full_position_embeddings = model_utils.create_weight(
63 | shape=[max_position_embeddings, embedding_size],
64 | var_name=position_embedding_name,
65 | initializer_range=initializer_range
66 | )
67 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
68 | [seq_len, -1])
69 | # num_dims = len(output.shape.as_list())
70 |
71 | # Only the last two dimensions are relevant (`seq_length` and `width`), so
72 | # we broadcast among the first dimensions, which is typically just
73 | # the batch size.
74 | position_broadcast_shape = []
75 | for _ in range(1):
76 | position_broadcast_shape.append(1)
77 | position_broadcast_shape.extend([seq_len, embedding_size])
78 | position_embeddings = tf.reshape(position_embeddings,
79 | position_broadcast_shape)
80 | # position_embeddings = tf.nn.embedding_lookup(full_position_embeddings, tf.range(0, seq_len))
81 |
82 | return position_embeddings
83 |
84 |
85 | def create_pinyin_embeddings(pinyin_ids, embedding_size: int, pinyin_out_dim: int, initializer_range,
86 | pinyin_vocab_size):
87 | """chineseBERT 的pinyin嵌入"""
88 | input_shape = model_utils.get_shape_list(pinyin_ids) # bs, seq_len, pinyin_locs
89 | pinyin_table = model_utils.create_weight(
90 | shape=[pinyin_vocab_size, embedding_size],
91 | var_name='pinyin_embeddings/embeddings',
92 | initializer_range=initializer_range
93 | )
94 | flat_pinyin_ids = tf.reshape(pinyin_ids, [-1])
95 | pinyin_embeddings = tf.gather(pinyin_table, flat_pinyin_ids)
96 | pinyin_embeddings = tf.reshape(pinyin_embeddings,
97 | [input_shape[0] * input_shape[1], input_shape[2],
98 | embedding_size]) # bs * seq_len, pinyin_locs, embed_size
99 | pinyin_embeddings = tf.expand_dims(pinyin_embeddings, -1) # bs * seq_len, pinyin_locs, embed_size, 1
100 | with tf.variable_scope("pinyin_embeddings/conv"):
101 | # 接一个charCNN
102 | filter_shape = [2, embedding_size, 1, pinyin_out_dim]
103 | pinyin_embeddings = layers.conv2d_layer(
104 | pinyin_embeddings, filter_shape, padding="VALID", act=None,
105 | initializer_range=0.1) # bs * seq_len, pinyin_locs - 2 + 1, 1, pinyin_out_dim
106 | pinyin_embeddings = layers.max_pooling_layer(
107 | pinyin_embeddings, ksize=[1, input_shape[2] - 2 + 1, 1, 1]) # bs * seq_len, 1, 1, pinyin_out_dim
108 | pinyin_embeddings = tf.reshape(pinyin_embeddings, input_shape[:2] + [pinyin_out_dim])
109 | return pinyin_embeddings
110 |
111 |
112 | def create_glyph_embeddings(input_ids, font_npy_files):
113 | font_arrays = [
114 | np.load(np_file).astype(np.float32) for np_file in font_npy_files
115 | ]
116 | vocab_size = font_arrays[0].shape[0]
117 | font_num = len(font_arrays)
118 | font_size = font_arrays[0].shape[-1]
119 | font_array = np.stack(font_arrays, axis=1)
120 | glyph_table = tf.get_variable(
121 | name="glyph_embeddings/embeddings",
122 | shape=[vocab_size, font_size ** 2 * font_num],
123 | initializer=tf.constant_initializer(font_array.reshape([vocab_size, -1])))
124 |
125 | flat_input_ids = tf.reshape(input_ids, [-1])
126 | output = tf.gather(glyph_table, flat_input_ids)
127 | input_shape = model_utils.get_shape_list(input_ids)
128 |
129 | output = tf.reshape(output,
130 | input_shape + [font_size ** 2 * font_num])
131 | return output
132 |
--------------------------------------------------------------------------------
/tfbert/models/glyce_bert.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :glyce_bert.py
3 | # @Time :2021/7/29 14:11
4 | # @Author :huanghui
5 | import os
6 | import json
7 | import tensorflow.compat.v1 as tf
8 | from . import embeddings, layers, model_utils
9 | from .base import BaseModel
10 | from .bert import bert_encoder
11 |
12 |
13 | def glyph_bert_embeddings(
14 | config,
15 | input_ids,
16 | pinyin_ids,
17 | token_type_ids=None
18 | ):
19 | (word_embeddings, embedding_table) = embeddings.create_word_embeddings(
20 | input_ids=input_ids,
21 | vocab_size=config.vocab_size,
22 | embedding_size=config.embedding_size,
23 | initializer_range=config.initializer_range,
24 | word_embedding_name="word_embeddings"
25 | )
26 |
27 | with open(os.path.join(config.config_path, 'pinyin_map.json')) as fin:
28 | pinyin_dict = json.load(fin)
29 | pinyin_embeddings = embeddings.create_pinyin_embeddings(
30 | pinyin_ids,
31 | embedding_size=128,
32 | pinyin_out_dim=config.embedding_size,
33 | initializer_range=config.initializer_range,
34 | pinyin_vocab_size=len(pinyin_dict['idx2char']))
35 |
36 | font_files = []
37 | for file in os.listdir(config.config_path):
38 | if file.endswith(".npy"):
39 | font_files.append(os.path.join(config.config_path, file))
40 | glyph_embeddings = embeddings.create_glyph_embeddings(
41 | input_ids, font_files
42 | )
43 | glyph_embeddings = layers.dense(glyph_embeddings, config.embedding_size, name="glyph_map")
44 |
45 | # fusion layer
46 | concat_embeddings = tf.concat([word_embeddings, pinyin_embeddings, glyph_embeddings], axis=2)
47 | inputs_embeds = layers.dense(concat_embeddings, config.embedding_size, name='map_fc')
48 |
49 | token_type_embeddings = embeddings.create_token_type_embeddings(
50 | token_type_ids=token_type_ids,
51 | embedding_size=config.embedding_size,
52 | token_type_vocab_size=config.type_vocab_size,
53 | token_type_embedding_name='token_type_embeddings',
54 | initializer_range=config.initializer_range
55 | )
56 |
57 | position_embeddings = embeddings.create_position_embeddings(
58 | seq_len=model_utils.get_shape_list(input_ids)[1],
59 | embedding_size=config.embedding_size,
60 | position_embedding_name='position_embeddings',
61 | initializer_range=config.initializer_range,
62 | max_position_embeddings=config.max_position_embeddings
63 | )
64 |
65 | embedding_output = inputs_embeds + position_embeddings + token_type_embeddings
66 | embedding_output = model_utils.layer_norm_and_dropout(
67 | embedding_output,
68 | config.hidden_dropout_prob
69 | )
70 |
71 | return embedding_output, embedding_table
72 |
73 |
74 | class GlyceBertModel(BaseModel):
75 | def __init__(
76 | self,
77 | config,
78 | is_training,
79 | input_ids,
80 | pinyin_ids,
81 | attention_mask=None,
82 | token_type_ids=None,
83 | return_pool=True,
84 | scope=None,
85 | reuse=False,
86 | compute_type=tf.float32
87 | ):
88 | super().__init__(config, is_training)
89 |
90 | input_shape = model_utils.get_shape_list(input_ids, expected_rank=2)
91 | batch_size = input_shape[0]
92 | seq_length = input_shape[1]
93 |
94 | if attention_mask is None:
95 | attention_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int64)
96 |
97 | if token_type_ids is None:
98 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int64)
99 |
100 | with tf.variable_scope(
101 | scope, default_name="bert",
102 | reuse=tf.AUTO_REUSE if reuse else None,
103 | custom_getter=model_utils.get_custom_getter(compute_type)):
104 | with tf.variable_scope("embeddings"):
105 | self.embedding_output, self.embedding_table = glyph_bert_embeddings(
106 | config=self.config,
107 | input_ids=input_ids,
108 | pinyin_ids=pinyin_ids,
109 | token_type_ids=token_type_ids
110 | )
111 |
112 | with tf.variable_scope("encoder"):
113 | attention_mask = model_utils.create_bert_mask(
114 | input_ids, attention_mask)
115 | if model_utils.get_shape_list(self.embedding_output)[-1] != self.config.hidden_size:
116 | self.embedding_output = layers.dense(
117 | self.embedding_output, self.config.hidden_size,
118 | 'embedding_hidden_mapping_in', initializer_range=self.config.initializer_range
119 | )
120 | encoder_outputs = bert_encoder(
121 | input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
122 | attention_mask=attention_mask,
123 | config=self.config,
124 | use_relative_position=False
125 | )
126 | if return_pool:
127 | with tf.variable_scope("pooler"):
128 | pooled_output = layers.pooler_layer(
129 | sequence_output=encoder_outputs[0],
130 | hidden_size=self.config.hidden_size,
131 | initializer_range=self.config.initializer_range
132 | )
133 | else:
134 | pooled_output = None
135 | # (pooled output, sequence output, all layer outputs, all layer att probs)
136 | self.outputs = (pooled_output,) + encoder_outputs
137 |
--------------------------------------------------------------------------------
/tfbert/models/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: loss.py
5 | @date: 2020/09/08
6 | """
7 | import tensorflow.compat.v1 as tf
8 | from tensorflow.python.ops import array_ops
9 | from ..utils import search_layer
10 |
11 |
12 | def cross_entropy_loss(logits, targets, depth):
13 | log_probs = tf.nn.log_softmax(logits, axis=-1)
14 | one_hot_labels = tf.one_hot(targets, depth=depth, dtype=tf.float32)
15 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
16 | loss = tf.reduce_mean(per_example_loss)
17 | return loss
18 |
19 |
20 | def loss_with_gradient_penalty(
21 | loss, epsilon=1,
22 | layer_name='word_embeddings',
23 | gradients_fn=None):
24 | '''
25 | 参考苏神的带梯度惩罚的损失
26 | :param loss: 原本计算得到的loss
27 | :param epsilon:
28 | :param layer_name:
29 | :param gradients_fn: 梯度计算方法,tf.gradients或者optimizer.compute_gradients
30 | :return:
31 | '''
32 | if gradients_fn is None:
33 | gradients_fn = tf.gradients
34 | embeddings = search_layer(layer_name)
35 | gp = tf.reduce_sum(gradients_fn(loss, [embeddings])[0] ** 2)
36 | return loss + 0.5 * epsilon * gp
37 |
38 |
39 | def mlm_loss(logits, targets, depth, label_weights):
40 | log_probs = tf.nn.log_softmax(logits, axis=-1)
41 |
42 | label_ids = tf.reshape(targets, [-1])
43 | label_weights = tf.reshape(label_weights, [-1])
44 |
45 | one_hot_labels = tf.one_hot(
46 | label_ids, depth=depth, dtype=tf.float32)
47 |
48 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
49 | numerator = tf.reduce_sum(label_weights * per_example_loss)
50 | denominator = tf.reduce_sum(label_weights) + 1e-5
51 | loss = numerator / denominator
52 | return loss
53 |
54 |
55 | def soft_cross_entropy(logits, targets):
56 | log_probs = tf.nn.log_softmax(logits, dim=-1)
57 | targets_prob = tf.nn.softmax(targets, dim=-1)
58 | per_example_loss = -tf.reduce_sum(targets_prob * log_probs, axis=-1)
59 | loss = tf.reduce_mean(per_example_loss)
60 | return loss
61 |
62 |
63 | def mse_loss(logits, targets):
64 | return tf.reduce_mean(tf.square(targets - logits))
65 |
66 |
67 | def focal_loss(prediction_tensor, target_tensor, alpha=0.25, gamma=2):
68 | r"""Compute focal loss for predictions.
69 | Multi-labels Focal loss formula:
70 | FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p)
71 | ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor.
72 | Args:
73 | prediction_tensor: A float tensor of shape [batch_size, num_anchors,
74 | num_classes] representing the predicted logits for each class
75 | target_tensor: A float tensor of shape [batch_size, num_anchors,
76 | num_classes] representing one-hot encoded classification targets
77 | weights: A float tensor of shape [batch_size, num_anchors]
78 | alpha: A scalar tensor for focal loss alpha hyper-parameter
79 | gamma: A scalar tensor for focal loss gamma hyper-parameter
80 | Returns:
81 | loss: A (scalar) tensor representing the value of the loss function
82 | """
83 | sigmoid_p = tf.nn.sigmoid(prediction_tensor)
84 | zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)
85 |
86 | # For poitive prediction, only need consider front part loss, back part is 0;
87 | # target_tensor > zeros <=> z=1, so poitive coefficient = z - p.
88 | pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - sigmoid_p, zeros)
89 |
90 | # For negative prediction, only need consider back part loss, front part is 0;
91 | # target_tensor > zeros <=> z=1, so negative coefficient = 0.
92 | neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p)
93 | per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
94 | - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
95 |
96 | return tf.reduce_mean(per_entry_cross_ent)
97 |
--------------------------------------------------------------------------------
/tfbert/models/nezha.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :nezha.py
3 | # @Time :2021/2/2 20:54
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | from . import layers, model_utils
7 | from .base import BaseModel
8 | from .bert import bert_embedding, bert_encoder
9 |
10 |
11 | class NezhaModel(BaseModel):
12 | def __init__(
13 | self,
14 | config,
15 | is_training,
16 | input_ids,
17 | attention_mask=None,
18 | token_type_ids=None,
19 | return_pool=True,
20 | scope=None,
21 | reuse=False,
22 | compute_type=tf.float32
23 | ):
24 | super().__init__(config, is_training)
25 |
26 | input_shape = model_utils.get_shape_list(input_ids, expected_rank=2)
27 | batch_size = input_shape[0]
28 | seq_length = input_shape[1]
29 |
30 | if attention_mask is None:
31 | attention_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int64)
32 |
33 | if token_type_ids is None:
34 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int64)
35 |
36 | with tf.variable_scope(
37 | scope, default_name="bert",
38 | reuse=tf.AUTO_REUSE if reuse else None,
39 | custom_getter=model_utils.get_custom_getter(compute_type)):
40 | with tf.variable_scope("embeddings"):
41 | self.embedding_output, self.embedding_table = bert_embedding(
42 | config=self.config,
43 | input_ids=input_ids,
44 | token_type_ids=token_type_ids,
45 | add_position_embedding=False
46 | )
47 |
48 | with tf.variable_scope("encoder"):
49 | attention_mask = model_utils.create_bert_mask(
50 | input_ids, attention_mask)
51 | if model_utils.get_shape_list(self.embedding_output)[-1] != self.config.hidden_size:
52 | self.embedding_output = layers.dense(
53 | self.embedding_output, self.config.hidden_size,
54 | 'embedding_hidden_mapping_in', initializer_range=self.config.initializer_range
55 | )
56 | encoder_outputs = bert_encoder(
57 | input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
58 | attention_mask=attention_mask,
59 | config=self.config,
60 | use_relative_position=True
61 | )
62 | if return_pool:
63 | with tf.variable_scope("pooler"):
64 | pooled_output = layers.pooler_layer(
65 | sequence_output=encoder_outputs[0],
66 | hidden_size=self.config.hidden_size,
67 | initializer_range=self.config.initializer_range
68 | )
69 | else:
70 | pooled_output = None
71 | # (pooled output, sequence output, all layer outputs, all layer att probs)
72 | self.outputs = (pooled_output,) + encoder_outputs
73 |
--------------------------------------------------------------------------------
/tfbert/optimization/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :__init__.py.py
3 | # @Time :2021/1/31 19:54
4 | # @Author :huanghui
5 |
6 | from .adamw import AdamWeightDecayOptimizer
7 | from .lamb import LAMBOptimizer
8 | from .schedule import lr_schedule
9 | from .create_optimizer import create_optimizer, create_train_op
10 |
--------------------------------------------------------------------------------
/tfbert/optimization/adamw.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :adamw.py
3 | # @Time :2021/1/31 19:54
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | import re
7 |
8 |
9 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
10 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
11 |
12 | def __init__(self,
13 | learning_rate,
14 | weight_decay_rate=0.0,
15 | beta_1=0.9,
16 | beta_2=0.999,
17 | epsilon=1e-6,
18 | exclude_from_weight_decay=None,
19 | name="AdamWeightDecayOptimizer"):
20 | """Constructs a AdamWeightDecayOptimizer."""
21 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
22 |
23 | self.learning_rate = tf.identity(learning_rate, name='learning_rate')
24 | self.weight_decay_rate = weight_decay_rate
25 | self.beta_1 = beta_1
26 | self.beta_2 = beta_2
27 | self.epsilon = epsilon
28 | self.exclude_from_weight_decay = exclude_from_weight_decay
29 |
30 | def apply_gradients(self, grads_and_vars, global_step=None, name=None,
31 | manual_fp16=False):
32 | """See base class."""
33 | assignments = []
34 | for (grad, param) in grads_and_vars:
35 | if grad is None or param is None:
36 | continue
37 |
38 | param_name = self._get_variable_name(param.name)
39 | has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32
40 | if has_shadow:
41 | # create shadow fp32 weights for fp16 variable
42 | param_fp32 = tf.get_variable(
43 | name=param_name + "/shadow",
44 | dtype=tf.float32,
45 | trainable=False,
46 | initializer=tf.cast(param.initialized_value(), tf.float32))
47 | else:
48 | param_fp32 = param
49 |
50 | m = tf.get_variable(
51 | name=param_name + "/adam_m",
52 | shape=param.shape.as_list(),
53 | dtype=tf.float32,
54 | trainable=False,
55 | initializer=tf.zeros_initializer())
56 | v = tf.get_variable(
57 | name=param_name + "/adam_v",
58 | shape=param.shape.as_list(),
59 | dtype=tf.float32,
60 | trainable=False,
61 | initializer=tf.zeros_initializer())
62 |
63 | # Standard Adam update.
64 | next_m = (
65 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
66 | next_v = (
67 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
68 | tf.square(grad)))
69 |
70 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
71 |
72 | # Just adding the square of the weights to the loss function is *not*
73 | # the correct way of using L2 regularization/weight decay with Adam,
74 | # since that will interact with the m and v parameters in strange ways.
75 | #
76 | # Instead we want to decay the weights in a manner that doesn't interact
77 | # with the m/v parameters. This is equivalent to adding the square
78 | # of the weights to the loss with plain (non-momentum) SGD.
79 | if self._do_use_weight_decay(param_name):
80 | update += self.weight_decay_rate * param_fp32
81 |
82 | update_with_lr = self.learning_rate * update
83 |
84 | next_param = param_fp32 - update_with_lr
85 |
86 | if has_shadow:
87 | # cast shadow fp32 weights to fp16 and assign to trainable variable
88 | param.assign(tf.cast(next_param, param.dtype.base_dtype))
89 | assignments.extend(
90 | [param_fp32.assign(next_param),
91 | m.assign(next_m),
92 | v.assign(next_v)])
93 | return tf.group(*assignments, name=name)
94 |
95 | def _do_use_weight_decay(self, param_name):
96 | """Whether to use L2 weight decay for `param_name`."""
97 | if not self.weight_decay_rate:
98 | return False
99 | if self.exclude_from_weight_decay:
100 | for r in self.exclude_from_weight_decay:
101 | if re.search(r, param_name) is not None:
102 | return False
103 | return True
104 |
105 | def _get_variable_name(self, param_name):
106 | """Get the variable name from the tensor name."""
107 | m = re.match("^(.*):\\d+$", param_name)
108 | if m is not None:
109 | param_name = m.group(1)
110 | return param_name
111 |
--------------------------------------------------------------------------------
/tfbert/optimization/create_optimizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :create_optimizer.py
3 | # @Time :2021/1/31 19:58
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | from .adamw import AdamWeightDecayOptimizer
7 | from .lamb import LAMBOptimizer
8 | from .schedule import lr_schedule
9 |
10 |
11 | def create_optimizer(
12 | learning_rate,
13 | num_train_steps=None,
14 | num_warmup_steps=None,
15 | optimizer_type='adamw',
16 | epsilon=1e-6,
17 | momentum=0.,
18 | weight_decay=0.01,
19 | decay_method='poly',
20 | mixed_precision=False,
21 | init_loss_scale=2 ** 32
22 | ):
23 | if decay_method is not None and num_train_steps is not None and num_warmup_steps is not None:
24 | num_train_steps = int(num_train_steps)
25 | num_warmup_steps = int(num_warmup_steps)
26 | learning_rate = lr_schedule(
27 | learning_rate, num_train_steps, num_warmup_steps,
28 | decay_method=decay_method, optimizer_type=optimizer_type
29 | )
30 |
31 | if optimizer_type == 'adamw':
32 | optimizer = AdamWeightDecayOptimizer(
33 | learning_rate=learning_rate,
34 | weight_decay_rate=weight_decay,
35 | beta_1=0.9,
36 | beta_2=0.999,
37 | epsilon=epsilon,
38 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]
39 | )
40 | elif optimizer_type == 'lamb':
41 | optimizer = LAMBOptimizer(
42 | learning_rate,
43 | weight_decay_rate=weight_decay,
44 | beta_1=0.9,
45 | beta_2=0.999,
46 | epsilon=epsilon,
47 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]
48 | )
49 | elif optimizer_type == 'adam':
50 | optimizer = tf.train.AdamOptimizer(
51 | learning_rate=learning_rate,
52 | beta1=0.9,
53 | beta2=0.999,
54 | epsilon=epsilon)
55 | elif optimizer_type == 'sgd':
56 | optimizer = tf.train.GradientDescentOptimizer(
57 | learning_rate=learning_rate
58 | )
59 | elif optimizer_type == 'adadelta':
60 | optimizer = tf.train.AdadeltaOptimizer(
61 | learning_rate=learning_rate,
62 | rho=0.95,
63 | epsilon=epsilon,
64 | )
65 | elif optimizer_type == 'adagrad':
66 | optimizer = tf.train.AdagradOptimizer(
67 | learning_rate=learning_rate,
68 | initial_accumulator_value=0.1
69 | )
70 | elif optimizer_type == 'rmsp':
71 | optimizer = tf.train.RMSPropOptimizer(
72 | learning_rate=learning_rate,
73 | decay=0.9,
74 | momentum=momentum,
75 | epsilon=epsilon,
76 | )
77 | else:
78 | raise ValueError('Unsupported optimizer option: %s' % optimizer_type)
79 |
80 | if mixed_precision:
81 | loss_scaler = tf.train.experimental.DynamicLossScale(
82 | initial_loss_scale=init_loss_scale, increment_period=1000,
83 | multiplier=2.0)
84 | optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler)
85 | loss_scale_value = tf.identity(loss_scaler(), name="loss_scale")
86 | return optimizer
87 |
88 |
89 | def create_train_op(
90 | optimizer,
91 | grads_and_vars,
92 | max_grad=1.0,
93 | mixed_precision=False,
94 | gradient_accumulation_steps=1):
95 | global_step = tf.train.get_or_create_global_step()
96 |
97 | if gradient_accumulation_steps > 1:
98 | local_step = tf.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False,
99 | initializer=tf.zeros_initializer)
100 | batch_finite = tf.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False,
101 | initializer=tf.ones_initializer)
102 | accum_vars = [tf.get_variable(
103 | name=tvar.name.split(":")[0] + "/accum",
104 | shape=tvar.shape.as_list(),
105 | dtype=tf.float32,
106 | trainable=False,
107 | initializer=tf.zeros_initializer()) for tvar in tf.trainable_variables()]
108 |
109 | reset_step = tf.cast(tf.math.equal(local_step % gradient_accumulation_steps, 0), dtype=tf.bool)
110 | local_step = tf.cond(reset_step, lambda: local_step.assign(tf.ones_like(local_step)),
111 | lambda: local_step.assign_add(1))
112 |
113 | grads_and_vars_and_accums = [(gv[0], gv[1], accum_vars[i]) for i, gv in enumerate(grads_and_vars) if
114 | gv[0] is not None]
115 | grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums))
116 |
117 | all_are_finite = tf.reduce_all(
118 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if mixed_precision else tf.constant(
119 | True,
120 | dtype=tf.bool)
121 | batch_finite = tf.cond(reset_step,
122 | lambda: batch_finite.assign(
123 | tf.math.logical_and(tf.constant(True, dtype=tf.bool), all_are_finite)),
124 | lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite)))
125 |
126 | # This is how the model was pre-trained.
127 | # ensure global norm is a finite number
128 | # to prevent clip_by_global_norm from having a hizzy fit.
129 | (clipped_grads, _) = tf.clip_by_global_norm(
130 | grads, clip_norm=max_grad)
131 |
132 | accum_vars = tf.cond(reset_step,
133 | lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)],
134 | lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)])
135 |
136 | def update(accum_vars):
137 | return optimizer.apply_gradients(list(zip(accum_vars, tvars)))
138 |
139 | update_step = tf.identity(
140 | tf.cast(tf.math.equal(local_step % gradient_accumulation_steps, 0), dtype=tf.bool),
141 | name="update_step")
142 | update_op = tf.cond(update_step,
143 | lambda: update(accum_vars), lambda: tf.no_op())
144 |
145 | new_global_step = tf.cond(tf.math.logical_and(update_step, batch_finite),
146 | lambda: global_step + 1,
147 | lambda: global_step)
148 | new_global_step = tf.identity(new_global_step, name='step_update')
149 | train_op = tf.group(update_op, [global_step.assign(new_global_step)])
150 | else:
151 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
152 | grads, tvars = list(zip(*grads_and_vars))
153 | all_are_finite = tf.reduce_all(
154 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if mixed_precision else tf.constant(True,
155 | dtype=tf.bool)
156 |
157 | # This is how the model was pre-trained.
158 | # ensure global norm is a finite number
159 | # to prevent clip_by_global_norm from having a hizzy fit.
160 | (clipped_grads, _) = tf.clip_by_global_norm(
161 | grads, clip_norm=max_grad)
162 |
163 | # 这里不要传入global step,adam内部没有对global step累加
164 | # 而原本adam等tf内置优化器会累加,这样就会造成global step重复增加
165 | train_op = optimizer.apply_gradients(
166 | list(zip(clipped_grads, tvars)))
167 |
168 | new_global_step = tf.cond(all_are_finite, lambda: global_step + 1, lambda: global_step)
169 | new_global_step = tf.identity(new_global_step, name='step_update')
170 |
171 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
172 | return train_op
173 |
--------------------------------------------------------------------------------
/tfbert/optimization/lamb.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :lamb.py
3 | # @Time :2021/1/31 19:55
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 | from tensorflow.python.ops import array_ops
7 | from tensorflow.python.ops import linalg_ops
8 | from tensorflow.python.ops import math_ops
9 | import re
10 |
11 |
12 | class LAMBOptimizer(tf.train.Optimizer):
13 | """A LAMB optimizer that includes "correct" L2 weight decay."""
14 |
15 | def __init__(self,
16 | learning_rate,
17 | weight_decay_rate=0.0,
18 | beta_1=0.9,
19 | beta_2=0.999,
20 | epsilon=1e-6,
21 | exclude_from_weight_decay=None,
22 | name="LAMBOptimizer"):
23 | """Constructs a LAMBOptimizer."""
24 | super(LAMBOptimizer, self).__init__(False, name)
25 |
26 | self.learning_rate = tf.identity(learning_rate, name='learning_rate')
27 | self.weight_decay_rate = weight_decay_rate
28 | self.beta_1 = beta_1
29 | self.beta_2 = beta_2
30 | self.epsilon = epsilon
31 | self.exclude_from_weight_decay = exclude_from_weight_decay
32 |
33 | def apply_gradients(self, grads_and_vars, global_step, name=None,
34 | manual_fp16=False):
35 | """See base class."""
36 | assignments = []
37 | steps = tf.cast(global_step, tf.float32)
38 | for (grad, param) in grads_and_vars:
39 | if grad is None or param is None:
40 | continue
41 |
42 | param_name = self._get_variable_name(param.name)
43 | has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32
44 | if has_shadow:
45 | # create shadow fp32 weights for fp16 variable
46 | param_fp32 = tf.get_variable(
47 | name=param_name + "/shadow",
48 | dtype=tf.float32,
49 | trainable=False,
50 | initializer=tf.cast(param.initialized_value(), tf.float32))
51 | else:
52 | param_fp32 = param
53 |
54 | m = tf.get_variable(
55 | name=param_name + "/adam_m",
56 | shape=param.shape.as_list(),
57 | dtype=tf.float32,
58 | trainable=False,
59 | initializer=tf.zeros_initializer())
60 | v = tf.get_variable(
61 | name=param_name + "/adam_v",
62 | shape=param.shape.as_list(),
63 | dtype=tf.float32,
64 | trainable=False,
65 | initializer=tf.zeros_initializer())
66 |
67 | # LAMB update
68 | next_m = (
69 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
70 | next_v = (
71 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
72 | tf.square(grad)))
73 |
74 | beta1_correction = (1 - self.beta_1 ** steps)
75 | beta2_correction = (1 - self.beta_2 ** steps)
76 |
77 | next_m_unbiased = next_m / beta1_correction
78 | next_v_unbiased = next_v / beta2_correction
79 |
80 | update = next_m_unbiased / (tf.sqrt(next_v_unbiased) + self.epsilon)
81 |
82 | # Just adding the square of the weights to the loss function is *not*
83 | # the correct way of using L2 regularization/weight decay with Adam,
84 | # since that will interact with the m and v parameters in strange ways.
85 | #
86 | # Instead we want to decay the weights in a manner that doesn't interact
87 | # with the m/v parameters. This is equivalent to adding the square
88 | # of the weights to the loss with plain (non-momentum) SGD.
89 | if self._do_use_weight_decay(param_name):
90 | update += self.weight_decay_rate * param_fp32
91 |
92 | w_norm = linalg_ops.norm(param, ord=2)
93 | g_norm = linalg_ops.norm(update, ord=2)
94 | ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
95 | math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
96 |
97 | update_with_lr = ratio * self.learning_rate * update
98 |
99 | next_param = param_fp32 - update_with_lr
100 |
101 | if has_shadow:
102 | # cast shadow fp32 weights to fp16 and assign to trainable variable
103 | param.assign(tf.cast(next_param, param.dtype.base_dtype))
104 | assignments.extend(
105 | [param_fp32.assign(next_param),
106 | m.assign(next_m),
107 | v.assign(next_v)])
108 | return tf.group(*assignments, name=name)
109 |
110 | def _do_use_weight_decay(self, param_name):
111 | """Whether to use L2 weight decay for `param_name`."""
112 | if not self.weight_decay_rate:
113 | return False
114 | if self.exclude_from_weight_decay:
115 | for r in self.exclude_from_weight_decay:
116 | if re.search(r, param_name) is not None:
117 | return False
118 | return True
119 |
120 | def _get_variable_name(self, param_name):
121 | """Get the variable name from the tensor name."""
122 | m = re.match("^(.*):\\d+$", param_name)
123 | if m is not None:
124 | param_name = m.group(1)
125 | return param_name
126 |
--------------------------------------------------------------------------------
/tfbert/optimization/schedule.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :schedule.py
3 | # @Time :2021/1/31 19:56
4 | # @Author :huanghui
5 | import tensorflow.compat.v1 as tf
6 |
7 |
8 | def lr_schedule(init_lr,
9 | num_train_steps,
10 | num_warmup_steps,
11 | decay_method='poly',
12 | optimizer_type='adamw'):
13 | '''
14 | 线性学习率, 在warmup steps之前:lr = global_step/num_warmup_steps * init_lr
15 | :param init_lr:
16 | :param num_train_steps:
17 | :param num_warmup_steps:
18 | :param decay_method: 学习率衰减方式,可选择 poly、cos
19 | :param optimizer_type:
20 | :return:
21 | '''
22 |
23 | global_step = tf.train.get_or_create_global_step()
24 |
25 | # avoid step change in learning rate at end of warmup phase
26 | if optimizer_type == "adamw":
27 | power = 1.0
28 | decayed_learning_rate_at_crossover_point = init_lr * (
29 | (1.0 - float(num_warmup_steps) / float(num_train_steps)) ** power)
30 | else:
31 | power = 0.5
32 | decayed_learning_rate_at_crossover_point = init_lr
33 |
34 | adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point)
35 | init_lr = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32)
36 |
37 | # increase the learning rate linearly
38 | if num_warmup_steps > 0:
39 | warmup_lr = (tf.cast(global_step, tf.float32)
40 | / tf.cast(num_warmup_steps, tf.float32)
41 | * init_lr) # 线性增长
42 | else:
43 | warmup_lr = 0.0
44 |
45 | # decay the learning rate
46 | if decay_method == "poly":
47 | decay_lr = tf.train.polynomial_decay(
48 | init_lr,
49 | global_step,
50 | num_train_steps,
51 | end_learning_rate=0.0,
52 | power=power,
53 | cycle=False)
54 | elif decay_method == "cos":
55 | decay_lr = tf.train.cosine_decay(
56 | init_lr,
57 | global_step,
58 | num_train_steps,
59 | alpha=0.0)
60 | else:
61 | raise ValueError(decay_method)
62 |
63 | learning_rate = tf.where(global_step < num_warmup_steps,
64 | warmup_lr, decay_lr)
65 |
66 | return learning_rate
67 |
--------------------------------------------------------------------------------
/tfbert/serving.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :serving.py
3 | # @Time :2021/2/3 21:52
4 | # @Author :huanghui
5 | import os
6 | import tensorflow.compat.v1 as tf
7 | from tensorflow.python.framework import ops
8 | from tensorflow.python.saved_model import builder
9 | from tensorflow.python.saved_model import signature_constants
10 | from tensorflow.python.saved_model import signature_def_utils
11 | from tensorflow.python.saved_model import tag_constants
12 | from tensorflow.contrib import predictor
13 |
14 |
15 | def export_model_to_pb(model_name_or_path, export_path,
16 | inputs: dict, outputs: dict):
17 | """
18 | config = BertConfig.from_pretrained('ckpt')
19 | input_ids = tf.placeholder(shape=[None, 32], dtype=tf.int64, name='input_ids')
20 | input_mask = tf.placeholder(shape=[None, 32], dtype=tf.int64, name='input_mask')
21 | token_type_ids = tf.placeholder(shape=[None, 32], dtype=tf.int64, name='token_type_ids')
22 | model = model = SequenceClassification(
23 | model_type='bert',
24 | config=config,
25 | num_classes=len(labels),
26 | is_training=False,
27 | input_ids=input_ids,
28 | input_mask=input_mask,
29 | token_type_ids=token_type_ids)
30 | export_model_to_pb('ckpt/model.ckpt-1875', 'pb',
31 | inputs={'input_ids': input_ids, 'input_mask': input_mask, 'token_type_ids': token_type_ids},
32 | outputs={'logits': model.logits}
33 | )
34 | :param model_name_or_path:
35 | :param export_path:
36 | :param inputs:
37 | :param outputs:
38 | :return:
39 | """
40 | gpu_config = tf.ConfigProto()
41 | gpu_config.gpu_options.allow_growth = True
42 | sess = tf.Session(config=gpu_config)
43 | saver = tf.train.Saver()
44 |
45 | if os.path.isdir(model_name_or_path):
46 | ckpt_file = tf.train.latest_checkpoint(model_name_or_path)
47 | if ckpt_file is None:
48 | ckpt_file = os.path.join(model_name_or_path, 'model.ckpt')
49 | else:
50 | ckpt_file = model_name_or_path
51 | saver.restore(sess, ckpt_file)
52 | save_pb(sess, export_path, inputs=inputs, outputs=outputs, saver=saver)
53 | tf.logging.info('export model to {}'.format(export_path))
54 |
55 |
56 | def save_pb(session, export_dir, inputs, outputs, legacy_init_op=None, saver=None):
57 | '''
58 | 重写 pb 保存模型接口,可以添加saver,剔除额外参数
59 | :param session:
60 | :param export_dir:
61 | :param inputs:
62 | :param outputs:
63 | :param legacy_init_op:
64 | :param saver:
65 | :return:
66 | '''
67 | signature_def_map = {
68 | signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
69 | signature_def_utils.predict_signature_def(inputs, outputs)
70 | }
71 | b = builder.SavedModelBuilder(export_dir)
72 | b.add_meta_graph_and_variables(
73 | session,
74 | tags=[tag_constants.SERVING],
75 | signature_def_map=signature_def_map,
76 | assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
77 | main_op=legacy_init_op,
78 | clear_devices=True,
79 | saver=saver
80 | )
81 | b.save()
82 |
83 |
84 | def load_pb(pb_dir):
85 | '''
86 | 加载保存的pb模型,该方法适用于线下测试pb模型,部署到线上就需要使用tf serving的方式来预测
87 |
88 | 测试例子:
89 | 通用是分类模型。
90 |
91 | predict_fn, input_names, output_names = load_pb('pb')
92 | tokenizer = BertTokenizer.from_pretrained('ckpt', do_lower_case=True)
93 | inputs = tokenizer.encode("名人堂故事之威斯康辛先生:大范&联盟总裁的前辈",
94 | add_special_tokens=True, max_length=32,
95 | pad_to_max_length=True)
96 | prediction = predict_fn(
97 | {
98 | 'input_ids': [inputs['input_ids']],
99 | 'input_mask': [inputs['input_mask']],
100 | 'token_type_ids': [inputs['token_type_ids']]
101 | }
102 | )
103 | print(prediction)
104 |
105 | 输出{'logits': array([[ 5.1162577, -3.842629 , -0.2090739, 1.629769 , -2.6358554]],
106 | dtype=float32)}
107 |
108 | :param pb_dir: 保存的pb模型的文件夹
109 | :return: 预测fn, fn接收的输入的names,fn输出的names
110 | '''
111 | predict_fn = predictor.from_saved_model(pb_dir)
112 | input_names = list(predict_fn._feed_tensors.keys())
113 | output_names = list(predict_fn._fetch_tensors.keys())
114 | return predict_fn, input_names, output_names
115 |
--------------------------------------------------------------------------------
/tfbert/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :__init__.py.py
3 | # @Time :2021/1/31 15:17
4 | # @Author :huanghui
5 | from .tokenization_base import (
6 | PTMTokenizer, BasicTokenizer, WordpieceTokenizer)
7 | from .bert import BertTokenizer
8 | from .albert import ALBertTokenizer
9 | from .bert import BertTokenizer as NeZhaTokenizer
10 | from .bert import BertTokenizer as ElectraTokenizer
11 | from .wobert import WoBertTokenizer
12 | from .glyce_bert import GlyceBertTokenizer
13 |
14 | TOKENIZERS = {
15 | 'bert': BertTokenizer, 'albert': ALBertTokenizer,
16 | 'nezha': NeZhaTokenizer, 'electra': ElectraTokenizer,
17 | 'wobert': WoBertTokenizer, 'glyce_bert': GlyceBertTokenizer
18 | }
19 |
--------------------------------------------------------------------------------
/tfbert/tokenizer/bert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: bert.py
5 | @date: 2020/09/08
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import collections
13 | import os
14 | from . import PTMTokenizer
15 | from .tokenization_base import load_vocab, BasicTokenizer, WordpieceTokenizer
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | class BertTokenizer(PTMTokenizer):
20 | padding_side = 'right'
21 | model_max_length = 512
22 | model_input_names = ['input_ids', 'token_type_ids', 'attention_mask']
23 |
24 | def __init__(
25 | self,
26 | vocab_file,
27 | do_lower_case=True,
28 | unk_token="[UNK]",
29 | sep_token="[SEP]",
30 | pad_token="[PAD]",
31 | cls_token="[CLS]",
32 | mask_token="[MASK]",
33 | **kwargs
34 | ):
35 | super().__init__(
36 | unk_token=unk_token,
37 | sep_token=sep_token,
38 | pad_token=pad_token,
39 | cls_token=cls_token,
40 | mask_token=mask_token,
41 | **kwargs,
42 | )
43 |
44 | self.vocab = load_vocab(vocab_file)
45 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
46 |
47 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
48 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
49 | self.do_lower_case = do_lower_case
50 | self.num_special_tokens = 2
51 |
52 | @property
53 | def vocab_size(self):
54 | return len(self.vocab)
55 |
56 | def convert_token_to_id(self, token):
57 | return self.vocab.get(token, self.vocab.get(self.unk_token))
58 |
59 | def convert_id_to_token(self, index):
60 | return self.ids_to_tokens.get(index, self.unk_token)
61 |
62 | def convert_tokens_to_string(self, tokens):
63 | out_string = " ".join(tokens).replace(" ##", "").strip()
64 | return out_string
65 |
66 | def num_special_tokens_to_add(self, pair=False):
67 | """
68 | Returns the number of added tokens when encoding a sequence with special tokens.
69 | Note:
70 | This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
71 | inside your training loop.
72 | Args:
73 | pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
74 | number of added tokens in the case of a single sequence if set to False.
75 | Returns:
76 | Number of tokens added to sequences
77 | """
78 | token_ids_0 = []
79 | token_ids_1 = []
80 | return len(
81 | self.build_inputs_with_special_tokens(token_ids_0, token_ids_1
82 | if pair else None))
83 |
84 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
85 | """
86 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
87 | adding special tokens.
88 |
89 | A BERT sequence has the following format:
90 | ::
91 | - single sequence: ``[CLS] X [SEP]``
92 | - pair of sequences: ``[CLS] A [SEP] B [SEP]``
93 | Args:
94 | token_ids_0 (:obj:`List[int]`):
95 | List of IDs to which the special tokens will be added.
96 | token_ids_1 (:obj:`List[int]`, `optional`):
97 | Optional second list of IDs for sequence pairs.
98 | Returns:
99 | :obj:`List[int]`: List of input_id with the appropriate special tokens.
100 | """
101 | if token_ids_1 is None:
102 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
103 | _cls = [self.cls_token_id]
104 | _sep = [self.sep_token_id]
105 | return _cls + token_ids_0 + _sep + token_ids_1 + _sep
106 |
107 | def build_offset_mapping_with_special_tokens(self,
108 | offset_mapping_0,
109 | offset_mapping_1=None):
110 | """
111 | Build offset map from a pair of offset map by concatenating and adding offsets of special tokens.
112 |
113 | A BERT offset_mapping has the following format:
114 | ::
115 | - single sequence: ``(0,0) X (0,0)``
116 | - pair of sequences: `(0,0) A (0,0) B (0,0)``
117 |
118 | Args:
119 | offset_mapping_0 (:obj:`List[tuple]`):
120 | List of char offsets to which the special tokens will be added.
121 | offset_mapping_1 (:obj:`List[tuple]`, `optional`):
122 | Optional second list of char offsets for offset mapping pairs.
123 | Returns:
124 | :obj:`List[tuple]`: List of char offsets with the appropriate offsets of special tokens.
125 | """
126 | if offset_mapping_1 is None:
127 | return [(0, 0)] + offset_mapping_0 + [(0, 0)]
128 |
129 | return [(0, 0)] + offset_mapping_0 + [(0, 0)
130 | ] + offset_mapping_1 + [(0, 0)]
131 |
132 | def create_token_type_ids_from_sequences(self,
133 | token_ids_0,
134 | token_ids_1=None):
135 | """
136 | Create a mask from the two sequences passed to be used in a sequence-pair classification task.
137 | A BERT sequence pair mask has the following format:
138 | ::
139 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
140 | | first sequence | second sequence |
141 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
142 | Args:
143 | token_ids_0 (:obj:`List[int]`):
144 | List of IDs.
145 | token_ids_1 (:obj:`List[int]`, `optional`):
146 | Optional second list of IDs for sequence pairs.
147 | Returns:
148 | :obj:`List[int]`: List of token_type_id according to the given sequence(s).
149 | """
150 | _sep = [self.sep_token_id]
151 | _cls = [self.cls_token_id]
152 | if token_ids_1 is None:
153 | return len(_cls + token_ids_0 + _sep) * [0]
154 | return len(_cls + token_ids_0 + _sep) * [0] + len(token_ids_1 +
155 | _sep) * [1]
156 |
157 | def get_special_tokens_mask(self,
158 | token_ids_0,
159 | token_ids_1=None,
160 | already_has_special_tokens=False):
161 | """
162 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
163 | special tokens using the tokenizer ``encode`` methods.
164 | Args:
165 | token_ids_0 (List[int]): List of ids of the first sequence.
166 | token_ids_1 (List[int], optinal): List of ids of the second sequence.
167 | already_has_special_tokens (bool, optional): Whether or not the token list is already
168 | formatted with special tokens for the model. Defaults to None.
169 | Returns:
170 | results (List[int]): The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
171 | """
172 |
173 | if already_has_special_tokens:
174 | if token_ids_1 is not None:
175 | raise ValueError(
176 | "You should not supply a second sequence if the provided sequence of "
177 | "ids is already formatted with special tokens for the model."
178 | )
179 | return list(
180 | map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
181 | token_ids_0))
182 |
183 | if token_ids_1 is not None:
184 | return [1] + ([0] * len(token_ids_0)) + [1] + (
185 | [0] * len(token_ids_1)) + [1]
186 | return [1] + ([0] * len(token_ids_0)) + [1]
187 |
188 | @classmethod
189 | def from_pretrained(cls, vocab_dir_or_file, **kwargs):
190 | do_lower_case = kwargs.pop('do_lower_case', True)
191 | if os.path.isdir(vocab_dir_or_file):
192 | filename = 'vocab.txt'
193 | vocab_file = os.path.join(vocab_dir_or_file, filename)
194 | else:
195 | vocab_file = vocab_dir_or_file
196 |
197 | return cls(vocab_file=vocab_file, do_lower_case=do_lower_case, **kwargs)
198 |
199 | def save_pretrained(self, save_directory):
200 | if os.path.isdir(save_directory):
201 | vocab_file = os.path.join(save_directory, 'vocab.txt')
202 | else:
203 | vocab_file = save_directory
204 |
205 | with open(vocab_file, 'w', encoding='utf-8') as writer:
206 | for token, index in self.vocab.items():
207 | writer.write(token.strip() + '\n')
208 | tf.logging.info(" Tokenizer vocab saved in {}".format(vocab_file))
209 | return vocab_file
210 |
211 | def tokenize(self, text):
212 | split_tokens = []
213 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
214 | # If the token is part of the never_split set
215 | if token in self.basic_tokenizer.never_split:
216 | split_tokens.append(token)
217 | else:
218 | split_tokens += self.wordpiece_tokenizer.tokenize(token)
219 |
220 | return split_tokens
221 |
--------------------------------------------------------------------------------
/tfbert/tokenizer/glyce_bert.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # @FileName :glyce_bert.py
3 | # @Time :2021/7/29 18:19
4 | # @Author :huanghui
5 | import os
6 | import tensorflow.compat.v1 as tf
7 | import json
8 | from .tokenization_base import convert_to_unicode, PaddingStrategy, TruncationStrategy
9 | from .bert import BertTokenizer
10 | from typing import List, Union, Tuple, Optional
11 |
12 |
13 | class GlyceBertTokenizer(BertTokenizer):
14 | def __init__(self, config_path, **kwargs):
15 | super(GlyceBertTokenizer, self).__init__(**kwargs)
16 | # load pinyin map dict
17 | with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
18 | self.pinyin_dict = json.load(fin)
19 | # load char id map tensor
20 | with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
21 | self.id2pinyin = json.load(fin)
22 | # load pinyin map tensor
23 | with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
24 | self.pinyin2tensor = json.load(fin)
25 |
26 | def save_pretrained(self, save_directory):
27 |
28 | if os.path.isdir(save_directory):
29 | vocab_file = os.path.join(save_directory, 'vocab.txt')
30 | config_path = os.path.join(save_directory, 'config')
31 | else:
32 | vocab_file = save_directory
33 | config_path = os.path.join(os.path.split(save_directory)[0], "config")
34 |
35 | if not os.path.exists(config_path):
36 | os.makedirs(config_path)
37 |
38 | with open(os.path.join(config_path, 'pinyin_map.json'), "w", encoding='utf8') as fin:
39 | fin.write(json.dumps(self.pinyin_dict, ensure_ascii=False))
40 |
41 | with open(os.path.join(config_path, 'id2pinyin.json'), "w", encoding='utf8') as fin:
42 | fin.write(json.dumps(self.id2pinyin, ensure_ascii=False))
43 |
44 | with open(os.path.join(config_path, 'pinyin2tensor.json'), "w", encoding='utf8') as fin:
45 | fin.write(json.dumps(self.pinyin2tensor, ensure_ascii=False))
46 |
47 | with open(vocab_file, 'w', encoding='utf-8') as writer:
48 | for token, index in self.vocab.items():
49 | writer.write(token.strip() + '\n')
50 | tf.logging.info(" Tokenizer vocab saved in {}".format(vocab_file))
51 | return vocab_file
52 |
53 | @classmethod
54 | def from_pretrained(cls, vocab_dir_or_file, **kwargs):
55 | do_lower_case = kwargs.pop('do_lower_case', True)
56 | if os.path.isdir(vocab_dir_or_file):
57 | filename = 'vocab.txt'
58 | vocab_file = os.path.join(vocab_dir_or_file, filename)
59 | config_path = os.path.join(vocab_dir_or_file, "config")
60 | else:
61 | vocab_file = vocab_dir_or_file
62 | config_path = os.path.join(os.path.split(vocab_dir_or_file)[0], "config")
63 |
64 | return cls(config_path=config_path, vocab_file=vocab_file, do_lower_case=do_lower_case, **kwargs)
65 |
66 | def convert_token_ids_to_pinyin_ids(self, ids):
67 | from pypinyin import pinyin, Style
68 |
69 | tokens = self.convert_ids_to_tokens(ids)
70 | offsets = []
71 | pos = 0
72 | sentence = ""
73 | for token in tokens:
74 | token = token.replace("##", "").strip()
75 |
76 | if len(token) == 0:
77 | token = " "
78 | if token in self.all_special_tokens:
79 | token = " "
80 | offsets.append((0, 0))
81 | else:
82 | offsets.append((pos, pos + len(token)))
83 | pos += len(token)
84 | sentence += token
85 |
86 | pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
87 | pinyin_locs = {}
88 | # get pinyin of each location
89 | for index, item in enumerate(pinyin_list):
90 | pinyin_string = item[0]
91 | # not a Chinese character, pass
92 | if pinyin_string == "not chinese":
93 | continue
94 | if pinyin_string in self.pinyin2tensor:
95 | pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
96 | else:
97 | ids = [0] * 8
98 | for i, p in enumerate(pinyin_string):
99 | if p not in self.pinyin_dict["char2idx"]:
100 | ids = [0] * 8
101 | break
102 | ids[i] = self.pinyin_dict["char2idx"][p]
103 | pinyin_locs[index] = ids
104 |
105 | # find chinese character location, and generate pinyin ids
106 | pinyin_ids = []
107 | for idx, offset in enumerate(offsets):
108 | if offset[1] - offset[0] != 1:
109 | pinyin_ids.append([0] * 8)
110 | continue
111 | if offset[0] in pinyin_locs:
112 | pinyin_ids.append(pinyin_locs[offset[0]])
113 | else:
114 | pinyin_ids.append([0] * 8)
115 |
116 | return pinyin_ids
117 |
118 | def _encode_plus(
119 | self,
120 | text: Union[str, List[str], List[int]],
121 | text_pair: Optional[Union[str, List[str], List[int]]] = None,
122 | add_special_tokens: bool = True,
123 | padding_strategy: Union[bool, str, PaddingStrategy] = PaddingStrategy.DO_NOT_PAD,
124 | truncation_strategy: Union[bool, str, TruncationStrategy] = TruncationStrategy.DO_NOT_TRUNCATE,
125 | max_length: Optional[int] = None,
126 | stride: int = 0,
127 | return_token_type_ids: Optional[bool] = None,
128 | return_attention_mask: Optional[bool] = None,
129 | return_overflowing_tokens: bool = False,
130 | return_special_tokens_mask: bool = False,
131 | return_length: bool = False,
132 | ):
133 | first_ids = self.get_input_ids(text)
134 | second_ids = self.get_input_ids(text_pair) if text_pair is not None else None
135 | encoded = self.prepare_for_model(
136 | first_ids,
137 | pair_ids=second_ids,
138 | add_special_tokens=add_special_tokens,
139 | padding=padding_strategy,
140 | truncation=truncation_strategy,
141 | max_length=max_length,
142 | stride=stride,
143 | return_attention_mask=return_attention_mask,
144 | return_token_type_ids=return_token_type_ids,
145 | return_overflowing_tokens=return_overflowing_tokens,
146 | return_special_tokens_mask=return_special_tokens_mask,
147 | return_length=return_length
148 | )
149 | pinyin_ids = self.convert_token_ids_to_pinyin_ids(encoded['input_ids'])
150 | assert len(pinyin_ids) == len(encoded['input_ids'])
151 | encoded['pinyin_ids'] = pinyin_ids
152 | return encoded
153 |
154 | def _batch_encode_plus(
155 | self,
156 | batch_text_or_text_pairs: Union[
157 | List[str],
158 | List[Tuple[str, str]],
159 | List[Tuple[List[str], List[str]]],
160 | List[Tuple[str, str]],
161 | List[List[int]],
162 | List[Tuple[List[int], List[int]]],
163 | ],
164 | add_special_tokens: bool = True,
165 | padding_strategy: Union[bool, str, PaddingStrategy] = PaddingStrategy.DO_NOT_PAD,
166 | truncation_strategy: Union[bool, str, TruncationStrategy] = TruncationStrategy.DO_NOT_TRUNCATE,
167 | max_length: Optional[int] = None,
168 | stride: int = 0,
169 | is_split_into_words: bool = False,
170 | return_token_type_ids: Optional[bool] = None,
171 | return_attention_mask: Optional[bool] = None,
172 | return_overflowing_tokens: bool = False,
173 | return_special_tokens_mask: bool = False,
174 | return_length: bool = False,
175 | ):
176 | input_ids = []
177 | for ids_or_pair_ids in batch_text_or_text_pairs:
178 | if not isinstance(ids_or_pair_ids, (list, tuple)):
179 | ids, pair_ids = ids_or_pair_ids, None
180 | elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
181 | ids, pair_ids = ids_or_pair_ids, None
182 | else:
183 | ids, pair_ids = ids_or_pair_ids
184 |
185 | first_ids = self.get_input_ids(ids)
186 | second_ids = self.get_input_ids(pair_ids) if pair_ids is not None else None
187 | input_ids.append((first_ids, second_ids))
188 |
189 | batch_outputs = self._batch_prepare_for_model(
190 | input_ids,
191 | add_special_tokens=add_special_tokens,
192 | padding_strategy=padding_strategy,
193 | truncation_strategy=truncation_strategy,
194 | max_length=max_length,
195 | stride=stride,
196 | return_attention_mask=return_attention_mask,
197 | return_token_type_ids=return_token_type_ids,
198 | return_overflowing_tokens=return_overflowing_tokens,
199 | return_special_tokens_mask=return_special_tokens_mask,
200 | return_length=return_length
201 | )
202 | batch_pinyin_ids = []
203 | for i in batch_outputs['input_ids']:
204 | pinyin_ids = self.convert_token_ids_to_pinyin_ids(batch_outputs['input_ids'][i])
205 | assert len(pinyin_ids) == len(batch_outputs['input_ids'][i])
206 | batch_pinyin_ids.append(pinyin_ids)
207 | batch_outputs['pinyin_ids'] = batch_pinyin_ids
208 | return batch_outputs
209 |
--------------------------------------------------------------------------------
/tfbert/tokenizer/wobert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | """
3 | @author: huanghui
4 | @file: wobert.py
5 | @date: 2020/09/19
6 | """
7 | from .tokenization_base import convert_to_unicode
8 | from .bert import BertTokenizer
9 |
10 |
11 | class WoBertTokenizer(BertTokenizer):
12 | def __init__(self, seg_fn=None, **kwargs):
13 | super(WoBertTokenizer, self).__init__(**kwargs)
14 | import jieba
15 | if seg_fn is None:
16 | self.seg_fn = lambda x: jieba.cut(x, HMM=False)
17 | else:
18 | self.seg_fn = seg_fn
19 |
20 | def tokenize(self, text):
21 | text = convert_to_unicode(text)
22 | split_tokens = []
23 | for token in self.seg_fn(text):
24 | if token in self.vocab:
25 | split_tokens.append(token)
26 | else:
27 | for t in self.basic_tokenizer.tokenize(token):
28 | for sub_token in self.wordpiece_tokenizer.tokenize(t):
29 | split_tokens.append(sub_token)
30 | return split_tokens
31 |
--------------------------------------------------------------------------------