├── README.md ├── test ├── retrieval.py └── generate.py ├── train ├── supervised.py ├── stage1.py └── stage2.py └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | # RoFormer-Sim 2 | RoFormer-Sim,又称SimBERTv2,是我们之前发布的[SimBERT](https://github.com/ZhuiyiTechnology/simbert)模型的升级版。 3 | 4 | ## 介绍 5 | - 弱监督:https://kexue.fm/archives/8454 6 | - 有监督:https://kexue.fm/archives/8541 7 | 8 | ## 训练 9 | tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 10 | 11 | ## 下载 12 | - [chinese_roformer-sim-char_L-12_H-768_A-12.zip](https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roformer-sim-char_L-12_H-768_A-12.zip) 13 | - [chinese_roformer-sim-char_L-6_H-384_A-6.zip](https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roformer-sim-char_L-6_H-384_A-6.zip) 14 | - [chinese_roformer-sim-char-ft_L-12_H-768_A-12.zip](https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roformer-sim-char-ft_L-12_H-768_A-12.zip) 15 | - [chinese_roformer-sim-char-ft_L-6_H-384_A-6.zip](https://open.zhuiyi.ai/releases/nlp/models/zhuiyi/chinese_roformer-sim-char-ft_L-6_H-384_A-6.zip) 16 | 17 | ## 引用 18 | 19 | Bibtex: 20 | 21 | ```tex 22 | @techreport{roformer-sim, 23 | title={RoFormer-Sim: Integrating Retrieval and Generation into RoFormer}, 24 | author={Jianlin Su}, 25 | year={2021}, 26 | url="https://github.com/ZhuiyiTechnology/roformer-sim", 27 | } 28 | ``` 29 | 30 | ## 联系 31 | 32 | 邮箱:ai@wezhuiyi.com 33 | 34 | ## 链接 35 | 36 | 追一科技:https://zhuiyi.ai 37 | -------------------------------------------------------------------------------- /test/retrieval.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 测试有监督版RoFormer-Sim-FT的相似度效果 3 | 4 | import numpy as np 5 | from bert4keras.backend import keras 6 | from bert4keras.models import build_transformer_model 7 | from bert4keras.tokenizers import Tokenizer 8 | from bert4keras.snippets import sequence_padding 9 | from keras.models import Model 10 | 11 | maxlen = 64 12 | 13 | # bert配置 14 | config_path = '/root/kg/bert/chinese_roformer-sim-char-ft_L-12_H-768_A-12/bert_config.json' 15 | checkpoint_path = '/root/kg/bert/chinese_roformer-sim-char-ft_L-12_H-768_A-12/bert_model.ckpt' 16 | dict_path = '/root/kg/bert/chinese_roformer-sim-char-ft_L-12_H-768_A-12/vocab.txt' 17 | 18 | # 建立分词器 19 | tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器 20 | 21 | # 建立加载模型 22 | bert = build_transformer_model( 23 | config_path, 24 | checkpoint_path, 25 | model='roformer', 26 | with_pool='linear', 27 | application='unilm', 28 | return_keras_model=False, 29 | ) 30 | 31 | encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0]) 32 | 33 | 34 | def similarity(text1, text2): 35 | """"计算text1与text2的相似度 36 | """ 37 | texts = [text1, text2] 38 | X, S = [], [] 39 | for t in texts: 40 | x, s = tokenizer.encode(t, maxlen=maxlen) 41 | X.append(x) 42 | S.append(s) 43 | X = sequence_padding(X) 44 | S = sequence_padding(S) 45 | Z = encoder.predict([X, S]) 46 | Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 47 | return (Z[0] * Z[1]).sum() 48 | 49 | 50 | similarity(u'今天天气不错', u'今天天气很好') 51 | similarity(u'今天天气不错', u'今天天气不好') 52 | similarity(u'我喜欢北京', u'我很喜欢北京') 53 | similarity(u'我喜欢北京', u'我不喜欢北京') 54 | similarity(u'电影不错', u'电影很好') 55 | similarity(u'电影不错', u'电影不好') 56 | similarity(u'红色的苹果', u'绿色的苹果') 57 | similarity(u'给我推荐一款红色的车', u'给我推荐一款黑色的车') 58 | similarity(u'给我推荐一款红色的车', u'推荐一辆红车') 59 | similarity(u'给我推荐一款红色的车', u'麻烦来一辆红车') 60 | -------------------------------------------------------------------------------- /test/generate.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # RoFormer-Sim base 基本例子 3 | # 测试环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 4 | 5 | import numpy as np 6 | from bert4keras.backend import keras, K 7 | from bert4keras.models import build_transformer_model 8 | from bert4keras.tokenizers import Tokenizer 9 | from bert4keras.snippets import sequence_padding, AutoRegressiveDecoder 10 | from bert4keras.snippets import uniout 11 | 12 | maxlen = 64 13 | 14 | # 模型配置 15 | config_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/bert_config.json' 16 | checkpoint_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/bert_model.ckpt' 17 | dict_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/vocab.txt' 18 | 19 | # 建立分词器 20 | tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器 21 | 22 | # 建立加载模型 23 | roformer = build_transformer_model( 24 | config_path, 25 | checkpoint_path, 26 | model='roformer', 27 | application='unilm', 28 | with_pool='linear' 29 | ) 30 | 31 | encoder = keras.models.Model(roformer.inputs, roformer.outputs[0]) 32 | seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1]) 33 | 34 | 35 | class SynonymsGenerator(AutoRegressiveDecoder): 36 | """seq2seq解码器 37 | """ 38 | @AutoRegressiveDecoder.wraps(default_rtype='probas') 39 | def predict(self, inputs, output_ids, step): 40 | token_ids, segment_ids = inputs 41 | token_ids = np.concatenate([token_ids, output_ids], 1) 42 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 43 | return self.last_token(seq2seq).predict([token_ids, segment_ids]) 44 | 45 | def generate(self, text, n=1, topp=0.95, mask_idxs=[]): 46 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 47 | for i in mask_idxs: 48 | token_ids[i] = tokenizer._token_mask_id 49 | output_ids = self.random_sample([token_ids, segment_ids], n, 50 | topp=topp) # 基于随机采样 51 | return [tokenizer.decode(ids) for ids in output_ids] 52 | 53 | 54 | synonyms_generator = SynonymsGenerator( 55 | start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen 56 | ) 57 | 58 | 59 | def gen_synonyms(text, n=100, k=20, mask_idxs=[]): 60 | ''''含义: 产生sent的n个相似句,然后返回最相似的k个。 61 | 做法:用seq2seq生成,并用encoder算相似度并排序。 62 | ''' 63 | r = synonyms_generator.generate(text, n, mask_idxs=mask_idxs) 64 | r = [i for i in set(r) if i != text] 65 | r = [text] + r 66 | X, S = [], [] 67 | for t in r: 68 | x, s = tokenizer.encode(t) 69 | X.append(x) 70 | S.append(s) 71 | X = sequence_padding(X) 72 | S = sequence_padding(S) 73 | Z = encoder.predict([X, S]) 74 | Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 75 | argsort = np.dot(Z[1:], -Z[0]).argsort() 76 | return [r[i + 1] for i in argsort[:k]] 77 | 78 | 79 | """ 80 | gen_synonyms(u'广州和深圳哪个好?') 81 | [ 82 | '深圳和广州哪个好?', 83 | '广州和深圳哪个好', 84 | '广州和深圳哪个更好?', 85 | '深圳和广州哪个更好?', 86 | '深圳和广州,那个更好?', 87 | '深圳和广州哪个好一些呢?', 88 | '深圳好还是广州好?', 89 | '广州和深圳哪个地方好点?', 90 | '广州好还是深圳好?', 91 | '广州和深圳哪个好一点', 92 | '广州和深圳哪个发展好?', 93 | '深圳好还是广州好', 94 | '深圳和广州哪个城市更好些', 95 | '深圳比广州好吗?', 96 | '到底深圳和广州哪个好?为什么呢?', 97 | '深圳究竟好还是广州好', 98 | '一般是深圳好还是广州好', 99 | '广州和深圳那个发展好点', 100 | '好一点的深圳和广州那边好?', 101 | '深圳比广州好在哪里?' 102 | ] 103 | 104 | gen_synonyms(u'科学技术是第一生产力。') 105 | [ 106 | '科学技术是第一生产力!', 107 | '科学技术是第一生产力', 108 | '一、科学技术是第一生产力。', 109 | '一是科学技术是第一生产力。', 110 | '第一,科学技术是第一生产力。', 111 | '第一生产力是科学技术。', 112 | '因为科学技术是第一生产力。', 113 | '科学技术是第一生产力知。', 114 | '也即科学技术是第一生产力。', 115 | '科学技术是第一生产力吗', 116 | '科技是第一生产力。', 117 | '因此,科学技术是第一生产力。', 118 | '其次,科学技术是第一生产力。', 119 | '科学技术才是第一生产力。', 120 | '科学技术是第一生产力吗?', 121 | '第二,科学技术是第一生产力。', 122 | '所以说科学技术是第一生产力。', 123 | '科学技术确实是第一生产力。', 124 | '科学技术还是第一生产力', 125 | '科学技术是第一生产力对吗?' 126 | ] 127 | """ 128 | -------------------------------------------------------------------------------- /train/supervised.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # SimBERT v2 监督训练代码 3 | # 训练环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 4 | 5 | import json, glob 6 | import numpy as np 7 | import tensorflow as tf 8 | from bert4keras.backend import keras, K 9 | from bert4keras.models import build_transformer_model 10 | from bert4keras.tokenizers import Tokenizer 11 | from bert4keras.optimizers import Adam, extend_with_weight_decay 12 | from bert4keras.snippets import DataGenerator, sequence_padding 13 | from bert4keras.snippets import text_segmentate, truncate_sequences 14 | from bert4keras.snippets import AutoRegressiveDecoder, open 15 | 16 | # 基本信息 17 | maxlen = 64 18 | batch_size = 192 19 | steps_per_epoch = 1000 20 | epochs = 10000 21 | labels = ['contradiction', 'entailment', 'neutral'] 22 | 23 | # bert配置 24 | config_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/bert_config.json' 25 | checkpoint_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/bert_model.ckpt' 26 | dict_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/vocab.txt' 27 | 28 | # 建立分词器 29 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 30 | 31 | 32 | def split(text): 33 | """分割句子 34 | """ 35 | seps, strips = u'\n。!?!?;;,, ', u';;,, ' 36 | return text_segmentate(text, maxlen * 1.2, seps, strips) 37 | 38 | 39 | def load_data_1(filename, threshold=0.5): 40 | """加载数据(带标签) 41 | 单条格式:(文本1, 文本2, 标签) 42 | """ 43 | D = [] 44 | with open(filename, encoding='utf-8') as f: 45 | for l in f: 46 | l = l.strip().split('\t') 47 | if len(l) == 3: 48 | l[0], l[1] = split(l[0])[0], split(l[1])[0] 49 | D.append((l[0], l[1], int(float(l[2]) > threshold))) 50 | return D 51 | 52 | 53 | # 加载数据集 54 | data_path = '/root/senteval_cn/' 55 | datasets_1 = [] 56 | for task_name in ['ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STS-B', 'SOHU21-SSB']: 57 | for f in ['train', 'valid']: 58 | threshold = 2.5 if task_name == 'STS-B' else 0.5 59 | filename = '%s%s/%s.%s.data' % (data_path, task_name, task_name, f) 60 | datasets_1 += load_data_1(filename, threshold) 61 | 62 | 63 | def load_data_2(filename): 64 | """加载数据(带标签) 65 | 单条格式:(文本1, 文本2, 标签) 66 | """ 67 | D = [] 68 | with open(filename, encoding='utf-8') as f: 69 | for l in f: 70 | l = json.loads(l) 71 | if l['gold_label'] not in labels: 72 | continue 73 | text1 = split(l['sentence1'])[0] 74 | text2 = split(l['sentence2'])[0] 75 | label = labels.index(l['gold_label']) + 2 76 | D.append((text1, text2, label)) 77 | return D 78 | 79 | 80 | # 加载数据集 81 | datasets_2 = [] 82 | for f in glob.glob('/root/cnsd/cnsd-*/*.jsonl'): 83 | datasets_2 += load_data_2(f) 84 | 85 | 86 | def corpus(): 87 | """合并语料,1:1采样 88 | """ 89 | def generator(dataset): 90 | while True: 91 | idxs = np.random.permutation(len(dataset)) 92 | for i in idxs: 93 | yield dataset[i] 94 | 95 | corpus_1 = generator(datasets_1) 96 | corpus_2 = generator(datasets_2) 97 | 98 | while True: 99 | yield next(corpus_1) 100 | yield next(corpus_2) 101 | 102 | 103 | class data_generator(DataGenerator): 104 | """数据生成器 105 | """ 106 | def __iter__(self, random=False): 107 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 108 | for is_end, (text1, text2, label) in self.sample(random): 109 | for text in [text1, text2]: 110 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 111 | batch_token_ids.append(token_ids) 112 | batch_segment_ids.append(segment_ids) 113 | batch_labels.append([label]) 114 | if len(batch_token_ids) == self.batch_size or is_end: 115 | batch_token_ids = sequence_padding(batch_token_ids) 116 | batch_segment_ids = sequence_padding(batch_segment_ids) 117 | batch_labels = np.array(batch_labels) 118 | yield [batch_token_ids, batch_segment_ids], batch_labels 119 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 120 | 121 | 122 | def merge(inputs): 123 | """向量合并:a、b、|a-b|拼接 124 | """ 125 | a, b = inputs[::2], inputs[1::2] 126 | o = K.concatenate([a, b, K.abs(a - b)], axis=1) 127 | return K.repeat_elements(o, 2, 0) 128 | 129 | 130 | def special_crossentropy(y_true, y_pred): 131 | """特殊的交叉熵 132 | """ 133 | task = K.cast(y_true < 1.5, K.floatx()) 134 | mask = K.constant([[0, 0, 1, 1, 1]]) 135 | y_pred_1 = y_pred - mask * 1e12 136 | y_pred_2 = y_pred - (1 - mask) * 1e12 137 | y_pred = task * y_pred_1 + (1 - task) * y_pred_2 138 | y_true = K.cast(y_true, 'int32') 139 | loss = K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True) 140 | return K.mean(loss) 141 | 142 | 143 | # 建立加载模型 144 | encoder = build_transformer_model( 145 | config_path, 146 | checkpoint_path, 147 | model='roformer', 148 | with_pool='linear', 149 | dropout_rate=0.2, 150 | ignore_invalid_weights=True 151 | ) 152 | output = keras.layers.Lambda(merge)(encoder.output) 153 | output = keras.layers.Dense(5, use_bias=False)(output) 154 | 155 | model = keras.models.Model(encoder.inputs, output) 156 | AdamW = extend_with_weight_decay(Adam, 'AdamW') 157 | optimizer = AdamW(learning_rate=2e-5, weight_decay_rate=0.01) 158 | model.compile(loss=special_crossentropy, optimizer=optimizer) 159 | model.summary() 160 | 161 | 162 | class Evaluator(keras.callbacks.Callback): 163 | """保存模型 164 | """ 165 | def on_epoch_end(self, epoch, logs=None): 166 | encoder.save_weights('./latest_model.weights') 167 | if (epoch + 1) % 5 == 0: 168 | encoder.save_weights('roformer-sim.%s.weights' % (epoch + 1)) 169 | 170 | 171 | if __name__ == '__main__': 172 | 173 | train_generator = data_generator(corpus(), batch_size) 174 | evaluator = Evaluator() 175 | 176 | model.fit_generator( 177 | train_generator.forfit(), 178 | steps_per_epoch=steps_per_epoch, 179 | epochs=epochs, 180 | callbacks=[evaluator] 181 | ) 182 | 183 | else: 184 | 185 | encoder.load_weights('./latest_model.weights') 186 | -------------------------------------------------------------------------------- /train/stage1.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # SimBERT v2 训练代码 3 | # 训练环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.backend import keras, K 8 | from bert4keras.layers import Loss 9 | from bert4keras.models import build_transformer_model 10 | from bert4keras.tokenizers import Tokenizer 11 | from bert4keras.optimizers import Adam, extend_with_weight_decay 12 | from bert4keras.snippets import DataGenerator, sequence_padding 13 | from bert4keras.snippets import text_segmentate, truncate_sequences 14 | from bert4keras.snippets import AutoRegressiveDecoder 15 | import jieba 16 | jieba.initialize() 17 | 18 | # 基本信息 19 | maxlen = 64 20 | batch_size = 96 21 | steps_per_epoch = 1000 22 | epochs = 10000 23 | 24 | # bert配置 25 | config_path = '/root/kg/bert/chinese_roformer-char_L-12_H-768_A-12/bert_config.json' 26 | checkpoint_path = '/root/kg/bert/chinese_roformer-char_L-12_H-768_A-12/bert_model.ckpt' 27 | dict_path = '/root/kg/bert/chinese_roformer-char_L-12_H-768_A-12/vocab.txt' 28 | 29 | # 建立分词器 30 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 31 | 32 | 33 | def read(filename): 34 | """读取语料,每行一个json 35 | """ 36 | while True: 37 | with open(filename) as f: 38 | for l in f: 39 | yield json.loads(l) 40 | 41 | 42 | def split(text): 43 | """分割句子 44 | """ 45 | seps, strips = u'\n。!?!?;;,, ', u';;,, ' 46 | return text_segmentate(text, maxlen * 1.2, seps, strips) 47 | 48 | 49 | def corpus(): 50 | """读取语料 51 | """ 52 | f1 = read('/root/data_pretrain/synonyms_shuf.json') 53 | f2 = read('/root/data_pretrain/synonym_answers_shuf.json') 54 | f3 = read('/root/data_pretrain/synonym/synonym_gen_2_shuf.json') 55 | while True: 56 | d = next(f1) 57 | text, synonyms = d['text'], d['synonyms'] 58 | text, synonym = np.random.permutation([text] + synonyms)[:2] 59 | text, synonym = split(text)[0], split(synonym)[0] 60 | yield text, synonym 61 | d = next(f2) 62 | text, synonym = d['text_a'], d['text_b'] 63 | text, synonym = split(text)[0], split(synonym)[0] 64 | yield text, synonym 65 | d = next(f1) 66 | text, synonyms = d['text'], d['synonyms'] 67 | text, synonym = np.random.permutation([text] + synonyms)[:2] 68 | text, synonym = split(text)[0], split(synonym)[0] 69 | yield text, synonym 70 | d = next(f3) 71 | text, synonym = d['text_a'], d['text_b'] 72 | text, synonym = split(text)[0], split(synonym)[0] 73 | yield text, synonym 74 | 75 | 76 | def masked_encode(text): 77 | """wwm随机mask 78 | """ 79 | words = jieba.lcut(text) 80 | rands = np.random.random(len(words)) 81 | source, target = [tokenizer._token_start_id], [0] 82 | for r, w in zip(rands, words): 83 | ids = tokenizer.encode(w)[0][1:-1] 84 | if r < 0.15 * 0.8: 85 | source.extend([tokenizer._token_mask_id] * len(ids)) 86 | target.extend(ids) 87 | elif r < 0.15 * 0.9: 88 | source.extend(ids) 89 | target.extend(ids) 90 | elif r < 0.15: 91 | source.extend( 92 | np.random.choice(tokenizer._vocab_size - 1, size=len(ids)) + 1 93 | ) 94 | target.extend(ids) 95 | else: 96 | source.extend(ids) 97 | target.extend([0] * len(ids)) 98 | source = source[:maxlen - 1] + [tokenizer._token_end_id] 99 | target = target[:maxlen - 1] + [0] 100 | return source, target 101 | 102 | 103 | class data_generator(DataGenerator): 104 | """数据生成器 105 | """ 106 | def __init__(self, *args, **kwargs): 107 | super(data_generator, self).__init__(*args, **kwargs) 108 | self.some_samples = [] 109 | 110 | def __iter__(self, random=False): 111 | batch_token_ids, batch_segment_ids = [], [] 112 | for is_end, (text, synonym) in self.sample(random): 113 | for i in range(2): 114 | if np.random.random() < 0.5: 115 | text_ids = masked_encode(text)[0] 116 | else: 117 | text_ids = tokenizer.encode(text)[0] 118 | synonym_ids = tokenizer.encode(synonym)[0][1:] 119 | truncate_sequences(maxlen * 2, -2, text_ids, synonym_ids) 120 | token_ids = text_ids + synonym_ids 121 | segment_ids = [0] * len(text_ids) + [1] * len(synonym_ids) 122 | batch_token_ids.append(token_ids) 123 | batch_segment_ids.append(segment_ids) 124 | self.some_samples.append(synonym) 125 | if len(self.some_samples) > 1000: 126 | self.some_samples.pop(0) 127 | text, synonym = synonym, text 128 | if len(batch_token_ids) == self.batch_size or is_end: 129 | batch_token_ids = sequence_padding(batch_token_ids) 130 | batch_segment_ids = sequence_padding(batch_segment_ids) 131 | yield [batch_token_ids, batch_segment_ids], None 132 | batch_token_ids, batch_segment_ids = [], [] 133 | 134 | 135 | class TotalLoss(Loss): 136 | """loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。 137 | """ 138 | def compute_loss(self, inputs, mask=None): 139 | loss1 = self.compute_loss_of_seq2seq(inputs, mask) 140 | loss2 = self.compute_loss_of_similarity(inputs, mask) 141 | self.add_metric(loss1, name='seq2seq_loss') 142 | self.add_metric(loss2, name='similarity_loss') 143 | return loss1 + loss2 144 | 145 | def compute_loss_of_seq2seq(self, inputs, mask=None): 146 | y_true, y_mask, _, y_pred = inputs 147 | y_true = y_true[:, 1:] # 目标token_ids 148 | y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分 149 | y_pred = y_pred[:, :-1] # 预测序列,错开一位 150 | loss = K.sparse_categorical_crossentropy( 151 | y_true, y_pred, from_logits=True 152 | ) 153 | loss = K.sum(loss * y_mask) / K.sum(y_mask) 154 | return loss 155 | 156 | def compute_loss_of_similarity(self, inputs, mask=None): 157 | _, _, y_pred, _ = inputs 158 | y_true = self.get_labels_of_similarity(y_pred) # 构建标签 159 | y_pred = K.l2_normalize(y_pred, axis=1) # 句向量归一化 160 | similarities = K.dot(y_pred, K.transpose(y_pred)) # 相似度矩阵 161 | similarities = similarities - K.eye(K.shape(y_pred)[0]) * 1e12 # 排除对角线 162 | similarities = similarities * 20 # scale 163 | loss = K.categorical_crossentropy( 164 | y_true, similarities, from_logits=True 165 | ) 166 | return loss 167 | 168 | def get_labels_of_similarity(self, y_pred): 169 | idxs = K.arange(0, K.shape(y_pred)[0]) 170 | idxs_1 = idxs[None, :] 171 | idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None] 172 | labels = K.equal(idxs_1, idxs_2) 173 | labels = K.cast(labels, K.floatx()) 174 | return labels 175 | 176 | 177 | # 建立加载模型 178 | roformer = build_transformer_model( 179 | config_path, 180 | checkpoint_path, 181 | model='roformer', 182 | application='unilm', 183 | with_pool='linear', 184 | with_mlm='linear', 185 | dropout_rate=0.2, 186 | ignore_invalid_weights=True 187 | ) 188 | 189 | encoder = keras.models.Model(roformer.inputs, roformer.outputs[0]) 190 | seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1]) 191 | 192 | outputs = TotalLoss([2, 3])(roformer.inputs + roformer.outputs) 193 | model = keras.models.Model(roformer.inputs, outputs) 194 | 195 | AdamW = extend_with_weight_decay(Adam, 'AdamW') 196 | optimizer = AdamW(learning_rate=1e-5, weight_decay_rate=0.01) 197 | model.compile(optimizer=optimizer) 198 | model.summary() 199 | 200 | 201 | class SynonymsGenerator(AutoRegressiveDecoder): 202 | """seq2seq解码器 203 | """ 204 | @AutoRegressiveDecoder.wraps(default_rtype='logits') 205 | def predict(self, inputs, output_ids, step): 206 | token_ids, segment_ids = inputs 207 | token_ids = np.concatenate([token_ids, output_ids], 1) 208 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 209 | return self.last_token(seq2seq).predict([token_ids, segment_ids]) 210 | 211 | def generate(self, text, n=1, topp=0.95): 212 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 213 | output_ids = self.random_sample([token_ids, segment_ids], n, 214 | topp=topp) # 基于随机采样 215 | return [tokenizer.decode(ids) for ids in output_ids] 216 | 217 | 218 | synonyms_generator = SynonymsGenerator( 219 | start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen 220 | ) 221 | 222 | 223 | def gen_synonyms(text, n=100, k=20): 224 | """"含义: 产生sent的n个相似句,然后返回最相似的k个。 225 | 做法:用seq2seq生成,并用encoder算相似度并排序。 226 | 效果: 227 | >>> gen_synonyms(u'微信和支付宝哪个好?') 228 | [ 229 | u'微信和支付宝,哪个好?', 230 | u'微信和支付宝哪个好', 231 | u'支付宝和微信哪个好', 232 | u'支付宝和微信哪个好啊', 233 | u'微信和支付宝那个好用?', 234 | u'微信和支付宝哪个好用', 235 | u'支付宝和微信那个更好', 236 | u'支付宝和微信哪个好用', 237 | u'微信和支付宝用起来哪个好?', 238 | u'微信和支付宝选哪个好', 239 | ] 240 | """ 241 | r = synonyms_generator.generate(text, n) 242 | r = [i for i in set(r) if i != text] 243 | r = [text] + r 244 | X, S = [], [] 245 | for t in r: 246 | x, s = tokenizer.encode(t) 247 | X.append(x) 248 | S.append(s) 249 | X = sequence_padding(X) 250 | S = sequence_padding(S) 251 | Z = encoder.predict([X, S]) 252 | Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 253 | argsort = np.dot(Z[1:], -Z[0]).argsort() 254 | return [r[i + 1] for i in argsort[:k]] 255 | 256 | 257 | def just_show(): 258 | """随机观察一些样本的效果 259 | """ 260 | some_samples = train_generator.some_samples 261 | S = [np.random.choice(some_samples) for i in range(3)] 262 | for s in S: 263 | try: 264 | print(u'原句子:%s' % s) 265 | print(u'同义句子:') 266 | print(gen_synonyms(s, 10, 10)) 267 | print() 268 | except: 269 | pass 270 | 271 | 272 | class Evaluate(keras.callbacks.Callback): 273 | """评估模型 274 | """ 275 | def __init__(self): 276 | self.lowest = 1e10 277 | 278 | def on_epoch_end(self, epoch, logs=None): 279 | model.save_weights('./latest_model.weights') 280 | # 保存最优 281 | if logs['loss'] <= self.lowest: 282 | self.lowest = logs['loss'] 283 | model.save_weights('./best_model.weights') 284 | # 演示效果 285 | just_show() 286 | 287 | 288 | if __name__ == '__main__': 289 | 290 | train_generator = data_generator(corpus(), batch_size) 291 | evaluator = Evaluate() 292 | 293 | model.fit_generator( 294 | train_generator.forfit(), 295 | steps_per_epoch=steps_per_epoch, 296 | epochs=epochs, 297 | callbacks=[evaluator] 298 | ) 299 | 300 | else: 301 | 302 | model.load_weights('./latest_model.weights') 303 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /train/stage2.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # SimBERT v2 训练代码 3 | # 训练环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.backend import keras, K 8 | from bert4keras.layers import Loss 9 | from bert4keras.models import build_transformer_model 10 | from bert4keras.tokenizers import Tokenizer 11 | from bert4keras.optimizers import Adam, extend_with_weight_decay 12 | from bert4keras.snippets import DataGenerator, sequence_padding 13 | from bert4keras.snippets import text_segmentate, truncate_sequences 14 | from bert4keras.snippets import AutoRegressiveDecoder 15 | import jieba 16 | jieba.initialize() 17 | 18 | # 基本信息 19 | maxlen = 64 20 | batch_size = 96 21 | steps_per_epoch = 1000 22 | epochs = 10000 23 | 24 | # bert配置 25 | config_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/bert_config.json' 26 | checkpoint_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/bert_model.ckpt' 27 | dict_path = '/root/kg/bert/chinese_roformer-sim-char_L-12_H-768_A-12/vocab.txt' 28 | 29 | # 建立分词器 30 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 31 | 32 | 33 | def read(filename): 34 | """读取语料,每行一个json 35 | """ 36 | while True: 37 | with open(filename) as f: 38 | for l in f: 39 | yield json.loads(l) 40 | 41 | 42 | def split(text): 43 | """分割句子 44 | """ 45 | seps, strips = u'\n。!?!?;;,, ', u';;,, ' 46 | return text_segmentate(text, maxlen * 1.2, seps, strips) 47 | 48 | 49 | def corpus(): 50 | """读取语料 51 | """ 52 | f1 = read('/root/data_pretrain/synonyms_shuf.json') 53 | f2 = read('/root/data_pretrain/synonym_answers_shuf.json') 54 | f3 = read('/root/data_pretrain/synonym/synonym_gen_2_shuf.json') 55 | while True: 56 | d = next(f1) 57 | text, synonyms = d['text'], d['synonyms'] 58 | text, synonym = np.random.permutation([text] + synonyms)[:2] 59 | text, synonym = split(text)[0], split(synonym)[0] 60 | yield text, synonym 61 | d = next(f2) 62 | text, synonym = d['text_a'], d['text_b'] 63 | text, synonym = split(text)[0], split(synonym)[0] 64 | yield text, synonym 65 | d = next(f1) 66 | text, synonyms = d['text'], d['synonyms'] 67 | text, synonym = np.random.permutation([text] + synonyms)[:2] 68 | text, synonym = split(text)[0], split(synonym)[0] 69 | yield text, synonym 70 | d = next(f3) 71 | text, synonym = d['text_a'], d['text_b'] 72 | text, synonym = split(text)[0], split(synonym)[0] 73 | yield text, synonym 74 | 75 | 76 | def masked_encode(text): 77 | """wwm随机mask 78 | """ 79 | words = jieba.lcut(text) 80 | rands = np.random.random(len(words)) 81 | source, target = [tokenizer._token_start_id], [0] 82 | for r, w in zip(rands, words): 83 | ids = tokenizer.encode(w)[0][1:-1] 84 | if r < 0.15 * 0.8: 85 | source.extend([tokenizer._token_mask_id] * len(ids)) 86 | target.extend(ids) 87 | elif r < 0.15 * 0.9: 88 | source.extend(ids) 89 | target.extend(ids) 90 | elif r < 0.15: 91 | source.extend( 92 | np.random.choice(tokenizer._vocab_size - 1, size=len(ids)) + 1 93 | ) 94 | target.extend(ids) 95 | else: 96 | source.extend(ids) 97 | target.extend([0] * len(ids)) 98 | source = source[:maxlen - 1] + [tokenizer._token_end_id] 99 | target = target[:maxlen - 1] + [0] 100 | return source, target 101 | 102 | 103 | # ========== 蒸馏用:开始 ========== 104 | 105 | # simbert配置 106 | sim_config_path = '/root/kg/bert/chinese_simbert_L-12_H-768_A-12/bert_config.json' 107 | sim_checkpoint_path = '/root/kg/bert/chinese_simbert_L-12_H-768_A-12/bert_model.ckpt' 108 | sim_dict_path = '/root/kg/bert/chinese_simbert_L-12_H-768_A-12/vocab.txt' 109 | 110 | # 建立分词器 111 | sim_tokenizer = Tokenizer(sim_dict_path, do_lower_case=True) # 建立分词器 112 | 113 | # 建立加载模型 114 | simbert = build_transformer_model( 115 | sim_config_path, 116 | sim_checkpoint_path, 117 | with_pool='linear', 118 | application='unilm', 119 | return_keras_model=False, 120 | ) 121 | 122 | sim_encoder = keras.models.Model(simbert.model.inputs, simbert.model.outputs[0]) 123 | 124 | # ========== 蒸馏用:结束 ========== 125 | 126 | 127 | class data_generator(DataGenerator): 128 | """数据生成器 129 | """ 130 | def __init__(self, *args, **kwargs): 131 | super(data_generator, self).__init__(*args, **kwargs) 132 | self.some_samples = [] 133 | 134 | def __iter__(self, random=False): 135 | batch_token_ids, batch_segment_ids = [], [] 136 | batch_sim_token_ids, batch_sim_segment_ids = [], [] 137 | for is_end, (text, synonym) in self.sample(random): 138 | for i in range(2): 139 | if np.random.random() < 0.5: 140 | text_ids = masked_encode(text)[0] 141 | else: 142 | text_ids = tokenizer.encode(text)[0] 143 | synonym_ids = tokenizer.encode(synonym)[0][1:] 144 | truncate_sequences(maxlen * 2, -2, text_ids, synonym_ids) 145 | token_ids = text_ids + synonym_ids 146 | segment_ids = [0] * len(text_ids) + [1] * len(synonym_ids) 147 | batch_token_ids.append(token_ids) 148 | batch_segment_ids.append(segment_ids) 149 | # ==== 蒸馏用:开始 ==== 150 | token_ids, segment_ids = sim_tokenizer.encode(text, maxlen=maxlen) 151 | batch_sim_token_ids.append(token_ids) 152 | batch_sim_segment_ids.append(segment_ids) 153 | # ==== 蒸馏用:结束 ==== 154 | self.some_samples.append(synonym) 155 | if len(self.some_samples) > 1000: 156 | self.some_samples.pop(0) 157 | text, synonym = synonym, text 158 | if len(batch_token_ids) == self.batch_size or is_end: 159 | batch_token_ids = sequence_padding(batch_token_ids) 160 | batch_segment_ids = sequence_padding(batch_segment_ids) 161 | # ==== 蒸馏用:开始 ==== 162 | batch_sim_token_ids = sequence_padding(batch_sim_token_ids) 163 | batch_sim_segment_ids = sequence_padding(batch_sim_segment_ids) 164 | sim_vecs = sim_encoder.predict([batch_sim_token_ids, batch_sim_segment_ids]) 165 | sim_vecs /= (sim_vecs**2).sum(axis=1, keepdims=True)**0.5 166 | sims = sim_vecs.dot(sim_vecs.T) 167 | # ==== 蒸馏用:结束 ==== 168 | yield [batch_token_ids, batch_segment_ids, sims], None 169 | batch_token_ids, batch_segment_ids = [], [] 170 | batch_sim_token_ids, batch_sim_segment_ids = [], [] 171 | 172 | 173 | class TotalLoss(Loss): 174 | """loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。 175 | """ 176 | def compute_loss(self, inputs, mask=None): 177 | loss1 = self.compute_loss_of_seq2seq(inputs, mask) 178 | loss2 = self.compute_loss_of_similarity(inputs, mask) 179 | self.add_metric(loss1, name='seq2seq_loss') 180 | self.add_metric(loss2, name='similarity_loss') 181 | return loss1 + loss2 182 | 183 | def compute_loss_of_seq2seq(self, inputs, mask=None): 184 | y_true, y_mask, _, y_pred, _ = inputs 185 | y_true = y_true[:, 1:] # 目标token_ids 186 | y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分 187 | y_pred = y_pred[:, :-1] # 预测序列,错开一位 188 | loss = K.sparse_categorical_crossentropy( 189 | y_true, y_pred, from_logits=True 190 | ) 191 | loss = K.sum(loss * y_mask) / K.sum(y_mask) 192 | return loss 193 | 194 | def compute_loss_of_similarity(self, inputs, mask=None): 195 | _, _, y_pred, _, y_true = inputs 196 | y_pred = K.l2_normalize(y_pred, axis=1) # 句向量归一化 197 | similarities = K.dot(y_pred, K.transpose(y_pred)) # 相似度矩阵 198 | loss = 100 * K.mean((similarities - y_true)**2) 199 | return loss 200 | 201 | def get_labels_of_similarity(self, y_pred): 202 | idxs = K.arange(0, K.shape(y_pred)[0]) 203 | idxs_1 = idxs[None, :] 204 | idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None] 205 | labels = K.equal(idxs_1, idxs_2) 206 | labels = K.cast(labels, K.floatx()) 207 | return labels 208 | 209 | 210 | # 建立加载模型 211 | roformer = build_transformer_model( 212 | config_path, 213 | checkpoint_path, 214 | model='roformer', 215 | application='unilm', 216 | with_pool='linear', 217 | with_mlm='linear', 218 | dropout_rate=0.2, 219 | ignore_invalid_weights=True 220 | ) 221 | 222 | encoder = keras.models.Model(roformer.inputs, roformer.outputs[0]) 223 | seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1]) 224 | 225 | sim_in = keras.layers.Input(shape=(None,)) 226 | outputs = TotalLoss([2, 3])(roformer.inputs + roformer.outputs + [sim_in]) 227 | model = keras.models.Model(roformer.inputs + [sim_in], outputs) 228 | 229 | AdamW = extend_with_weight_decay(Adam, 'AdamW') 230 | optimizer = AdamW(learning_rate=1e-5, weight_decay_rate=0.01) 231 | model.compile(optimizer=optimizer) 232 | model.summary() 233 | 234 | 235 | class SynonymsGenerator(AutoRegressiveDecoder): 236 | """seq2seq解码器 237 | """ 238 | @AutoRegressiveDecoder.wraps(default_rtype='logits') 239 | def predict(self, inputs, output_ids, step): 240 | token_ids, segment_ids = inputs 241 | token_ids = np.concatenate([token_ids, output_ids], 1) 242 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 243 | return self.last_token(seq2seq).predict([token_ids, segment_ids]) 244 | 245 | def generate(self, text, n=1, topp=0.95): 246 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 247 | output_ids = self.random_sample([token_ids, segment_ids], n, 248 | topp=topp) # 基于随机采样 249 | return [tokenizer.decode(ids) for ids in output_ids] 250 | 251 | 252 | synonyms_generator = SynonymsGenerator( 253 | start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen 254 | ) 255 | 256 | 257 | def gen_synonyms(text, n=100, k=20): 258 | """"含义: 产生sent的n个相似句,然后返回最相似的k个。 259 | 做法:用seq2seq生成,并用encoder算相似度并排序。 260 | 效果: 261 | >>> gen_synonyms(u'微信和支付宝哪个好?') 262 | [ 263 | u'微信和支付宝,哪个好?', 264 | u'微信和支付宝哪个好', 265 | u'支付宝和微信哪个好', 266 | u'支付宝和微信哪个好啊', 267 | u'微信和支付宝那个好用?', 268 | u'微信和支付宝哪个好用', 269 | u'支付宝和微信那个更好', 270 | u'支付宝和微信哪个好用', 271 | u'微信和支付宝用起来哪个好?', 272 | u'微信和支付宝选哪个好', 273 | ] 274 | """ 275 | r = synonyms_generator.generate(text, n) 276 | r = [i for i in set(r) if i != text] 277 | r = [text] + r 278 | X, S = [], [] 279 | for t in r: 280 | x, s = tokenizer.encode(t) 281 | X.append(x) 282 | S.append(s) 283 | X = sequence_padding(X) 284 | S = sequence_padding(S) 285 | Z = encoder.predict([X, S]) 286 | Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 287 | argsort = np.dot(Z[1:], -Z[0]).argsort() 288 | return [r[i + 1] for i in argsort[:k]] 289 | 290 | 291 | def just_show(): 292 | """随机观察一些样本的效果 293 | """ 294 | some_samples = train_generator.some_samples 295 | S = [np.random.choice(some_samples) for i in range(3)] 296 | for s in S: 297 | try: 298 | print(u'原句子:%s' % s) 299 | print(u'同义句子:') 300 | print(gen_synonyms(s, 10, 10)) 301 | print() 302 | except: 303 | pass 304 | 305 | 306 | class Evaluate(keras.callbacks.Callback): 307 | """评估模型 308 | """ 309 | def __init__(self): 310 | self.lowest = 1e10 311 | 312 | def on_epoch_end(self, epoch, logs=None): 313 | model.save_weights('./latest_model_2.weights') 314 | # 保存最优 315 | if logs['loss'] <= self.lowest: 316 | self.lowest = logs['loss'] 317 | model.save_weights('./best_model_2.weights') 318 | # 演示效果 319 | just_show() 320 | 321 | 322 | if __name__ == '__main__': 323 | 324 | train_generator = data_generator(corpus(), batch_size) 325 | evaluator = Evaluate() 326 | 327 | model.fit_generator( 328 | train_generator.forfit(), 329 | steps_per_epoch=steps_per_epoch, 330 | epochs=epochs, 331 | callbacks=[evaluator], 332 | use_multiprocessing=False, 333 | workers=0 334 | ) 335 | 336 | else: 337 | 338 | model.load_weights('./latest_model_2.weights') 339 | --------------------------------------------------------------------------------