├── README.md └── el.py /README.md: -------------------------------------------------------------------------------- 1 | # el-2019-final 2 | 2019年百度的实体链接比赛( https://biendata.com/competition/ccks_2019_el/ ),“科学空间队”源码 3 | 4 | 基于BiLSTM、Attention和人工特征的实体识别和实体链接模型。 5 | 6 | ## 环境 7 | Python 2.7 + Keras 2.2.4 + Tensorflow 1.8,其中关系最大的应该是Python 2.7了,如果你用Python 3,需要修改几行代码,至于修改哪几行,自己想办法,我不是你的debugger。 8 | 9 | 欢迎入坑Keras。人生苦短,我用Keras~ 10 | 11 | ## 详细介绍 12 | https://kexue.fm/archives/6919 13 | 14 | ## 交流 15 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn 16 | -------------------------------------------------------------------------------- /el.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | import numpy as np 7 | from random import choice 8 | from itertools import groupby 9 | from gensim.models import Word2Vec 10 | import pyhanlp 11 | from nlp_zero import Trie, DAG # pip install nlp_zero 12 | import re 13 | 14 | 15 | mode = 0 16 | min_count = 2 17 | char_size = 128 18 | num_features = 3 19 | 20 | 21 | word2vec = Word2Vec.load('../../kg/word2vec_baike/word2vec_baike') 22 | id2word = {i + 1: j for i, j in enumerate(word2vec.wv.index2word)} 23 | word2id = {j: i for i, j in id2word.items()} 24 | word2vec = word2vec.wv.syn0 25 | word_size = word2vec.shape[1] 26 | word2vec = np.concatenate([np.zeros((1, word_size)), word2vec]) 27 | 28 | 29 | def tokenize(s): 30 | """如果pyhanlp不好用,自己修改tokenize函数, 31 | 换成自己的分词工具即可。 32 | """ 33 | return [i.word for i in pyhanlp.HanLP.segment(s)] 34 | 35 | 36 | def sent2vec(S): 37 | """S格式:[[w1, w2]] 38 | """ 39 | V = [] 40 | for s in S: 41 | V.append([]) 42 | for w in s: 43 | for _ in w: 44 | V[(-1)].append(word2id.get(w, 0)) 45 | V = seq_padding(V) 46 | V = word2vec[V] 47 | return V 48 | 49 | 50 | id2kb = {} 51 | with open('../ccks2019_el/kb_data') as (f): 52 | for l in tqdm(f): 53 | _ = json.loads(l) 54 | subject_id = _['subject_id'] 55 | subject_alias = list(set([_['subject']] + _.get('alias', []))) 56 | subject_alias = [alias.lower() for alias in subject_alias] 57 | object_regex = set( 58 | [i['object'] for i in _['data'] if len(i['object']) <= 10] 59 | ) 60 | object_regex = sorted(object_regex, key=lambda s: -len(s)) 61 | object_regex = [re.escape(i) for i in object_regex] 62 | object_regex = re.compile('|'.join(object_regex)) # 预先建立正则表达式,用来识别object是否在query出现过 63 | _['data'].append({ 64 | 'predicate': u'名称', 65 | 'object': u'、'.join(subject_alias) 66 | }) 67 | subject_desc = '\n'.join( 68 | u'%s:%s' % (i['predicate'], i['object']) for i in _['data'] 69 | ) 70 | subject_desc = subject_desc.lower() 71 | id2kb[subject_id] = { 72 | 'subject_alias': subject_alias, 73 | 'subject_desc': subject_desc, 74 | 'object_regex': object_regex 75 | } 76 | 77 | 78 | kb2id = {} 79 | trie = Trie() # 根据知识库所有实体来构建Trie树 80 | 81 | for i, j in id2kb.items(): 82 | for k in j['subject_alias']: 83 | if k not in kb2id: 84 | kb2id[k] = [] 85 | trie[k.strip(u'《》')] = 1 86 | kb2id[k].append(i) 87 | 88 | 89 | def search_subjects(text_in): 90 | """实现最大匹配算法 91 | """ 92 | R = trie.search(text_in) 93 | dag = DAG(len(text_in)) 94 | for i, j in R: 95 | dag[(i, j)] = -1 96 | S = {} 97 | for i, j in dag.optimal_path(): 98 | if text_in[i:j] in kb2id: 99 | S[(i, j)] = text_in[i:j] 100 | return S 101 | 102 | 103 | train_data = [] 104 | 105 | with open('../ccks2019_el/train.json') as (f): 106 | for l in tqdm(f): 107 | _ = json.loads(l) 108 | train_data.append({ 109 | 'text': _['text'], 110 | 'mention_data': [ 111 | (x['mention'], int(x['offset']), x['kb_id']) 112 | for x in _['mention_data'] if x['kb_id'] != 'NIL' 113 | ] 114 | }) 115 | 116 | 117 | if not os.path.exists('../all_chars_me.json'): 118 | chars = {} 119 | for d in tqdm(iter(id2kb.values())): 120 | for c in d['subject_desc']: 121 | chars[c] = chars.get(c, 0) + 1 122 | for d in tqdm(iter(train_data)): 123 | for c in d['text'].lower(): 124 | chars[c] = chars.get(c, 0) + 1 125 | chars = {i: j for i, j in chars.items() if j >= min_count} 126 | id2char = {i + 2: j for i, j in enumerate(chars)} 127 | char2id = {j: i for i, j in id2char.items()} 128 | json.dump([id2char, char2id], open('../all_chars_me.json', 'w')) 129 | else: 130 | id2char, char2id = json.load(open('../all_chars_me.json')) 131 | 132 | 133 | # 通过统计来精简词典,提高最大匹配的准确率 134 | words_to_pred = {} 135 | words_to_remove = {} 136 | A, B, C = 1e-10, 1e-10, 1e-10 137 | for d in train_data: 138 | R = set([(v, k[0]) for k, v in search_subjects(d['text']).items()]) 139 | T = set([tuple(md[:2]) for md in d['mention_data']]) 140 | A += len(R & T) 141 | B += len(R) 142 | C += len(T) 143 | R = set([i[0] for i in R]) 144 | T = set([i[0] for i in T]) 145 | for w in T: 146 | words_to_pred[w] = words_to_pred.get(w, 0) + 1 147 | for w in R - T: 148 | words_to_remove[w] = words_to_remove.get(w, 0) + 1 149 | 150 | 151 | words = set(list(words_to_pred) + list(words_to_remove)) 152 | words = { 153 | i: (words_to_remove.get(i, 0) + 1.0) / (words_to_pred.get(i, 0) + 1.0) 154 | for i in words 155 | } 156 | words = {i: j for i, j in words.items() if j >= 5} 157 | 158 | for w in words: 159 | del kb2id[w] 160 | trie[w] = 0 161 | 162 | 163 | if not os.path.exists('../random_order_train.json'): 164 | random_order = range(len(train_data)) 165 | np.random.shuffle(random_order) 166 | json.dump(random_order, open('../random_order_train.json', 'w'), indent=4) 167 | else: 168 | random_order = json.load(open('../random_order_train.json')) 169 | 170 | 171 | dev_data = [train_data[j] for i, j in enumerate(random_order) if i % 9 == mode] 172 | train_data = [train_data[j] for i, j in enumerate(random_order) if i % 9 != mode] 173 | 174 | 175 | subjects = {} 176 | 177 | for d in train_data: 178 | for md in d['mention_data']: 179 | if md[0] not in subjects: 180 | subjects[md[0]] = {} 181 | subjects[md[0]][md[2]] = subjects[md[0]].get(md[2], 0) + 1 182 | 183 | 184 | candidate_links = {} 185 | 186 | for k, v in subjects.items(): 187 | for i, j in v.items(): 188 | if j < 2: 189 | del v[i] 190 | if v: 191 | _ = set(v.keys()) & set(kb2id.get(k, [])) 192 | if _: 193 | candidate_links[k] = list(_) 194 | 195 | 196 | test_data = [] 197 | 198 | with open('../ccks2019_el/develop.json') as f: 199 | for l in tqdm(f): 200 | _ = json.loads(l) 201 | test_data.append(_) 202 | 203 | 204 | def seq_padding(X, padding=0): 205 | L = [len(x) for x in X] 206 | ML = max(L) 207 | return np.array([ 208 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x 209 | for x in X 210 | ]) 211 | 212 | 213 | def isin_feature(text_a, text_b): 214 | y = np.zeros(len(''.join(text_a))) 215 | text_b = set(text_b) 216 | i = 0 217 | for w in text_a: 218 | if w in text_b: 219 | for c in w: 220 | y[i] = 1 221 | i += 1 222 | return y 223 | 224 | 225 | def is_match_objects(text, object_regex): 226 | y = np.zeros(len(text)) 227 | for i in object_regex.finditer(text): 228 | y[i.start():i.end()] = 1 229 | return y 230 | 231 | 232 | class data_generator: 233 | def __init__(self, data, batch_size=64): 234 | self.data = data 235 | self.batch_size = batch_size 236 | self.steps = len(self.data) // self.batch_size 237 | if len(self.data) % self.batch_size != 0: 238 | self.steps += 1 239 | def __len__(self): 240 | return self.steps 241 | def __iter__(self): 242 | while True: 243 | idxs = range(len(self.data)) 244 | np.random.shuffle(idxs) 245 | X1, X2, X1V, X2V, S1, S2, PRES1, PRES2, Y, T = ( 246 | [], [], [], [], [], [], [], [], [], [] 247 | ) 248 | for i in idxs: 249 | d = self.data[i] 250 | text = d['text'].lower() 251 | text_words = tokenize(text) 252 | text = ''.join(text_words) 253 | x1 = [char2id.get(c, 1) for c in text] 254 | s1, s2 = np.zeros(len(text)), np.zeros(len(text)) 255 | mds = {} 256 | for md in d['mention_data']: 257 | md = (md[0].lower(), md[1], md[2]) 258 | if md[0] in kb2id: 259 | j1 = md[1] 260 | j2 = j1 + len(md[0]) 261 | s1[j1] = 1 262 | s2[j2 - 1] = 1 263 | mds[(j1, j2)] = (md[0], md[2]) 264 | if mds: 265 | j1, j2 = choice(mds.keys()) 266 | y1 = np.zeros(len(text)) 267 | y1[j1:j2] = 1 268 | x2 = choice(kb2id[mds[(j1, j2)][0]]) 269 | if x2 == mds[(j1, j2)][1]: 270 | t = [1] 271 | else: 272 | t = [0] 273 | object_regex = id2kb[x2]['object_regex'] 274 | x2 = id2kb[x2]['subject_desc'] 275 | x2_words = tokenize(x2) 276 | x2 = ''.join(x2_words) 277 | y2 = isin_feature(text, x2) 278 | y3 = isin_feature(text_words, x2_words) 279 | y4 = is_match_objects(text, object_regex) 280 | y = np.vstack([y1, y2, y3, y4]).T 281 | x2 = [char2id.get(c, 1) for c in x2] 282 | pre_subjects = search_subjects(d['text']) 283 | pres1, pres2 = np.zeros(len(text)), np.zeros(len(text)) 284 | for j1, j2 in pre_subjects: 285 | pres1[j1] = 1 286 | pres2[j2 - 1] = 1 287 | X1.append(x1) 288 | X2.append(x2) 289 | X1V.append(text_words) 290 | X2V.append(x2_words) 291 | S1.append(s1) 292 | S2.append(s2) 293 | PRES1.append(pres1) 294 | PRES2.append(pres2) 295 | Y.append(y) 296 | T.append(t) 297 | if len(X1) == self.batch_size or i == idxs[-1]: 298 | X1 = seq_padding(X1) 299 | X2 = seq_padding(X2) 300 | X1V = sent2vec(X1V) 301 | X2V = sent2vec(X2V) 302 | S1 = seq_padding(S1) 303 | S2 = seq_padding(S2) 304 | PRES1 = seq_padding(PRES1) 305 | PRES2 = seq_padding(PRES2) 306 | Y = seq_padding(Y, np.zeros(1 + num_features)) 307 | T = seq_padding(T) 308 | yield [X1, X2, X1V, X2V, S1, S2, PRES1, PRES2, Y, T], None 309 | X1, X2, X1V, X2V, S1, S2, PRES1, PRES2, Y, T = ( 310 | [], [], [], [], [], [], [], [], [], [] 311 | ) 312 | 313 | 314 | from keras.layers import * 315 | from keras.models import Model 316 | import keras.backend as K 317 | from keras.callbacks import Callback 318 | from keras.optimizers import Adam 319 | 320 | 321 | class Attention(Layer): 322 | """多头注意力机制 323 | """ 324 | def __init__(self, nb_head, size_per_head, **kwargs): 325 | self.nb_head = nb_head 326 | self.size_per_head = size_per_head 327 | self.out_dim = nb_head * size_per_head 328 | super(Attention, self).__init__(**kwargs) 329 | def build(self, input_shape): 330 | q_in_dim = input_shape[0][-1] 331 | k_in_dim = input_shape[1][-1] 332 | v_in_dim = input_shape[2][-1] 333 | self.q_kernel = self.add_weight( 334 | name='q_kernel', 335 | shape=(q_in_dim, self.out_dim), 336 | initializer='glorot_normal') 337 | self.k_kernel = self.add_weight( 338 | name='k_kernel', 339 | shape=(k_in_dim, self.out_dim), 340 | initializer='glorot_normal') 341 | self.v_kernel = self.add_weight( 342 | name='w_kernel', 343 | shape=(v_in_dim, self.out_dim), 344 | initializer='glorot_normal') 345 | def mask(self, x, mask, mode='mul'): 346 | if mask is None: 347 | return x 348 | for _ in range(K.ndim(x) - K.ndim(mask)): 349 | mask = K.expand_dims(mask, K.ndim(mask)) 350 | if mode == 'mul': 351 | return x * mask 352 | else: 353 | return x - (1 - mask) * 1e10 354 | def call(self, inputs): 355 | q, k, v = inputs[:3] 356 | v_mask, q_mask = (None, None) 357 | if len(inputs) > 3: 358 | v_mask = inputs[3] 359 | if len(inputs) > 4: 360 | q_mask = inputs[4] 361 | # 线性变换 362 | qw = K.dot(q, self.q_kernel) 363 | kw = K.dot(k, self.k_kernel) 364 | vw = K.dot(v, self.v_kernel) 365 | # 形状变换 366 | qw = K.reshape(qw, (-1, K.shape(qw)[1], self.nb_head, self.size_per_head)) 367 | kw = K.reshape(kw, (-1, K.shape(kw)[1], self.nb_head, self.size_per_head)) 368 | vw = K.reshape(vw, (-1, K.shape(vw)[1], self.nb_head, self.size_per_head)) 369 | # 维度置换 370 | qw = K.permute_dimensions(qw, (0, 2, 1, 3)) 371 | kw = K.permute_dimensions(kw, (0, 2, 1, 3)) 372 | vw = K.permute_dimensions(vw, (0, 2, 1, 3)) 373 | # Attention 374 | a = K.batch_dot(qw, kw, [3, 3]) / self.size_per_head**0.5 375 | a = K.permute_dimensions(a, (0, 3, 2, 1)) 376 | a = self.mask(a, v_mask, 'add') 377 | a = K.permute_dimensions(a, (0, 3, 2, 1)) 378 | a = K.softmax(a) 379 | # 完成输出 380 | o = K.batch_dot(a, vw, [3, 2]) 381 | o = K.permute_dimensions(o, (0, 2, 1, 3)) 382 | o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim)) 383 | o = self.mask(o, q_mask, 'mul') 384 | return o 385 | def compute_output_shape(self, input_shape): 386 | return (input_shape[0][0], input_shape[0][1], self.out_dim) 387 | 388 | 389 | def seq_maxpool(x): 390 | """seq是[None, seq_len, s_size]的格式, 391 | mask是[None, seq_len, 1]的格式,先除去mask部分, 392 | 然后再做maxpooling。 393 | """ 394 | seq, mask = x 395 | seq -= (1 - mask) * 1e10 396 | return K.max(seq, 1) 397 | 398 | 399 | class MyBidirectional: 400 | """自己封装双向RNN,允许传入mask,保证对齐 401 | """ 402 | def __init__(self, layer): 403 | self.forward_layer = layer.__class__.from_config(layer.get_config()) 404 | self.backward_layer = layer.__class__.from_config(layer.get_config()) 405 | self.forward_layer.name = 'forward_' + self.forward_layer.name 406 | self.backward_layer.name = 'backward_' + self.backward_layer.name 407 | def reverse_sequence(self, inputs): 408 | """这里的mask.shape是[batch_size, seq_len, 1] 409 | """ 410 | x, mask = inputs 411 | seq_len = K.round(K.sum(mask, 1)[:, 0]) 412 | seq_len = K.cast(seq_len, 'int32') 413 | return K.tf.reverse_sequence(x, seq_len, seq_dim=1) 414 | def __call__(self, inputs): 415 | x, mask = inputs 416 | x_forward = self.forward_layer(x) 417 | x_backward = Lambda(self.reverse_sequence)([x, mask]) 418 | x_backward = self.backward_layer(x_backward) 419 | x_backward = Lambda(self.reverse_sequence)([x_backward, mask]) 420 | x = Concatenate()([x_forward, x_backward]) 421 | x = Lambda(lambda x: x[0] * x[1])([x, mask]) 422 | return x 423 | 424 | 425 | x1_in = Input(shape=(None, )) 426 | x2_in = Input(shape=(None, )) 427 | x1v_in = Input(shape=(None, word_size)) 428 | x2v_in = Input(shape=(None, word_size)) 429 | s1_in = Input(shape=(None, )) 430 | s2_in = Input(shape=(None, )) 431 | pres1_in = Input(shape=(None, )) 432 | pres2_in = Input(shape=(None, )) 433 | y_in = Input(shape=(None, 1 + num_features)) 434 | t_in = Input(shape=(1, )) 435 | 436 | x1, x2, x1v, x2v, s1, s2, pres1, pres2, y, t = ( 437 | x1_in, x2_in, x1v_in, x2v_in, s1_in, s2_in, pres1_in, pres2_in, y_in, t_in 438 | ) 439 | 440 | x1_mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(x1) 441 | x2_mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(x2) 442 | 443 | embedding = Embedding(len(id2char) + 2, char_size) 444 | dense = Dense(char_size, use_bias=False) 445 | 446 | x1 = embedding(x1) 447 | x1v = dense(x1v) 448 | x1 = Add()([x1, x1v]) 449 | x1 = Dropout(0.2)(x1) 450 | 451 | pres1 = Lambda(lambda x: K.expand_dims(x, 2))(pres1) 452 | pres2 = Lambda(lambda x: K.expand_dims(x, 2))(pres2) 453 | x1 = Concatenate()([x1, pres1, pres2]) 454 | x1 = Lambda(lambda x: x[0] * x[1])([x1, x1_mask]) 455 | 456 | x1 = MyBidirectional(CuDNNLSTM(char_size // 2, return_sequences=True))([x1, x1_mask]) 457 | 458 | h = Conv1D(char_size, 3, activation='relu', padding='same')(x1) 459 | ps1 = Dense(1, activation='sigmoid')(h) 460 | ps2 = Dense(1, activation='sigmoid')(h) 461 | ps1 = Lambda(lambda x: x[0] * x[1])([ps1, pres1]) # 这样一乘,相当于只从最大匹配的结果中筛选实体 462 | ps2 = Lambda(lambda x: x[0] * x[1])([ps2, pres2]) # 这样一乘,相当于只从最大匹配的结果中筛选实体 463 | 464 | s_model = Model([x1_in, x1v_in, pres1_in, pres2_in], [ps1, ps2]) 465 | 466 | 467 | x1 = Concatenate()([x1, y]) 468 | x1 = MyBidirectional(CuDNNLSTM(char_size // 2, return_sequences=True))([x1, x1_mask]) 469 | ys = Lambda(lambda x: K.sum(x[0] * x[1][..., :1], 1) / K.sum(x[1][..., :1], 1))([x1, y]) 470 | 471 | x2 = embedding(x2) 472 | x2v = dense(x2v) 473 | x2 = Add()([x2, x2v]) 474 | x2 = Dropout(0.2)(x2) 475 | x2 = Lambda(lambda x: x[0] * x[1])([x2, x2_mask]) 476 | x2 = MyBidirectional(CuDNNLSTM(char_size // 2, return_sequences=True))([x2, x2_mask]) 477 | 478 | x12 = Attention(8, 16)([x1, x2, x2, x2_mask, x1_mask]) 479 | x12 = Lambda(seq_maxpool)([x12, x1_mask]) 480 | x21 = Attention(8, 16)([x2, x1, x1, x1_mask, x2_mask]) 481 | x21 = Lambda(seq_maxpool)([x21, x2_mask]) 482 | x = Concatenate()([x12, x21, ys]) 483 | x = Dropout(0.2)(x) 484 | x = Dense(char_size, activation='relu')(x) 485 | pt = Dense(1, activation='sigmoid')(x) 486 | 487 | t_model = Model([x1_in, x2_in, x1v_in, x2v_in, pres1_in, pres2_in, y_in], pt) 488 | 489 | 490 | train_model = Model( 491 | [x1_in, x2_in, x1v_in, x2v_in, s1_in, s2_in, pres1_in, pres2_in, y_in, t_in], 492 | [ps1, ps2, pt] 493 | ) 494 | 495 | s1 = K.expand_dims(s1, 2) 496 | s2 = K.expand_dims(s2, 2) 497 | s1_loss = K.binary_crossentropy(s1, ps1) 498 | s1_loss = K.sum(s1_loss * x1_mask) / K.sum(x1_mask) 499 | s2_loss = K.binary_crossentropy(s2, ps2) 500 | s2_loss = K.sum(s2_loss * x1_mask) / K.sum(x1_mask) 501 | pt_loss = K.mean(K.binary_crossentropy(t, pt)) 502 | loss = s1_loss + s2_loss + pt_loss 503 | 504 | train_model.add_loss(loss) 505 | train_model.compile(optimizer=Adam(1e-3)) 506 | train_model.summary() 507 | 508 | 509 | class ExponentialMovingAverage: 510 | """对模型权重进行指数滑动平均。 511 | 用法:在model.compile之后、第一次训练之前使用; 512 | 先初始化对象,然后执行inject方法。 513 | """ 514 | def __init__(self, model, momentum=0.9999): 515 | self.momentum = momentum 516 | self.model = model 517 | self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights] 518 | def inject(self): 519 | """添加更新算子到model.metrics_updates。 520 | """ 521 | self.initialize() 522 | for w1, w2 in zip(self.ema_weights, self.model.weights): 523 | op = K.moving_average_update(w1, w2, self.momentum) 524 | self.model.metrics_updates.append(op) 525 | def initialize(self): 526 | """ema_weights初始化跟原模型初始化一致。 527 | """ 528 | self.old_weights = K.batch_get_value(self.model.weights) 529 | K.batch_set_value(zip(self.ema_weights, self.old_weights)) 530 | def apply_ema_weights(self): 531 | """备份原模型权重,然后将平均权重应用到模型上去。 532 | """ 533 | self.old_weights = K.batch_get_value(self.model.weights) 534 | ema_weights = K.batch_get_value(self.ema_weights) 535 | K.batch_set_value(zip(self.model.weights, ema_weights)) 536 | def reset_old_weights(self): 537 | """恢复模型到旧权重。 538 | """ 539 | K.batch_set_value(zip(self.model.weights, self.old_weights)) 540 | 541 | 542 | EMAer = ExponentialMovingAverage(train_model) 543 | EMAer.inject() 544 | 545 | 546 | def extract_items(text_in): 547 | text_words = tokenize(text_in) 548 | text_old = ''.join(text_words) 549 | text_in = text_old.lower() 550 | _x1 = [char2id.get(c, 1) for c in text_in] 551 | _x1 = np.array([_x1]) 552 | _x1v = sent2vec([text_words]) 553 | pre_subjects = search_subjects(text_in) 554 | _pres1, _pres2 = np.zeros([1, len(text_in)]), np.zeros([1, len(text_in)]) 555 | for j1, j2 in pre_subjects: 556 | _pres1[(0, j1)] = 1 557 | _pres2[(0, j2 - 1)] = 1 558 | _k1, _k2 = s_model.predict([_x1, _x1v, _pres1, _pres2]) 559 | _k1, _k2 = _k1[0, :, 0], _k2[0, :, 0] 560 | _k1, _k2 = np.where(_k1 > 0.4)[0], np.where(_k2 > 0.4)[0] 561 | _subjects = [] 562 | for i in _k1: 563 | j = _k2[(_k2 >= i)] 564 | if len(j) > 0: 565 | j = j[0] 566 | _subject = text_in[i:j + 1] 567 | _subjects.append((_subject, i, j + 1)) 568 | if _subjects: 569 | R = [] 570 | _X2, _X2V, _Y = [], [], [] 571 | _S, _IDXS = [], {} 572 | for _s in _subjects: 573 | _y1 = np.zeros(len(text_in)) 574 | _y1[_s[1]: _s[2]] = 1 575 | if _s[0] in candidate_links: 576 | _IDXS[_s] = candidate_links.get(_s[0], []) 577 | else: 578 | _IDXS[_s] = kb2id.get(_s[0], []) 579 | for i in _IDXS[_s]: 580 | object_regex = id2kb[i]['object_regex'] 581 | _x2 = id2kb[i]['subject_desc'] 582 | _x2_words = tokenize(_x2) 583 | _x2 = ''.join(_x2_words) 584 | _y2 = isin_feature(text_in, _x2) 585 | _y3 = isin_feature(text_words, _x2_words) 586 | _y4 = is_match_objects(text_in, object_regex) 587 | _y = np.vstack([_y1, _y2, _y3, _y4]).T 588 | _x2 = [char2id.get(c, 1) for c in _x2] 589 | _X2.append(_x2) 590 | _X2V.append(_x2_words) 591 | _Y.append(_y) 592 | _S.append(_s) 593 | if _X2: 594 | _X2 = seq_padding(_X2) 595 | _X2V = sent2vec(_X2V) 596 | _Y = seq_padding(_Y, np.zeros(1 + num_features)) 597 | _X1 = np.repeat(_x1, len(_X2), 0) 598 | _X1V = np.repeat(_x1v, len(_X2), 0) 599 | _PRES1 = np.repeat(_pres1, len(_X2), 0) 600 | _PRES2 = np.repeat(_pres2, len(_X2), 0) 601 | scores = t_model.predict([_X1, _X2, _X1V, _X2V, _PRES1, _PRES2, _Y])[:, 0] 602 | for k, v in groupby(zip(_S, scores), key=lambda s: s[0]): 603 | ks = _IDXS[k] 604 | vs = [j[1] for j in v] 605 | if np.max(vs) < 0.1: 606 | continue 607 | kbid = ks[np.argmax(vs)] 608 | R.append((text_old[k[1]:k[2]], k[1], kbid)) 609 | return R 610 | else: 611 | return [] 612 | 613 | 614 | def test(test_data): 615 | F = open('result.json', 'w') 616 | for d in tqdm(iter(test_data)): 617 | d['mention_data'] = [ 618 | dict(zip(['mention', 'offset', 'kb_id'], [md[0], str(md[1]), md[2]])) 619 | for md in set(extract_items(d['text'])) 620 | ] 621 | F.write(json.dumps(d, ensure_ascii=False).encode('utf-8') + '\n') 622 | F.close() 623 | 624 | 625 | class Evaluate(Callback): 626 | def __init__(self): 627 | self.F1 = [] 628 | self.best = 0.0 629 | def on_epoch_end(self, epoch, logs=None): 630 | EMAer.apply_ema_weights() 631 | f1, precision, recall = self.evaluate() 632 | self.F1.append(f1) 633 | if f1 > self.best: 634 | self.best = f1 635 | train_model.save_weights('best_model.weights') 636 | print 'f1: %.4f, precision: %.4f, recall: %.4f, best f1: %.4f\n' % (f1, precision, recall, self.best) 637 | EMAer.reset_old_weights() 638 | def evaluate(self): 639 | A, B, C = 1e-10, 1e-10, 1e-10 640 | F = open('dev_pred.json', 'w') 641 | pbar = tqdm() 642 | for d in dev_data: 643 | R = set(extract_items(d['text'])) 644 | T = set(d['mention_data']) 645 | A += len(R & T) 646 | B += len(R) 647 | C += len(T) 648 | s = json.dumps( 649 | { 650 | 'text': d['text'], 651 | 'mention_data': list(T), 652 | 'mention_data_pred': list(R), 653 | 'new': list(R - T), 654 | 'lack': list(T - R) 655 | }, 656 | ensure_ascii=False, 657 | indent=4) 658 | F.write(s.encode('utf-8') + '\n') 659 | pbar.update(1) 660 | f1, pr, rc = 2 * A / (B + C), A / B, A / C 661 | pbar.set_description('< f1: %.4f, precision: %.4f, recall: %.4f >' % (f1, pr, rc)) 662 | F.close() 663 | pbar.close() 664 | return (2 * A / (B + C), A / B, A / C) 665 | 666 | 667 | evaluator = Evaluate() 668 | train_D = data_generator(train_data) 669 | 670 | if __name__ == '__main__': 671 | train_model.fit_generator( 672 | train_D.__iter__(), 673 | steps_per_epoch=len(train_D), 674 | epochs=120, 675 | callbacks=[evaluator] 676 | ) 677 | else: 678 | train_model.load_weights('best_model.weights') 679 | --------------------------------------------------------------------------------