├── .gitignore ├── README.md ├── data ├── alias_schema ├── ccf2019_corpus.json ├── data_report.py ├── other.txt └── sougouqa_webqa_corpus.json ├── model_predict.py ├── model_train.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data/data_from_ccf2019.py 3 | data/alias_* 4 | models 5 | chinese-RoBERTa-wwm-ext 6 | data/sougouqa_webqa_tagging_corpus.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 别名发现系统 2 | 3 | ### 项目说明 4 | 5 | 别名发现在日常生活、知识图谱及搜索推荐中有重要应用。 6 | 7 | 本项目使用深度学习方法去发现非结构化文本中的别名与简称(或称缩略语),比如下面的例子: 8 | 9 | ``` 10 | 卡定沟:又名嘎定沟,位于西藏318国道拉萨至林芝段距八一镇24公里处,海拔2980米,地处雅鲁藏布江支流尼洋河畔 11 | ``` 12 | 13 | 发现别名为:`(卡定沟, 别名, 嘎定沟)`。 14 | 15 | 本文将会提供别名发现的语料以及相应的深度学习模型,即使用关系抽取模型作为别名发现的模型。 16 | 17 | ### 语料说明 18 | 19 | 本项目最大的贡献在于提供人工标注的高质量别名语料,`本项目于2021年8月11日启动,长期维护`,由于是个人维护,故更新较慢。 20 | 21 | 本项目的语料来源见下表(源语料已经由作者使用程序加工处理过,因此下面语料只表明别名语料的出处): 22 | 23 | |文件名称|数据来源|标注样本数量| 24 | |---|---|---| 25 | |data/ccf2019_corpus.json|CCF2019年关系抽取比赛数据|3369| 26 | |data/sougouqa_webqa_corpus.json|阅读理解数据(WebQA & SougouQA)|1056| 27 | 28 | ### 模型训练 29 | 30 | 训练集数据:测试集数据=8:2 31 | 32 | maxlen=200, batch_size=16, epoch=20, 使用Google Colab训练(预训练模型为`哈工大的中文Roberta模型: chinese-RoBERTa-wwm-ext`), 在测试集上的F1值为88.97% 33 | 34 | ### 模型预测 35 | 36 | ``` 37 | 杨桃,又名阳桃、羊桃、五棱子,学名“五敛子”,又因横切面如五角星,故国外又称之为“星梨”。 杨桃原产地,传统认为产於东南亚的马来西亚等地。我国於汉朝就有栽培记载,今我国福建、广东、广西、云南等地;亚洲东南亚、印度;美洲巴西等热带地区均普遍栽培。... 38 | 39 | 预测结果: 40 | 41 | [("杨桃", "别名", "阳桃"), ("杨桃", "别名", "羊桃"), ("杨桃", "别名", "五棱子"), ("杨桃", "别名", "五敛子"), ("杨桃", "别名", "星梨")] 42 | ``` 43 | 44 | ``` 45 | 关羽去世后,逐渐被神化,被民间尊为“关公”,又称美髯公。 46 | 预测结果: 47 | 48 | [("关羽", "别名", "关公"), ("关羽", "别名", "美髯公")] 49 | ``` 50 | 51 | ``` 52 | 英吉利海峡隧道(thechanneltunnel)又称英法海底隧道或欧洲隧道(eurotunnel),是一条把英国英伦三岛连接往欧洲法国的铁路隧道,于1994年5月6日开通。 53 | 54 | 预测结果: 55 | 56 | [("英吉利海峡隧道", "别名", "英法海底隧道"), ("英吉利海峡隧道", "别名", "欧洲隧道")] 57 | ``` 58 | 59 | ### 备注 60 | 61 | 语料维护费时费力,因此可能会存在一定错误,如有问题,请及时指出。 -------------------------------------------------------------------------------- /data/alias_schema: -------------------------------------------------------------------------------- 1 | {"object_type": "实体1", "predicate": "别名", "subject_type": "实体2"} 2 | {"object_type": "实体1", "predicate": "简称", "subject_type": "实体2"} -------------------------------------------------------------------------------- /data/data_report.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/2 13:46 3 | # @Author : Jclian91 4 | # @File : data_report.py 5 | # @Place : Yangpu, Shanghai 6 | import json 7 | from random import shuffle 8 | 9 | with open("ccf2019_corpus.json", "r", encoding="utf-8") as f: 10 | content = json.loads(f.read()) 11 | 12 | with open("sougouqa_webqa_corpus.json", "r", encoding="utf-8") as f: 13 | content.extend([_ for _ in json.loads(f.read()) if not _["spo"] or _["spo"][0][0]]) 14 | 15 | # 简单数据统计 16 | print(f"共有{len(content)}条标注样本") 17 | print(len(content) - 3369) 18 | print("data review for last 10 samples: ") 19 | for example in content[-10:]: 20 | print(example) 21 | 22 | # 将数据集划分为训练集和测试集,比列为8:2 23 | shuffle(content) 24 | train_data = content[:int(len(content)*0.8)] 25 | test_data = content[int(len(content)*0.8):] 26 | with open("alias_train.json", "w", encoding="utf-8") as f: 27 | for _ in train_data: 28 | f.write(json.dumps(_, ensure_ascii=False)+"\n") 29 | with open("alias_test.json", "w", encoding="utf-8") as f: 30 | for _ in test_data: 31 | f.write(json.dumps(_, ensure_ascii=False)+"\n") 32 | 33 | -------------------------------------------------------------------------------- /data/other.txt: -------------------------------------------------------------------------------- 1 | 汉兵用绿旗,号称绿旗营,又称绿营。 2 | 次年各省举人赴北京礼部应考,称为会试。中式后赴宫中太和殿对答皇帝策问,称为廷试或殿试。 3 | 康熙三十一年(甲寅)三合会成立。三合会或称天地会,或称三点会。支派有清水会、匕首(小刀)会、双刀会等名目。 4 | 哥老会又称哥弟会,成立约在乾隆年间。 -------------------------------------------------------------------------------- /model_predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/5 23:54 3 | # @Author : Jclian91 4 | # @File : model_predict.py 5 | # @Place : Yangpu, Shanghai 6 | # -*- coding: utf-8 -*- 7 | 8 | import json 9 | import numpy as np 10 | from bert4keras.backend import K, batch_gather 11 | from bert4keras.layers import Loss 12 | from bert4keras.layers import LayerNormalization 13 | from bert4keras.tokenizers import Tokenizer 14 | from bert4keras.models import build_transformer_model 15 | from bert4keras.optimizers import Adam, extend_with_exponential_moving_average 16 | from bert4keras.snippets import open, to_array 17 | from keras.layers import Input, Dense, Lambda, Reshape 18 | from keras.models import Model 19 | 20 | 21 | maxlen = 200 22 | config_path = './chinese-RoBERTa-wwm-ext/bert_config.json' 23 | checkpoint_path = './chinese-RoBERTa-wwm-ext/bert_model.ckpt' 24 | dict_path = './chinese-RoBERTa-wwm-ext/vocab.txt' 25 | 26 | predicate2id, id2predicate = {}, {} 27 | 28 | with open('./data/alias_schema', "r", encoding="utf-8") as f: 29 | for l in f: 30 | l = json.loads(l) 31 | if l['predicate'] not in predicate2id: 32 | id2predicate[len(predicate2id)] = l['predicate'] 33 | predicate2id[l['predicate']] = len(predicate2id) 34 | 35 | # 建立分词器 36 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 37 | 38 | # 补充输入 39 | subject_labels = Input(shape=(None, 2), name='Subject-Labels') 40 | subject_ids = Input(shape=(2,), name='Subject-Ids') 41 | object_labels = Input(shape=(None, len(predicate2id), 2), name='Object-Labels') 42 | 43 | # 加载预训练模型 44 | bert = build_transformer_model( 45 | config_path=config_path, 46 | checkpoint_path=checkpoint_path, 47 | return_keras_model=False, 48 | ) 49 | 50 | # 预测subject 51 | output = Dense( 52 | units=2, activation='sigmoid', kernel_initializer=bert.initializer 53 | )(bert.model.output) 54 | subject_preds = Lambda(lambda x: x**2)(output) 55 | 56 | subject_model = Model(bert.model.inputs, subject_preds) 57 | 58 | 59 | # 根据subject_ids从output中取出subject的向量表征 60 | def extract_subject(inputs): 61 | output, subject_ids = inputs 62 | start = batch_gather(output, subject_ids[:, :1]) 63 | end = batch_gather(output, subject_ids[:, 1:]) 64 | subject = K.concatenate([start, end], 2) 65 | return subject[:, 0] 66 | 67 | 68 | # 传入subject,预测object 69 | # 通过Conditional Layer Normalization将subject融入到object的预测中 70 | output = bert.model.layers[-2].get_output_at(-1) 71 | subject = Lambda(extract_subject)([output, subject_ids]) 72 | output = LayerNormalization(conditional=True)([output, subject]) 73 | output = Dense( 74 | units=len(predicate2id) * 2, 75 | activation='sigmoid', 76 | kernel_initializer=bert.initializer 77 | )(output) 78 | output = Lambda(lambda x: x**4)(output) 79 | object_preds = Reshape((-1, len(predicate2id), 2))(output) 80 | 81 | object_model = Model(bert.model.inputs + [subject_ids], object_preds) 82 | 83 | 84 | class TotalLoss(Loss): 85 | """subject_loss与object_loss之和,都是二分类交叉熵 86 | """ 87 | def compute_loss(self, inputs, mask=None): 88 | subject_labels, object_labels = inputs[:2] 89 | subject_preds, object_preds, _ = inputs[2:] 90 | if mask[4] is None: 91 | mask = 1.0 92 | else: 93 | mask = K.cast(mask[4], K.floatx()) 94 | # sujuect部分loss 95 | subject_loss = K.binary_crossentropy(subject_labels, subject_preds) 96 | subject_loss = K.mean(subject_loss, 2) 97 | subject_loss = K.sum(subject_loss * mask) / K.sum(mask) 98 | # object部分loss 99 | object_loss = K.binary_crossentropy(object_labels, object_preds) 100 | object_loss = K.sum(K.mean(object_loss, 3), 2) 101 | object_loss = K.sum(object_loss * mask) / K.sum(mask) 102 | # 总的loss 103 | return subject_loss + object_loss 104 | 105 | 106 | subject_preds, object_preds = TotalLoss([2, 3])([ 107 | subject_labels, object_labels, subject_preds, object_preds, 108 | bert.model.output 109 | ]) 110 | 111 | # 训练模型 112 | train_model = Model( 113 | bert.model.inputs + [subject_labels, subject_ids, object_labels], 114 | [subject_preds, object_preds] 115 | ) 116 | 117 | AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA') 118 | optimizer = AdamEMA(lr=1e-5) 119 | train_model.compile(optimizer=optimizer) 120 | 121 | 122 | train_model.load_weights('./models/alias_best_model.weights') 123 | 124 | 125 | # 抽取输入text所包含的三元组 126 | def extract_spoes(text): 127 | 128 | tokens = tokenizer.tokenize(text, maxlen=maxlen) 129 | mapping = tokenizer.rematch(text, tokens) 130 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 131 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 132 | # 抽取subject 133 | subject_preds = subject_model.predict([token_ids, segment_ids]) 134 | start = np.where(subject_preds[0, :, 0] > 0.6)[0] 135 | end = np.where(subject_preds[0, :, 1] > 0.5)[0] 136 | subjects = [] 137 | for i in start: 138 | j = end[end >= i] 139 | if len(j) > 0: 140 | j = j[0] 141 | subjects.append((i, j)) 142 | if subjects: 143 | spoes = [] 144 | token_ids = np.repeat(token_ids, len(subjects), 0) 145 | segment_ids = np.repeat(segment_ids, len(subjects), 0) 146 | subjects = np.array(subjects) 147 | # 传入subject,抽取object和predicate 148 | object_preds = object_model.predict([token_ids, segment_ids, subjects]) 149 | for subject, object_pred in zip(subjects, object_preds): 150 | start = np.where(object_pred[:, :, 0] > 0.6) 151 | end = np.where(object_pred[:, :, 1] > 0.5) 152 | for _start, predicate1 in zip(*start): 153 | for _end, predicate2 in zip(*end): 154 | if _start <= _end and predicate1 == predicate2: 155 | spoes.append( 156 | ((mapping[subject[0]][0], 157 | mapping[subject[1]][-1]), predicate1, 158 | (mapping[_start][0], mapping[_end][-1])) 159 | ) 160 | break 161 | return [(text[s[0]:s[1] + 1], id2predicate[p], text[o[0]:o[1] + 1]) 162 | for s, p, o, in spoes] 163 | else: 164 | return [] 165 | 166 | 167 | if __name__ == '__main__': 168 | # new_text = "亦庄开发区也称北京经济技术开发区,位于大兴区东北部地区;筹建于1991年,1992年开始建设并对外招商,1994年8月25日被国务院批准为国家级经济技术开发区。" 169 | # print(extract_spoes(new_text)) 170 | import json 171 | with open("./data/sougouqa_webqa_corpus.json", "r", encoding="utf-8") as f: 172 | content = json.loads(f.read()) 173 | 174 | # 模型预测 175 | g = open("sougouqa_webqa_predict.json", "a", encoding="utf-8") 176 | i = 0 177 | for line in content: 178 | if not line["spo"][0][0]: 179 | i += 1 180 | new_text = line["text"] 181 | spo = [list(_) for _ in extract_spoes(new_text)] 182 | print(i, spo, new_text) 183 | g.write(json.dumps({"spo": spo, "text": new_text}, ensure_ascii=False)+"\n") 184 | 185 | g.close() 186 | -------------------------------------------------------------------------------- /model_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import numpy as np 4 | from bert4keras.backend import keras, K, batch_gather 5 | from bert4keras.layers import Loss 6 | from bert4keras.layers import LayerNormalization 7 | from bert4keras.tokenizers import Tokenizer 8 | from bert4keras.models import build_transformer_model 9 | from bert4keras.optimizers import Adam, extend_with_exponential_moving_average 10 | from bert4keras.snippets import sequence_padding, DataGenerator 11 | from bert4keras.snippets import open, to_array 12 | from keras.layers import Input, Dense, Lambda, Reshape 13 | from keras.models import Model 14 | from tqdm import tqdm 15 | 16 | maxlen = 200 17 | batch_size = 16 18 | config_path = './chinese-RoBERTa-wwm-ext/bert_config.json' 19 | checkpoint_path = './chinese-RoBERTa-wwm-ext/bert_model.ckpt' 20 | dict_path = './chinese-RoBERTa-wwm-ext/vocab.txt' 21 | 22 | 23 | def load_data(filename): 24 | """加载数据 25 | 单条格式:{'text': text, 'spo_list': [(s, p, o)]} 26 | """ 27 | print("loading data for {}".format(filename)) 28 | D = [] 29 | with open(filename, "r", encoding='utf-8') as f: 30 | for l in f: 31 | l = json.loads(l) 32 | D.append({ 33 | 'text': l['text'], 34 | 'spo_list': [(spo[0], spo[1], spo[2]) for spo in l['spo']] 35 | }) 36 | return D 37 | 38 | 39 | # 加载数据集 40 | train_data = load_data('./data/alias_train.json') 41 | valid_data = load_data('./data/alias_test.json') 42 | predicate2id, id2predicate = {}, {} 43 | 44 | with open('./data/alias_schema', "r", encoding="utf-8") as f: 45 | for l in f: 46 | l = json.loads(l) 47 | if l['predicate'] not in predicate2id: 48 | id2predicate[len(predicate2id)] = l['predicate'] 49 | predicate2id[l['predicate']] = len(predicate2id) 50 | 51 | # 建立分词器 52 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 53 | 54 | 55 | def search(pattern, sequence): 56 | """从sequence中寻找子串pattern 57 | 如果找到,返回第一个下标;否则返回-1。 58 | """ 59 | n = len(pattern) 60 | for i in range(len(sequence)): 61 | if sequence[i:i + n] == pattern: 62 | return i 63 | return -1 64 | 65 | 66 | class data_generator(DataGenerator): 67 | """数据生成器 68 | """ 69 | def __iter__(self, random=False): 70 | batch_token_ids, batch_segment_ids = [], [] 71 | batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] 72 | for is_end, d in self.sample(random): 73 | token_ids, segment_ids = tokenizer.encode(d['text'], maxlen=maxlen) 74 | # 整理三元组 {s: [(o, p)]} 75 | spoes = {} 76 | for s, p, o in d['spo_list']: 77 | s = tokenizer.encode(s)[0][1:-1] 78 | p = predicate2id[p] 79 | o = tokenizer.encode(o)[0][1:-1] 80 | s_idx = search(s, token_ids) 81 | o_idx = search(o, token_ids) 82 | if s_idx != -1 and o_idx != -1: 83 | s = (s_idx, s_idx + len(s) - 1) 84 | o = (o_idx, o_idx + len(o) - 1, p) 85 | if s not in spoes: 86 | spoes[s] = [] 87 | spoes[s].append(o) 88 | if spoes: 89 | # subject标签 90 | subject_labels = np.zeros((len(token_ids), 2)) 91 | for s in spoes: 92 | subject_labels[s[0], 0] = 1 93 | subject_labels[s[1], 1] = 1 94 | # 随机选一个subject 95 | start, end = np.array(list(spoes.keys())).T 96 | start = np.random.choice(start) 97 | end = np.random.choice(end[end >= start]) 98 | subject_ids = (start, end) 99 | # 对应的object标签 100 | object_labels = np.zeros((len(token_ids), len(predicate2id), 2)) 101 | for o in spoes.get(subject_ids, []): 102 | object_labels[o[0], o[2], 0] = 1 103 | object_labels[o[1], o[2], 1] = 1 104 | # 构建batch 105 | batch_token_ids.append(token_ids) 106 | batch_segment_ids.append(segment_ids) 107 | batch_subject_labels.append(subject_labels) 108 | batch_subject_ids.append(subject_ids) 109 | batch_object_labels.append(object_labels) 110 | if len(batch_token_ids) == self.batch_size or is_end: 111 | batch_token_ids = sequence_padding(batch_token_ids) 112 | batch_segment_ids = sequence_padding(batch_segment_ids) 113 | batch_subject_labels = sequence_padding( 114 | batch_subject_labels 115 | ) 116 | batch_subject_ids = np.array(batch_subject_ids) 117 | batch_object_labels = sequence_padding(batch_object_labels) 118 | yield [ 119 | batch_token_ids, batch_segment_ids, 120 | batch_subject_labels, batch_subject_ids, 121 | batch_object_labels 122 | ], None 123 | batch_token_ids, batch_segment_ids = [], [] 124 | batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] 125 | 126 | 127 | def extract_subject(inputs): 128 | """根据subject_ids从output中取出subject的向量表征 129 | """ 130 | output, subject_ids = inputs 131 | start = batch_gather(output, subject_ids[:, :1]) 132 | end = batch_gather(output, subject_ids[:, 1:]) 133 | subject = K.concatenate([start, end], 2) 134 | return subject[:, 0] 135 | 136 | 137 | # 补充输入 138 | subject_labels = Input(shape=(None, 2), name='Subject-Labels') 139 | subject_ids = Input(shape=(2,), name='Subject-Ids') 140 | object_labels = Input(shape=(None, len(predicate2id), 2), name='Object-Labels') 141 | 142 | # 加载预训练模型 143 | bert = build_transformer_model( 144 | config_path=config_path, 145 | checkpoint_path=checkpoint_path, 146 | return_keras_model=False 147 | ) 148 | 149 | # 预测subject 150 | output = Dense( 151 | units=2, activation='sigmoid', kernel_initializer=bert.initializer 152 | )(bert.model.output) 153 | subject_preds = Lambda(lambda x: x**2)(output) 154 | 155 | subject_model = Model(bert.model.inputs, subject_preds) 156 | 157 | # 传入subject,预测object 158 | # 通过Conditional Layer Normalization将subject融入到object的预测中 159 | output = bert.model.layers[-2].get_output_at(-1) 160 | subject = Lambda(extract_subject)([output, subject_ids]) 161 | output = LayerNormalization(conditional=True)([output, subject]) 162 | output = Dense( 163 | units=len(predicate2id) * 2, 164 | activation='sigmoid', 165 | kernel_initializer=bert.initializer 166 | )(output) 167 | output = Lambda(lambda x: x**4)(output) 168 | object_preds = Reshape((-1, len(predicate2id), 2))(output) 169 | 170 | object_model = Model(bert.model.inputs + [subject_ids], object_preds) 171 | 172 | 173 | class TotalLoss(Loss): 174 | """subject_loss与object_loss之和,都是二分类交叉熵 175 | """ 176 | def compute_loss(self, inputs, mask=None): 177 | subject_labels, object_labels = inputs[:2] 178 | subject_preds, object_preds, _ = inputs[2:] 179 | if mask[4] is None: 180 | mask = 1.0 181 | else: 182 | mask = K.cast(mask[4], K.floatx()) 183 | # sujuect部分loss 184 | subject_loss = K.binary_crossentropy(subject_labels, subject_preds) 185 | subject_loss = K.mean(subject_loss, 2) 186 | subject_loss = K.sum(subject_loss * mask) / K.sum(mask) 187 | # object部分loss 188 | object_loss = K.binary_crossentropy(object_labels, object_preds) 189 | object_loss = K.sum(K.mean(object_loss, 3), 2) 190 | object_loss = K.sum(object_loss * mask) / K.sum(mask) 191 | # 总的loss 192 | return subject_loss + object_loss 193 | 194 | 195 | subject_preds, object_preds = TotalLoss([2, 3])([ 196 | subject_labels, object_labels, subject_preds, object_preds, 197 | bert.model.output 198 | ]) 199 | 200 | # 训练模型 201 | train_model = Model( 202 | bert.model.inputs + [subject_labels, subject_ids, object_labels], 203 | [subject_preds, object_preds] 204 | ) 205 | 206 | AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA') 207 | optimizer = AdamEMA(lr=1e-5) 208 | train_model.compile(optimizer=optimizer) 209 | 210 | 211 | def extract_spoes(text): 212 | """抽取输入text所包含的三元组 213 | """ 214 | tokens = tokenizer.tokenize(text, maxlen=maxlen) 215 | mapping = tokenizer.rematch(text, tokens) 216 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 217 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 218 | # 抽取subject 219 | subject_preds = subject_model.predict([token_ids, segment_ids]) 220 | start = np.where(subject_preds[0, :, 0] > 0.6)[0] 221 | end = np.where(subject_preds[0, :, 1] > 0.5)[0] 222 | subjects = [] 223 | for i in start: 224 | j = end[end >= i] 225 | if len(j) > 0: 226 | j = j[0] 227 | subjects.append((i, j)) 228 | if subjects: 229 | spoes = [] 230 | token_ids = np.repeat(token_ids, len(subjects), 0) 231 | segment_ids = np.repeat(segment_ids, len(subjects), 0) 232 | subjects = np.array(subjects) 233 | # 传入subject,抽取object和predicate 234 | object_preds = object_model.predict([token_ids, segment_ids, subjects]) 235 | for subject, object_pred in zip(subjects, object_preds): 236 | start = np.where(object_pred[:, :, 0] > 0.6) 237 | end = np.where(object_pred[:, :, 1] > 0.5) 238 | for _start, predicate1 in zip(*start): 239 | for _end, predicate2 in zip(*end): 240 | if _start <= _end and predicate1 == predicate2: 241 | spoes.append( 242 | ((mapping[subject[0]][0], 243 | mapping[subject[1]][-1]), predicate1, 244 | (mapping[_start][0], mapping[_end][-1])) 245 | ) 246 | break 247 | return [(text[s[0]:s[1] + 1], id2predicate[p], text[o[0]:o[1] + 1]) 248 | for s, p, o, in spoes] 249 | else: 250 | return [] 251 | 252 | 253 | class SPO(tuple): 254 | """用来存三元组的类 255 | 表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法, 256 | 使得在判断两个三元组是否等价时容错性更好。 257 | """ 258 | def __init__(self, spo): 259 | self.spox = ( 260 | tuple(tokenizer.tokenize(spo[0])), 261 | spo[1], 262 | tuple(tokenizer.tokenize(spo[2])), 263 | ) 264 | 265 | def __hash__(self): 266 | return self.spox.__hash__() 267 | 268 | def __eq__(self, spo): 269 | return self.spox == spo.spox 270 | 271 | 272 | def evaluate(data): 273 | """评估函数,计算f1、precision、recall 274 | """ 275 | X, Y, Z = 1e-10, 1e-10, 1e-10 276 | f = open('dev_pred.json', 'w', encoding='utf-8') 277 | pbar = tqdm() 278 | for d in data: 279 | R = set([SPO(spo) for spo in extract_spoes(d['text'])]) 280 | T = set([SPO(spo) for spo in d['spo_list']]) 281 | X += len(R & T) 282 | Y += len(R) 283 | Z += len(T) 284 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 285 | pbar.update() 286 | pbar.set_description( 287 | 'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall) 288 | ) 289 | s = json.dumps({ 290 | 'text': d['text'], 291 | 'spo_list': list(T), 292 | 'spo_list_pred': list(R), 293 | 'new': list(R - T), 294 | 'lack': list(T - R), 295 | }, 296 | ensure_ascii=False, 297 | indent=4) 298 | f.write(s + '\n') 299 | pbar.close() 300 | f.close() 301 | return f1, precision, recall 302 | 303 | 304 | class Evaluator(keras.callbacks.Callback): 305 | """评估与保存 306 | """ 307 | def __init__(self): 308 | self.best_val_f1 = 0. 309 | 310 | def on_epoch_end(self, epoch, logs=None): 311 | optimizer.apply_ema_weights() 312 | f1, precision, recall = evaluate(valid_data) 313 | if f1 >= self.best_val_f1: 314 | self.best_val_f1 = f1 315 | train_model.save_weights('./models/alias_best_model.weights') 316 | optimizer.reset_old_weights() 317 | print( 318 | 'f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 319 | (f1, precision, recall, self.best_val_f1) 320 | ) 321 | 322 | 323 | if __name__ == '__main__': 324 | train_generator = data_generator(train_data, batch_size) 325 | evaluator = Evaluator() 326 | 327 | train_model.fit( 328 | train_generator.forfit(), 329 | steps_per_epoch=len(train_generator), 330 | epochs=20, 331 | callbacks=[evaluator] 332 | ) 333 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.14.0 2 | tensorflow-gpu==1.14.0 3 | Keras==2.2.4 4 | bert4keras==0.10.7 --------------------------------------------------------------------------------