├── tokenizer ├── sentencepiece_cn.model ├── reduce.py └── count.py ├── data_sample.tsv ├── README.md └── task_autotitle_csl.py /tokenizer/sentencepiece_cn.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bojone/t5_in_bert4keras/HEAD/tokenizer/sentencepiece_cn.model -------------------------------------------------------------------------------- /tokenizer/reduce.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 根据count的结果精简sentencepiece模型 3 | # 注:需要 sentencepiece>=0.1.94 4 | 5 | from tqdm import tqdm 6 | import json 7 | import pandas as pd 8 | from sentencepiece import sentencepiece_model_pb2 as model 9 | import sentencepiece as spm 10 | 11 | min_count = 1000 12 | old_model = '/root/kg/bert/mt5/sentencepiece.model' 13 | new_model = '/root/kg/bert/mt5/sentencepiece_cn.model' 14 | new_model_keep_tokens = '/root/kg/bert/mt5/sentencepiece_cn_keep_tokens.json' 15 | 16 | dic = json.load(open('result.json')) 17 | dic = pd.Series(dic).sort_values(ascending=False) 18 | dic = dic[dic >= min_count] 19 | dic = set(dic.index) 20 | 21 | m = model.ModelProto() 22 | m.ParseFromString(open(old_model, 'rb').read()) 23 | pieces = m.pieces[:259] + [p for p in m.pieces[259:] if p.piece in dic] + m.pieces[-100:] 24 | 25 | for i in tqdm(range(len(m.pieces))): 26 | del m.pieces[-1] 27 | 28 | m.pieces.extend(pieces) 29 | 30 | with open(new_model, 'wb') as f: 31 | f.write(m.SerializeToString()) 32 | 33 | sp1 = spm.SentencePieceProcessor() 34 | sp2 = spm.SentencePieceProcessor() 35 | 36 | sp1.load(old_model) 37 | sp2.load(new_model) 38 | 39 | keep_tokens = [] 40 | 41 | for i in range(sp2.get_piece_size()): 42 | keep_tokens.append(sp1.piece_to_id(sp2.id_to_piece(i))) 43 | 44 | keep_tokens.append(sp1.get_piece_size()) 45 | keep_tokens.append(sp1.get_piece_size() + 1) 46 | 47 | with open(new_model_keep_tokens, 'w') as f: 48 | json.dump(keep_tokens, f) 49 | -------------------------------------------------------------------------------- /tokenizer/count.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用tokenizer对语料分词,然后统计每个token的词频 3 | 4 | import glob, json, re 5 | import sentencepiece as spm 6 | from bert4keras.snippets import parallel_apply 7 | from tqdm import tqdm 8 | 9 | 10 | spm_path = '/root/kg/bert/mt5/sentencepiece.model' 11 | sp_model = spm.SentencePieceProcessor() 12 | sp_model.Load(spm_path) 13 | global_tokens = {} 14 | 15 | 16 | def corpus(): 17 | filenames = glob.glob('/root/data_pretrain/*/*') 18 | count, texts = 0, [] 19 | for filename in filenames: 20 | with open(filename) as f: 21 | for l in f: 22 | l = json.loads(l)['text'].strip() 23 | texts.append(l) 24 | count += 1 25 | if count == 1000: 26 | yield texts 27 | count, texts = 0, [] 28 | if texts: 29 | yield texts 30 | 31 | 32 | def count(texts): 33 | tokens = {} 34 | for text in texts: 35 | for t in sp_model.encode_as_pieces(text): 36 | tokens[t] = tokens.get(t, 0) + 1 37 | return tokens 38 | 39 | 40 | def callback(tokens): 41 | for k, v in tokens.items(): 42 | global_tokens[k] = global_tokens.get(k, 0) + v 43 | 44 | 45 | parallel_apply( 46 | func=count, 47 | iterable=tqdm(corpus()), 48 | workers=20, 49 | max_queue_size=1000, 50 | callback=callback, 51 | ) 52 | 53 | 54 | import pandas as pd 55 | 56 | dic = pd.Series(global_tokens).sort_values(ascending=False) 57 | dic.to_csv('result.csv', header=None, encoding='utf-8', sep='\t') 58 | json.dump(global_tokens, open('result.json', 'w')) 59 | -------------------------------------------------------------------------------- /data_sample.tsv: -------------------------------------------------------------------------------- 1 | 交换超立方体网络容错路由研究 为了研究交换超立方体网络容错路由问题,引入了相邻结点集合类的概念,提出了相邻结点集的求解公式。对于满足任意子连通性条件的交换超立方体网络,给出了基于相邻结点集合类的自适应容错路由算法及算法的步长上界。仿真实验结果表明算法是有效的。 2 | 一种基于通讯痕迹的社会网络团伙分析模型 研究在已知目标团伙中某节点以及目标团伙特征的前提下,基于通讯痕迹特征寻找社会网络团伙。研究过程中引入了社会圈、节点中心度和事件集合关联矩阵等概念,重点将聚类分析方法与社会团伙发现相结合,以期得到一种基于通讯痕迹的社会网络团伙分析模型。 3 | 基于Hadoop平台的XML文档重复数据检测 XML数据越来越广泛地被用于信息交换与集成中,其数据质量问题引起了人们的关注.解决由数据质量引发的问题,实体识别技术非常关键.为了克服现有方法的不足,在海量XML数据上进行高效的重复对象检测,以实体识别技术为基础提出了基于Hadoop平台的XML文档重复检测算法,它将所有标签节点统称为属性,用实体来描述属性,通过属性的比较,快速地找到在某些属性上相同的所有实体对象,并利用Hadoop应用框架处理海量数据的优势实现并行处理.经过试验验证该方法良好的扩展性,伸缩性和高效性. 4 | 快速码字搜索算法中一维特征量的最佳选择方法 矢量量化编码过程中的最近邻码字搜索需要进行大量的矢量间距离的计算,这个过程的计算复杂度极高,严重限制了其实际使用.为了加速矢量量化的编码过程,许多文献提出了各种不同组合的基于均值、2-范数、方差和角度的矢量一维特征量的快速最近邻矢量量化码字搜索算法.通过实验给出了这四个一维特征量单独使用以及相互组合的所有情况下各算法的搜索范围和编码时间,并对它们进行了比较和分析,进而提出了在实际进行编码时如何最优地进行一维特征量选取的准则. 5 | 海量病例CT图像的快速查找检索模型仿真 在海量病例CT图像的快速查找检索过程中,采用传统算法进行检索,由于计算复杂、计算量大等原因,造成病例CT图像查找检索效率过低的问题。为解决上述问题,提出了一种改进高阶统计量算法的海量病例CT图像的快速查找检索方法。通过Radon变换方法将病例CT图像代入到一维空间中,获取病例CT图像投影数据的双谱信息,将高阶统计量算法与亚像素边缘特征算法相融合,将亚像素级精度位置搜索的问题变为最小化函数,对病例CT图像的亚像素边缘特征进行有效的提取。采用奇异值-迭代最近点法(SVD-ICP)和小波极大值完成病例CT图像轮廓间配准融合,进而实现了海量病例CT图像的快速查找与检索。实验结果表明,提出的改进高阶统计量算法的海量病例CT图像的快速查找检索方法精确度高,实用性强。 6 | 基于像素分解的圆形标志点亚像素定位研究 影像中圆形标志点的定位对于数字摄影测量具有重要作用.通过对圆形标志点边缘处的混合像素进行亚像素定位,提取出标志点的亚像素级边缘,再基于最小二乘原理进行椭圆拟合得到圆形标志的中心坐标.运用三种实验表明,与直接采用像素级边缘进行拟合定位相比,该方法的精度明显提高. 7 | 基于稀疏低秩描述的图像检索方法 使用颜色、形状、纹理等特征的基于内容的图像检索技术,将图像看作向量空间中的点,通过计算两点之间的某种距离来衡量图像间的相似度,然而在提取图像特征时相同类型的图像会出现不一致的特征,极大地影响了检索算法的准确率。针对该问题,提出一种稀疏低秩描述的多特征图像检索方法。通过对图像集的稀疏低秩描述,保持了相同类别特征的全局结构,同时也降低了对于局部噪声的敏感度,增强了检索算法的鲁棒性。在Corel图像集上的检索实验结果表明,该方法较已有的基于内容的图像检索方法有更好的检索效果。 8 | 基于神经网络的铁水KR脱硫预报模型 将神经网络理论应用于铁水脱硫过程,研究工艺参数与其影响因子之间的关系,建立预报模型,为生产过程中工艺参数(搅拌时间、搅拌次数和加入剂量)的设定选择提供准确的预报。研究分析表明,该预报模型可以应用于实际生产,提高铁水的脱硫成功的命中率,降低铁水的脱硫成本。 9 | VANET安全技术综述 随着车载自组织网络技术的不断发展,研究者对车载自组织网络系统安全进行了深入研究.论文阐述了车载自组织网络领域中安全研究的重要性;介绍了该领域中目前最新研究进展和存在的主要问题;讨论并比较了各种安全协议应用于车载自组织网络的优缺点;分析总结了系统中安全协议的设计要素;最后展望了车载自组织网络安全技术的未来研究方向. 10 | 基于DSpace构建传统蒙古文学科机构知识库平台 本文主要阐述了基于DSpace构建传统蒙古文学科机构知识库的难点以及解决的技术路线,包括蒙古文数字资料的采集、存储、检索以及显示等。针对蒙古文的构词和语法等方面的特点,对开源搜索引擎Lucene进行改进——采用B树管理Term、简化了特征词权值的计算、采用EC方法确定了蒙古文停用词表,实现了基于Lucene的蒙古文检索。 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # T5 in bert4keras 2 | 整理一下在keras中使用T5模型的要点,尤其是中文场景下的使用要点。以多国语言版mT5为例。 3 | 4 | 博客链接:https://kexue.fm/archives/7867 5 | 6 | 本项目实验环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.9.1 7 | 8 | ## 模型下载 9 | 10 | 首先,要想办法下载Google开放的权重,最简单的方式,是找一台能科学上网的服务器,在上面安装gsutil,然后执行 11 | ```shell 12 | gsutil cp -r gs://t5-data/pretrained_models/mt5/small . 13 | ``` 14 | 15 | T5使用sentencepiece作为tokenizer,mT5的tokenizer模型下载地址为 16 | ```shell 17 | gsutil cp -r gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model . 18 | ``` 19 | 20 | 笔者精简好的tokenizer文件:[sentencepiece_cn.model](https://github.com/bojone/t5_in_bert4keras/blob/main/tokenizer/sentencepiece_cn.model)和[sentencepiece_cn_keep_tokens.json](https://github.com/bojone/t5_in_bert4keras/blob/main/tokenizer/sentencepiece_cn_keep_tokens.json) 21 | 22 | 另外,为了方便国内用户,笔者将small版和base版整理分享到[百度网盘](https://pan.baidu.com/s/1YWaStqB6Epkxfyx6WcOzWw)(mwfc)了。 23 | 24 | ## Config 25 | 26 | T5模型的配置文件是gin格式的,这不符合bert4keras的输入,使用者请根据所给的gin和下述模版构建对应的config.json文件。 27 | 28 | 下面是mT5 small版的参考config.json: 29 | ```python 30 | { 31 | "hidden_dropout_prob": 0.1, 32 | "hidden_size": 512, 33 | "initializer_range": 0.02, 34 | "intermediate_size": 1024, 35 | "num_attention_heads": 6, 36 | "attention_head_size": 64, 37 | "num_hidden_layers": 8, 38 | "vocab_size": 250112, 39 | "hidden_act": ["gelu", "linear"] 40 | } 41 | ``` 42 | 43 | 一般要修改的是`hidden_size`、`intermediate_size`、`num_attention_heads`、`attention_head_size`和`num_hidden_layers`这几个参数。 44 | 45 | ## 基本使用 46 | 47 | ```python 48 | # 模型路径 49 | config_path = '/root/kg/bert/mt5/mt5_small/t5_config.json' 50 | checkpoint_path = '/root/kg/bert/mt5/mt5_small/model.ckpt-1000000' 51 | spm_path = '/root/kg/bert/mt5/sentencepiece.model' 52 | 53 | # 加载分词器 54 | tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>') 55 | 56 | # 加载模型 57 | t5 = build_transformer_model( 58 | config_path=config_path, 59 | checkpoint_path=checkpoint_path, 60 | model='t5.1.1', 61 | return_keras_model=False, 62 | name='T5', 63 | ) 64 | 65 | encoder = t5.encoder 66 | decoder = t5.decoder 67 | model = t5.model 68 | ``` 69 | 70 | ## 中文优化 71 | 72 | ```python 73 | # 模型路径 74 | config_path = '/root/kg/bert/mt5/mt5_base/t5_config.json' 75 | checkpoint_path = '/root/kg/bert/mt5/mt5_base/model.ckpt-1000000' 76 | spm_path = '/root/kg/bert/mt5/sentencepiece_cn.model' 77 | keep_tokens_path = '/root/kg/bert/mt5/sentencepiece_cn_keep_tokens.json' 78 | 79 | # 加载分词器 80 | tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>') 81 | keep_tokens = json.load(open(keep_tokens_path)) 82 | 83 | # 加载模型 84 | t5 = build_transformer_model( 85 | config_path=config_path, 86 | checkpoint_path=checkpoint_path, 87 | keep_tokens=keep_tokens, 88 | model='t5.1.1', 89 | return_keras_model=False, 90 | name='T5', 91 | ) 92 | 93 | encoder = t5.encoder 94 | decoder = t5.decoder 95 | model = t5.model 96 | ``` 97 | 98 | 细节请参考:[task_autotitle_csl.py](https://github.com/bojone/t5_in_bert4keras/blob/main/task_autotitle_csl.py)。 99 | 100 | ## 交流联系 101 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn 102 | -------------------------------------------------------------------------------- /task_autotitle_csl.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 微调多国语言版T5做Seq2Seq任务 3 | # 介绍链接:kexue.fm/archives/7867 4 | # 数据集:https://github.com/CLUEbenchmark/CLGE 中的CSL数据集 5 | # 补充了评测指标bleu、rouge-1、rouge-2、rouge-l 6 | 7 | from __future__ import print_function 8 | import json 9 | import numpy as np 10 | from tqdm import tqdm 11 | from bert4keras.backend import keras, K 12 | from bert4keras.layers import Loss 13 | from bert4keras.models import build_transformer_model 14 | from bert4keras.tokenizers import SpTokenizer 15 | from bert4keras.optimizers import Adam 16 | from bert4keras.snippets import sequence_padding, open 17 | from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder 18 | from keras.models import Model 19 | from rouge import Rouge # pip install rouge 20 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 21 | 22 | # 基本参数 23 | max_c_len = 256 24 | max_t_len = 32 25 | batch_size = 16 26 | epochs = 40 27 | 28 | # 模型路径 29 | config_path = '/root/kg/bert/mt5/mt5_base/mt5_base_config.json' 30 | checkpoint_path = '/root/kg/bert/mt5/mt5_base/model.ckpt-1000000' 31 | spm_path = '/root/kg/bert/mt5/sentencepiece_cn.model' 32 | keep_tokens_path = '/root/kg/bert/mt5/sentencepiece_cn_keep_tokens.json' 33 | 34 | 35 | def load_data(filename): 36 | D = [] 37 | with open(filename, encoding='utf-8') as f: 38 | for l in f: 39 | title, content = l.strip().split('\t') 40 | D.append((title, content)) 41 | return D 42 | 43 | 44 | # 加载数据集 45 | train_data = load_data('/root/csl/train.tsv') 46 | valid_data = load_data('/root/csl/val.tsv') 47 | test_data = load_data('/root/csl/test.tsv') 48 | 49 | # 加载分词器 50 | tokenizer = SpTokenizer(spm_path, token_start=None, token_end='') 51 | keep_tokens = json.load(open(keep_tokens_path)) 52 | 53 | 54 | class data_generator(DataGenerator): 55 | """数据生成器 56 | """ 57 | def __iter__(self, random=False): 58 | batch_c_token_ids, batch_t_token_ids = [], [] 59 | for is_end, (title, content) in self.sample(random): 60 | c_token_ids, _ = tokenizer.encode(content, maxlen=max_c_len) 61 | t_token_ids, _ = tokenizer.encode(title, maxlen=max_t_len) 62 | batch_c_token_ids.append(c_token_ids) 63 | batch_t_token_ids.append([0] + t_token_ids) 64 | if len(batch_c_token_ids) == self.batch_size or is_end: 65 | batch_c_token_ids = sequence_padding(batch_c_token_ids) 66 | batch_t_token_ids = sequence_padding(batch_t_token_ids) 67 | yield [batch_c_token_ids, batch_t_token_ids], None 68 | batch_c_token_ids, batch_t_token_ids = [], [] 69 | 70 | 71 | class CrossEntropy(Loss): 72 | """交叉熵作为loss,并mask掉输入部分 73 | """ 74 | def compute_loss(self, inputs, mask=None): 75 | y_true, y_pred = inputs 76 | y_true = y_true[:, 1:] # 目标token_ids 77 | y_mask = K.cast(mask[1], K.floatx())[:, :-1] # 解码器自带mask 78 | y_pred = y_pred[:, :-1] # 预测序列,错开一位 79 | loss = K.sparse_categorical_crossentropy(y_true, y_pred) 80 | loss = K.sum(loss * y_mask) / K.sum(y_mask) 81 | return loss 82 | 83 | 84 | t5 = build_transformer_model( 85 | config_path=config_path, 86 | checkpoint_path=checkpoint_path, 87 | keep_tokens=keep_tokens, 88 | model='t5.1.1', 89 | return_keras_model=False, 90 | name='T5', 91 | ) 92 | 93 | encoder = t5.encoder 94 | decoder = t5.decoder 95 | model = t5.model 96 | model.summary() 97 | 98 | output = CrossEntropy(1)([model.inputs[1], model.outputs[0]]) 99 | 100 | model = Model(model.inputs, output) 101 | model.compile(optimizer=Adam(2e-4)) 102 | 103 | 104 | class AutoTitle(AutoRegressiveDecoder): 105 | """seq2seq解码器 106 | """ 107 | @AutoRegressiveDecoder.wraps(default_rtype='probas') 108 | def predict(self, inputs, output_ids, states): 109 | c_encoded = inputs[0] 110 | return decoder.predict([c_encoded, output_ids])[:, -1] 111 | 112 | def generate(self, text, topk=1): 113 | c_token_ids, _ = tokenizer.encode(text, maxlen=max_c_len) 114 | c_encoded = encoder.predict(np.array([c_token_ids]))[0] 115 | output_ids = self.beam_search([c_encoded], topk) # 基于beam search 116 | return tokenizer.decode([int(i) for i in output_ids]) 117 | 118 | 119 | # 注:T5有一个很让人不解的设置,它的标记id是0,即其实都是0 120 | autotitle = AutoTitle(start_id=0, end_id=tokenizer._token_end_id, maxlen=32) 121 | 122 | 123 | class Evaluator(keras.callbacks.Callback): 124 | """评估与保存 125 | """ 126 | def __init__(self): 127 | self.rouge = Rouge() 128 | self.smooth = SmoothingFunction().method1 129 | self.best_bleu = 0. 130 | 131 | def on_epoch_end(self, epoch, logs=None): 132 | metrics = self.evaluate(valid_data) # 评测模型 133 | if metrics['bleu'] > self.best_bleu: 134 | self.best_bleu = metrics['bleu'] 135 | model.save_weights('./best_model.weights') # 保存模型 136 | metrics['best_bleu'] = self.best_bleu 137 | print('valid_data:', metrics) 138 | 139 | def evaluate(self, data, topk=1): 140 | total = 0 141 | rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0 142 | for title, content in tqdm(data): 143 | total += 1 144 | title = ' '.join(title).lower() 145 | pred_title = ' '.join(autotitle.generate(content, topk)).lower() 146 | if pred_title.strip(): 147 | scores = self.rouge.get_scores(hyps=pred_title, refs=title) 148 | rouge_1 += scores[0]['rouge-1']['f'] 149 | rouge_2 += scores[0]['rouge-2']['f'] 150 | rouge_l += scores[0]['rouge-l']['f'] 151 | bleu += sentence_bleu( 152 | references=[title.split(' ')], 153 | hypothesis=pred_title.split(' '), 154 | smoothing_function=self.smooth 155 | ) 156 | rouge_1 /= total 157 | rouge_2 /= total 158 | rouge_l /= total 159 | bleu /= total 160 | return { 161 | 'rouge-1': rouge_1, 162 | 'rouge-2': rouge_2, 163 | 'rouge-l': rouge_l, 164 | 'bleu': bleu, 165 | } 166 | 167 | 168 | if __name__ == '__main__': 169 | 170 | evaluator = Evaluator() 171 | train_generator = data_generator(train_data, batch_size) 172 | 173 | model.fit( 174 | train_generator.forfit(), 175 | steps_per_epoch=len(train_generator), 176 | epochs=epochs, 177 | callbacks=[evaluator] 178 | ) 179 | 180 | else: 181 | 182 | model.load_weights('./best_model.weights') 183 | --------------------------------------------------------------------------------