├── README.md ├── PeopleDaily_CRF.py ├── PeopleDaily_GlobalPointer.py ├── CMeEE_CRF.py ├── CMeEE_GlobalPointer.py ├── CLUENER_CRF.py └── CLUENER_GlobalPointer.py /README.md: -------------------------------------------------------------------------------- 1 | # GlobalPointer 2 | 全局指针统一处理嵌套与非嵌套NER。 3 | 4 | ## 介绍 5 | 6 | - 博客:https://kexue.fm/archives/8373 7 | 8 | ## 效果 9 | 10 | ### 人民日报NER 11 | 12 | | | 验证集F1 | 测试集F1 | 训练速度 | 预测速度 | 13 | | :-: | :-: | :-: | :-: | :-: | 14 | | CRF | 96.39% | 95.46% | 1x | 1x | 15 | | GlobalPointer (w/o RoPE) | 54.35% | 62.59% | 1.61x | 1.13x | 16 | | GlobalPointer (w/ RoPE) | 96.25% | 95.51% | 1.56x | 1.11x | 17 | 18 | ### CLUENER 19 | 20 | | | 验证集F1 | 测试集F1 | 训练速度 | 预测速度 | 21 | | :-: | :-: | :-: | :-: | :-: | 22 | | CRF | 79.51% | 78.70% | 1x | 1x | 23 | | GlobalPointer | 80.03% | 79.44% | 1.22x | 1x | 24 | 25 | ### CMeEE 26 | 27 | | | 验证集F1 | 测试集F1 | 训练速度 | 预测速度 | 28 | | :-: | :-: | :-: | :-: | :-: | 29 | | CRF | 63.81% | 64.39% | 1x | 1x | 30 | | GlobalPointer | 64.84% | 65.98% | 1.52x | 1.13x | 31 | 32 | ## 环境 33 | 34 | 需要`bert4keras >= 0.10.6`。个人实验环境是tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6。 35 | 36 | ## 交流 37 | QQ交流群:808623966,微信群请加机器人微信号spaces_ac_cn 38 | -------------------------------------------------------------------------------- /PeopleDaily_CRF.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用CRF做中文命名实体识别 3 | # 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz 4 | 5 | import numpy as np 6 | from bert4keras.backend import keras, K 7 | from bert4keras.models import build_transformer_model 8 | from bert4keras.tokenizers import Tokenizer 9 | from bert4keras.optimizers import Adam 10 | from bert4keras.snippets import sequence_padding, DataGenerator 11 | from bert4keras.snippets import open, ViterbiDecoder, to_array 12 | from bert4keras.layers import ConditionalRandomField 13 | from keras.layers import Dense 14 | from keras.models import Model 15 | from tqdm import tqdm 16 | 17 | maxlen = 256 18 | epochs = 10 19 | batch_size = 16 20 | learning_rate = 2e-5 21 | crf_lr_multiplier = 1000 # 必要时扩大CRF层的学习率 22 | categories = set() 23 | 24 | # bert配置 25 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 26 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 27 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 28 | 29 | 30 | def load_data(filename): 31 | """加载数据 32 | 单条格式:[text, (start, end, label), (start, end, label), ...], 33 | 意味着text[start:end + 1]是类型为label的实体。 34 | """ 35 | D = [] 36 | with open(filename, encoding='utf-8') as f: 37 | f = f.read() 38 | for l in f.split('\n\n'): 39 | if not l: 40 | continue 41 | d = [''] 42 | for i, c in enumerate(l.split('\n')): 43 | char, flag = c.split(' ') 44 | d[0] += char 45 | if flag[0] == 'B': 46 | d.append([i, i, flag[2:]]) 47 | categories.add(flag[2:]) 48 | elif flag[0] == 'I': 49 | d[-1][1] = i 50 | D.append(d) 51 | return D 52 | 53 | 54 | # 标注数据 55 | train_data = load_data('/root/ner/china-people-daily-ner-corpus/example.train') 56 | valid_data = load_data('/root/ner/china-people-daily-ner-corpus/example.dev') 57 | test_data = load_data('/root/ner/china-people-daily-ner-corpus/example.test') 58 | categories = list(sorted(categories)) 59 | 60 | # 建立分词器 61 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 62 | 63 | 64 | class data_generator(DataGenerator): 65 | """数据生成器 66 | """ 67 | def __iter__(self, random=False): 68 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 69 | for is_end, d in self.sample(random): 70 | tokens = tokenizer.tokenize(d[0], maxlen=maxlen) 71 | mapping = tokenizer.rematch(d[0], tokens) 72 | start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} 73 | end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} 74 | token_ids = tokenizer.tokens_to_ids(tokens) 75 | segment_ids = [0] * len(token_ids) 76 | labels = np.zeros(len(token_ids)) 77 | for start, end, label in d[1:]: 78 | if start in start_mapping and end in end_mapping: 79 | start = start_mapping[start] 80 | end = end_mapping[end] 81 | labels[start] = categories.index(label) * 2 + 1 82 | labels[start + 1:end + 1] = categories.index(label) * 2 + 2 83 | batch_token_ids.append(token_ids) 84 | batch_segment_ids.append(segment_ids) 85 | batch_labels.append(labels) 86 | if len(batch_token_ids) == self.batch_size or is_end: 87 | batch_token_ids = sequence_padding(batch_token_ids) 88 | batch_segment_ids = sequence_padding(batch_segment_ids) 89 | batch_labels = sequence_padding(batch_labels) 90 | yield [batch_token_ids, batch_segment_ids], batch_labels 91 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 92 | 93 | 94 | model = build_transformer_model(config_path, checkpoint_path) 95 | output = Dense(len(categories) * 2 + 1)(model.output) 96 | CRF = ConditionalRandomField(lr_multiplier=crf_lr_multiplier) 97 | output = CRF(output) 98 | 99 | model = Model(model.input, output) 100 | model.summary() 101 | 102 | model.compile( 103 | loss=CRF.sparse_loss, 104 | optimizer=Adam(learning_rate), 105 | metrics=[CRF.sparse_accuracy] 106 | ) 107 | 108 | 109 | class NamedEntityRecognizer(ViterbiDecoder): 110 | """命名实体识别器 111 | """ 112 | def recognize(self, text): 113 | tokens = tokenizer.tokenize(text, maxlen=512) 114 | mapping = tokenizer.rematch(text, tokens) 115 | token_ids = tokenizer.tokens_to_ids(tokens) 116 | segment_ids = [0] * len(token_ids) 117 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 118 | nodes = model.predict([token_ids, segment_ids])[0] 119 | labels = self.decode(nodes) 120 | entities, starting = [], False 121 | for i, label in enumerate(labels): 122 | if label > 0: 123 | if label % 2 == 1: 124 | starting = True 125 | entities.append([[i], categories[(label - 1) // 2]]) 126 | elif starting: 127 | entities[-1][0].append(i) 128 | else: 129 | starting = False 130 | else: 131 | starting = False 132 | return [(mapping[w[0]][0], mapping[w[-1]][-1], l) for w, l in entities] 133 | 134 | 135 | NER = NamedEntityRecognizer(trans=K.eval(CRF.trans), starts=[0], ends=[0]) 136 | 137 | 138 | def evaluate(data): 139 | """评测函数 140 | """ 141 | X, Y, Z = 1e-10, 1e-10, 1e-10 142 | for d in tqdm(data, ncols=100): 143 | R = set(NER.recognize(d[0])) 144 | T = set([tuple(i) for i in d[1:]]) 145 | X += len(R & T) 146 | Y += len(R) 147 | Z += len(T) 148 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 149 | return f1, precision, recall 150 | 151 | 152 | class Evaluator(keras.callbacks.Callback): 153 | """评估与保存 154 | """ 155 | def __init__(self): 156 | self.best_val_f1 = 0 157 | 158 | def on_epoch_end(self, epoch, logs=None): 159 | trans = K.eval(CRF.trans) 160 | NER.trans = trans 161 | print(NER.trans) 162 | f1, precision, recall = evaluate(valid_data) 163 | # 保存最优 164 | if f1 >= self.best_val_f1: 165 | self.best_val_f1 = f1 166 | model.save_weights('./best_model_peopledaily_crf.weights') 167 | print( 168 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 169 | (f1, precision, recall, self.best_val_f1) 170 | ) 171 | f1, precision, recall = evaluate(test_data) 172 | print( 173 | 'test: f1: %.5f, precision: %.5f, recall: %.5f\n' % 174 | (f1, precision, recall) 175 | ) 176 | 177 | 178 | if __name__ == '__main__': 179 | 180 | evaluator = Evaluator() 181 | train_generator = data_generator(train_data, batch_size) 182 | 183 | model.fit( 184 | train_generator.forfit(), 185 | steps_per_epoch=len(train_generator), 186 | epochs=epochs, 187 | callbacks=[evaluator] 188 | ) 189 | 190 | else: 191 | 192 | model.load_weights('./best_model_peopledaily_crf.weights') 193 | NER.trans = K.eval(CRF.trans) 194 | -------------------------------------------------------------------------------- /PeopleDaily_GlobalPointer.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用GlobalPointer做中文命名实体识别 3 | # 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz 4 | 5 | import numpy as np 6 | from bert4keras.backend import keras, K 7 | from bert4keras.backend import multilabel_categorical_crossentropy 8 | from bert4keras.layers import GlobalPointer 9 | from bert4keras.models import build_transformer_model 10 | from bert4keras.tokenizers import Tokenizer 11 | from bert4keras.optimizers import Adam 12 | from bert4keras.snippets import sequence_padding, DataGenerator 13 | from bert4keras.snippets import open, to_array 14 | from keras.models import Model 15 | from tqdm import tqdm 16 | 17 | maxlen = 256 18 | epochs = 10 19 | batch_size = 16 20 | learning_rate = 2e-5 21 | categories = set() 22 | 23 | # bert配置 24 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 25 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 26 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 27 | 28 | 29 | def load_data(filename): 30 | """加载数据 31 | 单条格式:[text, (start, end, label), (start, end, label), ...], 32 | 意味着text[start:end + 1]是类型为label的实体。 33 | """ 34 | D = [] 35 | with open(filename, encoding='utf-8') as f: 36 | f = f.read() 37 | for l in f.split('\n\n'): 38 | if not l: 39 | continue 40 | d = [''] 41 | for i, c in enumerate(l.split('\n')): 42 | char, flag = c.split(' ') 43 | d[0] += char 44 | if flag[0] == 'B': 45 | d.append([i, i, flag[2:]]) 46 | categories.add(flag[2:]) 47 | elif flag[0] == 'I': 48 | d[-1][1] = i 49 | D.append(d) 50 | return D 51 | 52 | 53 | # 标注数据 54 | train_data = load_data('/root/ner/china-people-daily-ner-corpus/example.train') 55 | valid_data = load_data('/root/ner/china-people-daily-ner-corpus/example.dev') 56 | test_data = load_data('/root/ner/china-people-daily-ner-corpus/example.test') 57 | categories = list(sorted(categories)) 58 | 59 | # 建立分词器 60 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 61 | 62 | 63 | class data_generator(DataGenerator): 64 | """数据生成器 65 | """ 66 | def __iter__(self, random=False): 67 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 68 | for is_end, d in self.sample(random): 69 | tokens = tokenizer.tokenize(d[0], maxlen=maxlen) 70 | mapping = tokenizer.rematch(d[0], tokens) 71 | start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} 72 | end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} 73 | token_ids = tokenizer.tokens_to_ids(tokens) 74 | segment_ids = [0] * len(token_ids) 75 | labels = np.zeros((len(categories), maxlen, maxlen)) 76 | for start, end, label in d[1:]: 77 | if start in start_mapping and end in end_mapping: 78 | start = start_mapping[start] 79 | end = end_mapping[end] 80 | label = categories.index(label) 81 | labels[label, start, end] = 1 82 | batch_token_ids.append(token_ids) 83 | batch_segment_ids.append(segment_ids) 84 | batch_labels.append(labels[:, :len(token_ids), :len(token_ids)]) 85 | if len(batch_token_ids) == self.batch_size or is_end: 86 | batch_token_ids = sequence_padding(batch_token_ids) 87 | batch_segment_ids = sequence_padding(batch_segment_ids) 88 | batch_labels = sequence_padding(batch_labels, seq_dims=3) 89 | yield [batch_token_ids, batch_segment_ids], batch_labels 90 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 91 | 92 | 93 | def global_pointer_crossentropy(y_true, y_pred): 94 | """给GlobalPointer设计的交叉熵 95 | """ 96 | bh = K.prod(K.shape(y_pred)[:2]) 97 | y_true = K.reshape(y_true, (bh, -1)) 98 | y_pred = K.reshape(y_pred, (bh, -1)) 99 | return K.mean(multilabel_categorical_crossentropy(y_true, y_pred)) 100 | 101 | 102 | def global_pointer_f1_score(y_true, y_pred): 103 | """给GlobalPointer设计的F1 104 | """ 105 | y_pred = K.cast(K.greater(y_pred, 0), K.floatx()) 106 | return 2 * K.sum(y_true * y_pred) / K.sum(y_true + y_pred) 107 | 108 | 109 | model = build_transformer_model(config_path, checkpoint_path) 110 | output = GlobalPointer(len(categories), 64)(model.output) 111 | 112 | model = Model(model.input, output) 113 | model.summary() 114 | 115 | model.compile( 116 | loss=global_pointer_crossentropy, 117 | optimizer=Adam(learning_rate), 118 | metrics=[global_pointer_f1_score] 119 | ) 120 | 121 | 122 | class NamedEntityRecognizer(object): 123 | """命名实体识别器 124 | """ 125 | def recognize(self, text, threshold=0): 126 | tokens = tokenizer.tokenize(text, maxlen=512) 127 | mapping = tokenizer.rematch(text, tokens) 128 | token_ids = tokenizer.tokens_to_ids(tokens) 129 | segment_ids = [0] * len(token_ids) 130 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 131 | scores = model.predict([token_ids, segment_ids])[0] 132 | scores[:, [0, -1]] -= np.inf 133 | scores[:, :, [0, -1]] -= np.inf 134 | entities = [] 135 | for l, start, end in zip(*np.where(scores > threshold)): 136 | entities.append( 137 | (mapping[start][0], mapping[end][-1], categories[l]) 138 | ) 139 | return entities 140 | 141 | 142 | NER = NamedEntityRecognizer() 143 | 144 | 145 | def evaluate(data): 146 | """评测函数 147 | """ 148 | X, Y, Z = 1e-10, 1e-10, 1e-10 149 | for d in tqdm(data, ncols=100): 150 | R = set(NER.recognize(d[0])) 151 | T = set([tuple(i) for i in d[1:]]) 152 | X += len(R & T) 153 | Y += len(R) 154 | Z += len(T) 155 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 156 | return f1, precision, recall 157 | 158 | 159 | class Evaluator(keras.callbacks.Callback): 160 | """评估与保存 161 | """ 162 | def __init__(self): 163 | self.best_val_f1 = 0 164 | 165 | def on_epoch_end(self, epoch, logs=None): 166 | f1, precision, recall = evaluate(valid_data) 167 | # 保存最优 168 | if f1 >= self.best_val_f1: 169 | self.best_val_f1 = f1 170 | model.save_weights('./best_model_peopledaily_globalpointer.weights') 171 | print( 172 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 173 | (f1, precision, recall, self.best_val_f1) 174 | ) 175 | f1, precision, recall = evaluate(test_data) 176 | print( 177 | 'test: f1: %.5f, precision: %.5f, recall: %.5f\n' % 178 | (f1, precision, recall) 179 | ) 180 | 181 | 182 | if __name__ == '__main__': 183 | 184 | evaluator = Evaluator() 185 | train_generator = data_generator(train_data, batch_size) 186 | 187 | model.fit( 188 | train_generator.forfit(), 189 | steps_per_epoch=len(train_generator), 190 | epochs=epochs, 191 | callbacks=[evaluator] 192 | ) 193 | 194 | else: 195 | 196 | model.load_weights('./best_model_peopledaily_globalpointer.weights') 197 | -------------------------------------------------------------------------------- /CMeEE_CRF.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用CRF做中文命名实体识别 3 | # 数据集 https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.backend import keras, K 8 | from bert4keras.models import build_transformer_model 9 | from bert4keras.tokenizers import Tokenizer 10 | from bert4keras.optimizers import Adam 11 | from bert4keras.snippets import sequence_padding, DataGenerator 12 | from bert4keras.snippets import open, ViterbiDecoder, to_array 13 | from bert4keras.layers import ConditionalRandomField 14 | from keras.layers import Dense 15 | from keras.models import Model 16 | from tqdm import tqdm 17 | 18 | maxlen = 256 19 | epochs = 10 20 | batch_size = 16 21 | learning_rate = 2e-5 22 | crf_lr_multiplier = 1000 # 必要时扩大CRF层的学习率 23 | categories = set() 24 | 25 | # bert配置 26 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 27 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 28 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 29 | 30 | 31 | def load_data(filename): 32 | """加载数据 33 | 单条格式:[text, (start, end, label), (start, end, label), ...], 34 | 意味着text[start:end + 1]是类型为label的实体。 35 | """ 36 | D = [] 37 | for d in json.load(open(filename)): 38 | D.append([d['text']]) 39 | for e in d['entities']: 40 | start, end, label = e['start_idx'], e['end_idx'], e['type'] 41 | if start <= end: 42 | D[-1].append((start, end, label)) 43 | categories.add(label) 44 | return D 45 | 46 | 47 | # 标注数据 48 | train_data = load_data('/root/ner/CMeEE/CMeEE_train.json') 49 | valid_data = load_data('/root/ner/CMeEE/CMeEE_dev.json') 50 | categories = list(sorted(categories)) 51 | 52 | # 建立分词器 53 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 54 | 55 | 56 | class data_generator(DataGenerator): 57 | """数据生成器 58 | """ 59 | def __iter__(self, random=False): 60 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 61 | for is_end, d in self.sample(random): 62 | tokens = tokenizer.tokenize(d[0], maxlen=maxlen) 63 | mapping = tokenizer.rematch(d[0], tokens) 64 | start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} 65 | end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} 66 | token_ids = tokenizer.tokens_to_ids(tokens) 67 | segment_ids = [0] * len(token_ids) 68 | labels = np.zeros(len(token_ids)) 69 | for start, end, label in d[1:]: 70 | if start in start_mapping and end in end_mapping: 71 | start = start_mapping[start] 72 | end = end_mapping[end] 73 | labels[start] = categories.index(label) * 2 + 1 74 | labels[start + 1:end + 1] = categories.index(label) * 2 + 2 75 | batch_token_ids.append(token_ids) 76 | batch_segment_ids.append(segment_ids) 77 | batch_labels.append(labels) 78 | if len(batch_token_ids) == self.batch_size or is_end: 79 | batch_token_ids = sequence_padding(batch_token_ids) 80 | batch_segment_ids = sequence_padding(batch_segment_ids) 81 | batch_labels = sequence_padding(batch_labels) 82 | yield [batch_token_ids, batch_segment_ids], batch_labels 83 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 84 | 85 | 86 | model = build_transformer_model(config_path, checkpoint_path) 87 | output = Dense(len(categories) * 2 + 1)(model.output) 88 | CRF = ConditionalRandomField(lr_multiplier=crf_lr_multiplier) 89 | output = CRF(output) 90 | 91 | model = Model(model.input, output) 92 | model.summary() 93 | 94 | model.compile( 95 | loss=CRF.sparse_loss, 96 | optimizer=Adam(learning_rate), 97 | metrics=[CRF.sparse_accuracy] 98 | ) 99 | 100 | 101 | class NamedEntityRecognizer(ViterbiDecoder): 102 | """命名实体识别器 103 | """ 104 | def recognize(self, text): 105 | tokens = tokenizer.tokenize(text, maxlen=512) 106 | mapping = tokenizer.rematch(text, tokens) 107 | token_ids = tokenizer.tokens_to_ids(tokens) 108 | segment_ids = [0] * len(token_ids) 109 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 110 | nodes = model.predict([token_ids, segment_ids])[0] 111 | labels = self.decode(nodes) 112 | entities, starting = [], False 113 | for i, label in enumerate(labels): 114 | if label > 0: 115 | if label % 2 == 1: 116 | starting = True 117 | entities.append([[i], categories[(label - 1) // 2]]) 118 | elif starting: 119 | entities[-1][0].append(i) 120 | else: 121 | starting = False 122 | else: 123 | starting = False 124 | return [(mapping[w[0]][0], mapping[w[-1]][-1], l) for w, l in entities] 125 | 126 | 127 | NER = NamedEntityRecognizer(trans=K.eval(CRF.trans), starts=[0], ends=[0]) 128 | 129 | 130 | def evaluate(data): 131 | """评测函数 132 | """ 133 | X, Y, Z = 1e-10, 1e-10, 1e-10 134 | for d in tqdm(data, ncols=100): 135 | R = set(NER.recognize(d[0])) 136 | T = set([tuple(i) for i in d[1:]]) 137 | X += len(R & T) 138 | Y += len(R) 139 | Z += len(T) 140 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 141 | return f1, precision, recall 142 | 143 | 144 | class Evaluator(keras.callbacks.Callback): 145 | """评估与保存 146 | """ 147 | def __init__(self): 148 | self.best_val_f1 = 0 149 | 150 | def on_epoch_end(self, epoch, logs=None): 151 | trans = K.eval(CRF.trans) 152 | NER.trans = trans 153 | print(NER.trans) 154 | f1, precision, recall = evaluate(valid_data) 155 | # 保存最优 156 | if f1 >= self.best_val_f1: 157 | self.best_val_f1 = f1 158 | model.save_weights('./best_model_cmeee_crf.weights') 159 | print( 160 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 161 | (f1, precision, recall, self.best_val_f1) 162 | ) 163 | 164 | 165 | def predict_to_file(in_file, out_file): 166 | """预测到文件 167 | 可以提交到 https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 168 | """ 169 | data = json.load(open(in_file)) 170 | for d in tqdm(data, ncols=100): 171 | d['entities'] = [] 172 | entities = NER.recognize(d['text']) 173 | for e in entities: 174 | d['entities'].append({ 175 | 'start_idx': e[0], 176 | 'end_idx': e[1], 177 | 'type': e[2] 178 | }) 179 | json.dump( 180 | data, 181 | open(out_file, 'w', encoding='utf-8'), 182 | indent=4, 183 | ensure_ascii=False 184 | ) 185 | 186 | 187 | if __name__ == '__main__': 188 | 189 | evaluator = Evaluator() 190 | train_generator = data_generator(train_data, batch_size) 191 | 192 | model.fit( 193 | train_generator.forfit(), 194 | steps_per_epoch=len(train_generator), 195 | epochs=epochs, 196 | callbacks=[evaluator] 197 | ) 198 | 199 | else: 200 | 201 | model.load_weights('./best_model_cmeee_crf.weights') 202 | NER.trans = K.eval(CRF.trans) 203 | # predict_to_file('/root/ner/CMeEE/CMeEE_test.json', 'CMeEE_test.json') 204 | -------------------------------------------------------------------------------- /CMeEE_GlobalPointer.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用GlobalPointer做中文命名实体识别 3 | # 数据集 https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.backend import keras, K 8 | from bert4keras.backend import multilabel_categorical_crossentropy 9 | from bert4keras.layers import GlobalPointer 10 | from bert4keras.models import build_transformer_model 11 | from bert4keras.tokenizers import Tokenizer 12 | from bert4keras.optimizers import Adam 13 | from bert4keras.snippets import sequence_padding, DataGenerator 14 | from bert4keras.snippets import open, to_array 15 | from keras.models import Model 16 | from tqdm import tqdm 17 | 18 | maxlen = 256 19 | epochs = 10 20 | batch_size = 16 21 | learning_rate = 2e-5 22 | categories = set() 23 | 24 | # bert配置 25 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 26 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 27 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 28 | 29 | 30 | def load_data(filename): 31 | """加载数据 32 | 单条格式:[text, (start, end, label), (start, end, label), ...], 33 | 意味着text[start:end + 1]是类型为label的实体。 34 | """ 35 | D = [] 36 | for d in json.load(open(filename)): 37 | D.append([d['text']]) 38 | for e in d['entities']: 39 | start, end, label = e['start_idx'], e['end_idx'], e['type'] 40 | if start <= end: 41 | D[-1].append((start, end, label)) 42 | categories.add(label) 43 | return D 44 | 45 | 46 | # 标注数据 47 | train_data = load_data('/root/ner/CMeEE/CMeEE_train.json') 48 | valid_data = load_data('/root/ner/CMeEE/CMeEE_dev.json') 49 | categories = list(sorted(categories)) 50 | 51 | # 建立分词器 52 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 53 | 54 | 55 | class data_generator(DataGenerator): 56 | """数据生成器 57 | """ 58 | def __iter__(self, random=False): 59 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 60 | for is_end, d in self.sample(random): 61 | tokens = tokenizer.tokenize(d[0], maxlen=maxlen) 62 | mapping = tokenizer.rematch(d[0], tokens) 63 | start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} 64 | end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} 65 | token_ids = tokenizer.tokens_to_ids(tokens) 66 | segment_ids = [0] * len(token_ids) 67 | labels = np.zeros((len(categories), maxlen, maxlen)) 68 | for start, end, label in d[1:]: 69 | if start in start_mapping and end in end_mapping: 70 | start = start_mapping[start] 71 | end = end_mapping[end] 72 | label = categories.index(label) 73 | labels[label, start, end] = 1 74 | batch_token_ids.append(token_ids) 75 | batch_segment_ids.append(segment_ids) 76 | batch_labels.append(labels[:, :len(token_ids), :len(token_ids)]) 77 | if len(batch_token_ids) == self.batch_size or is_end: 78 | batch_token_ids = sequence_padding(batch_token_ids) 79 | batch_segment_ids = sequence_padding(batch_segment_ids) 80 | batch_labels = sequence_padding(batch_labels, seq_dims=3) 81 | yield [batch_token_ids, batch_segment_ids], batch_labels 82 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 83 | 84 | 85 | def global_pointer_crossentropy(y_true, y_pred): 86 | """给GlobalPointer设计的交叉熵 87 | """ 88 | bh = K.prod(K.shape(y_pred)[:2]) 89 | y_true = K.reshape(y_true, (bh, -1)) 90 | y_pred = K.reshape(y_pred, (bh, -1)) 91 | return K.mean(multilabel_categorical_crossentropy(y_true, y_pred)) 92 | 93 | 94 | def global_pointer_f1_score(y_true, y_pred): 95 | """给GlobalPointer设计的F1 96 | """ 97 | y_pred = K.cast(K.greater(y_pred, 0), K.floatx()) 98 | return 2 * K.sum(y_true * y_pred) / K.sum(y_true + y_pred) 99 | 100 | 101 | model = build_transformer_model(config_path, checkpoint_path) 102 | output = GlobalPointer(len(categories), 64)(model.output) 103 | 104 | model = Model(model.input, output) 105 | model.summary() 106 | 107 | model.compile( 108 | loss=global_pointer_crossentropy, 109 | optimizer=Adam(learning_rate), 110 | metrics=[global_pointer_f1_score] 111 | ) 112 | 113 | 114 | class NamedEntityRecognizer(object): 115 | """命名实体识别器 116 | """ 117 | def recognize(self, text, threshold=0): 118 | tokens = tokenizer.tokenize(text, maxlen=512) 119 | mapping = tokenizer.rematch(text, tokens) 120 | token_ids = tokenizer.tokens_to_ids(tokens) 121 | segment_ids = [0] * len(token_ids) 122 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 123 | scores = model.predict([token_ids, segment_ids])[0] 124 | scores[:, [0, -1]] -= np.inf 125 | scores[:, :, [0, -1]] -= np.inf 126 | entities = [] 127 | for l, start, end in zip(*np.where(scores > threshold)): 128 | entities.append( 129 | (mapping[start][0], mapping[end][-1], categories[l]) 130 | ) 131 | return entities 132 | 133 | 134 | NER = NamedEntityRecognizer() 135 | 136 | 137 | def evaluate(data): 138 | """评测函数 139 | """ 140 | X, Y, Z = 1e-10, 1e-10, 1e-10 141 | for d in tqdm(data, ncols=100): 142 | R = set(NER.recognize(d[0])) 143 | T = set([tuple(i) for i in d[1:]]) 144 | X += len(R & T) 145 | Y += len(R) 146 | Z += len(T) 147 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 148 | return f1, precision, recall 149 | 150 | 151 | class Evaluator(keras.callbacks.Callback): 152 | """评估与保存 153 | """ 154 | def __init__(self): 155 | self.best_val_f1 = 0 156 | 157 | def on_epoch_end(self, epoch, logs=None): 158 | f1, precision, recall = evaluate(valid_data) 159 | # 保存最优 160 | if f1 >= self.best_val_f1: 161 | self.best_val_f1 = f1 162 | model.save_weights('./best_model_cmeee_globalpointer.weights') 163 | print( 164 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 165 | (f1, precision, recall, self.best_val_f1) 166 | ) 167 | 168 | 169 | def predict_to_file(in_file, out_file): 170 | """预测到文件 171 | 可以提交到 https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 172 | """ 173 | data = json.load(open(in_file)) 174 | for d in tqdm(data, ncols=100): 175 | d['entities'] = [] 176 | entities = NER.recognize(d['text']) 177 | for e in entities: 178 | d['entities'].append({ 179 | 'start_idx': e[0], 180 | 'end_idx': e[1], 181 | 'type': e[2] 182 | }) 183 | json.dump( 184 | data, 185 | open(out_file, 'w', encoding='utf-8'), 186 | indent=4, 187 | ensure_ascii=False 188 | ) 189 | 190 | 191 | if __name__ == '__main__': 192 | 193 | evaluator = Evaluator() 194 | train_generator = data_generator(train_data, batch_size) 195 | 196 | model.fit( 197 | train_generator.forfit(), 198 | steps_per_epoch=len(train_generator), 199 | epochs=epochs, 200 | callbacks=[evaluator] 201 | ) 202 | 203 | else: 204 | 205 | model.load_weights('./best_model_cmeee_globalpointer.weights') 206 | # predict_to_file('/root/ner/CMeEE/CMeEE_test.json', 'CMeEE_test.json') 207 | -------------------------------------------------------------------------------- /CLUENER_CRF.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用CRF做中文命名实体识别 3 | # 数据集 https://github.com/CLUEbenchmark/CLUENER2020 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.backend import keras, K 8 | from bert4keras.models import build_transformer_model 9 | from bert4keras.tokenizers import Tokenizer 10 | from bert4keras.optimizers import Adam 11 | from bert4keras.snippets import sequence_padding, DataGenerator 12 | from bert4keras.snippets import open, ViterbiDecoder, to_array 13 | from bert4keras.layers import ConditionalRandomField 14 | from keras.layers import Dense 15 | from keras.models import Model 16 | from tqdm import tqdm 17 | 18 | maxlen = 256 19 | epochs = 10 20 | batch_size = 16 21 | learning_rate = 2e-5 22 | crf_lr_multiplier = 1000 # 必要时扩大CRF层的学习率 23 | categories = set() 24 | 25 | # bert配置 26 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 27 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 28 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 29 | 30 | 31 | def load_data(filename): 32 | """加载数据 33 | 单条格式:[text, (start, end, label), (start, end, label), ...], 34 | 意味着text[start:end + 1]是类型为label的实体。 35 | """ 36 | D = [] 37 | with open(filename, encoding='utf-8') as f: 38 | for l in f: 39 | l = json.loads(l) 40 | d = [l['text']] 41 | for k, v in l['label'].items(): 42 | categories.add(k) 43 | for spans in v.values(): 44 | for start, end in spans: 45 | d.append((start, end, k)) 46 | D.append(d) 47 | return D 48 | 49 | 50 | # 标注数据 51 | train_data = load_data('/root/ner/cluener/train.json') 52 | valid_data = load_data('/root/ner/cluener/dev.json') 53 | categories = list(sorted(categories)) 54 | 55 | # 建立分词器 56 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 57 | 58 | 59 | class data_generator(DataGenerator): 60 | """数据生成器 61 | """ 62 | def __iter__(self, random=False): 63 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 64 | for is_end, d in self.sample(random): 65 | tokens = tokenizer.tokenize(d[0], maxlen=maxlen) 66 | mapping = tokenizer.rematch(d[0], tokens) 67 | start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} 68 | end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} 69 | token_ids = tokenizer.tokens_to_ids(tokens) 70 | segment_ids = [0] * len(token_ids) 71 | labels = np.zeros(len(token_ids)) 72 | for start, end, label in d[1:]: 73 | if start in start_mapping and end in end_mapping: 74 | start = start_mapping[start] 75 | end = end_mapping[end] 76 | labels[start] = categories.index(label) * 2 + 1 77 | labels[start + 1:end + 1] = categories.index(label) * 2 + 2 78 | batch_token_ids.append(token_ids) 79 | batch_segment_ids.append(segment_ids) 80 | batch_labels.append(labels) 81 | if len(batch_token_ids) == self.batch_size or is_end: 82 | batch_token_ids = sequence_padding(batch_token_ids) 83 | batch_segment_ids = sequence_padding(batch_segment_ids) 84 | batch_labels = sequence_padding(batch_labels) 85 | yield [batch_token_ids, batch_segment_ids], batch_labels 86 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 87 | 88 | 89 | model = build_transformer_model(config_path, checkpoint_path) 90 | output = Dense(len(categories) * 2 + 1)(model.output) 91 | CRF = ConditionalRandomField(lr_multiplier=crf_lr_multiplier) 92 | output = CRF(output) 93 | 94 | model = Model(model.input, output) 95 | model.summary() 96 | 97 | model.compile( 98 | loss=CRF.sparse_loss, 99 | optimizer=Adam(learning_rate), 100 | metrics=[CRF.sparse_accuracy] 101 | ) 102 | 103 | 104 | class NamedEntityRecognizer(ViterbiDecoder): 105 | """命名实体识别器 106 | """ 107 | def recognize(self, text): 108 | tokens = tokenizer.tokenize(text, maxlen=512) 109 | mapping = tokenizer.rematch(text, tokens) 110 | token_ids = tokenizer.tokens_to_ids(tokens) 111 | segment_ids = [0] * len(token_ids) 112 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 113 | nodes = model.predict([token_ids, segment_ids])[0] 114 | labels = self.decode(nodes) 115 | entities, starting = [], False 116 | for i, label in enumerate(labels): 117 | if label > 0: 118 | if label % 2 == 1: 119 | starting = True 120 | entities.append([[i], categories[(label - 1) // 2]]) 121 | elif starting: 122 | entities[-1][0].append(i) 123 | else: 124 | starting = False 125 | else: 126 | starting = False 127 | return [(mapping[w[0]][0], mapping[w[-1]][-1], l) for w, l in entities] 128 | 129 | 130 | NER = NamedEntityRecognizer(trans=K.eval(CRF.trans), starts=[0], ends=[0]) 131 | 132 | 133 | def evaluate(data): 134 | """评测函数 135 | """ 136 | X, Y, Z = 1e-10, 1e-10, 1e-10 137 | for d in tqdm(data, ncols=100): 138 | R = set(NER.recognize(d[0])) 139 | T = set([tuple(i) for i in d[1:]]) 140 | X += len(R & T) 141 | Y += len(R) 142 | Z += len(T) 143 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 144 | return f1, precision, recall 145 | 146 | 147 | class Evaluator(keras.callbacks.Callback): 148 | """评估与保存 149 | """ 150 | def __init__(self): 151 | self.best_val_f1 = 0 152 | 153 | def on_epoch_end(self, epoch, logs=None): 154 | trans = K.eval(CRF.trans) 155 | NER.trans = trans 156 | print(NER.trans) 157 | f1, precision, recall = evaluate(valid_data) 158 | # 保存最优 159 | if f1 >= self.best_val_f1: 160 | self.best_val_f1 = f1 161 | model.save_weights('./best_model_cluener_crf.weights') 162 | print( 163 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 164 | (f1, precision, recall, self.best_val_f1) 165 | ) 166 | 167 | 168 | def predict_to_file(in_file, out_file): 169 | """预测到文件 170 | 可以提交到 https://www.cluebenchmarks.com/ner.html 171 | """ 172 | fw = open(out_file, 'w', encoding='utf-8') 173 | with open(in_file) as fr: 174 | for l in tqdm(fr): 175 | l = json.loads(l) 176 | l['label'] = {} 177 | for start, end, label in NER.recognize(l['text']): 178 | if label not in l['label']: 179 | l['label'][label] = {} 180 | entity = l['text'][start:end + 1] 181 | if entity not in l['label'][label]: 182 | l['label'][label][entity] = [] 183 | l['label'][label][entity].append([start, end]) 184 | l = json.dumps(l, ensure_ascii=False) 185 | fw.write(l + '\n') 186 | fw.close() 187 | 188 | 189 | if __name__ == '__main__': 190 | 191 | evaluator = Evaluator() 192 | train_generator = data_generator(train_data, batch_size) 193 | 194 | model.fit( 195 | train_generator.forfit(), 196 | steps_per_epoch=len(train_generator), 197 | epochs=epochs, 198 | callbacks=[evaluator] 199 | ) 200 | 201 | else: 202 | 203 | model.load_weights('./best_model_cluener_crf.weights') 204 | NER.trans = K.eval(CRF.trans) 205 | # predict_to_file('/root/ner/cluener/test.json', 'cluener_test.json') 206 | -------------------------------------------------------------------------------- /CLUENER_GlobalPointer.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 用GlobalPointer做中文命名实体识别 3 | # 数据集 https://github.com/CLUEbenchmark/CLUENER2020 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.backend import keras, K 8 | from bert4keras.backend import multilabel_categorical_crossentropy 9 | from bert4keras.layers import GlobalPointer 10 | from bert4keras.models import build_transformer_model 11 | from bert4keras.tokenizers import Tokenizer 12 | from bert4keras.optimizers import Adam 13 | from bert4keras.snippets import sequence_padding, DataGenerator 14 | from bert4keras.snippets import open, to_array 15 | from keras.models import Model 16 | from tqdm import tqdm 17 | 18 | maxlen = 256 19 | epochs = 10 20 | batch_size = 16 21 | learning_rate = 2e-5 22 | categories = set() 23 | 24 | # bert配置 25 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 26 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 27 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 28 | 29 | 30 | def load_data(filename): 31 | """加载数据 32 | 单条格式:[text, (start, end, label), (start, end, label), ...], 33 | 意味着text[start:end + 1]是类型为label的实体。 34 | """ 35 | D = [] 36 | with open(filename, encoding='utf-8') as f: 37 | for l in f: 38 | l = json.loads(l) 39 | d = [l['text']] 40 | for k, v in l['label'].items(): 41 | categories.add(k) 42 | for spans in v.values(): 43 | for start, end in spans: 44 | d.append((start, end, k)) 45 | D.append(d) 46 | return D 47 | 48 | 49 | # 标注数据 50 | train_data = load_data('/root/ner/cluener/train.json') 51 | valid_data = load_data('/root/ner/cluener/dev.json') 52 | categories = list(sorted(categories)) 53 | 54 | # 建立分词器 55 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 56 | 57 | 58 | class data_generator(DataGenerator): 59 | """数据生成器 60 | """ 61 | def __iter__(self, random=False): 62 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 63 | for is_end, d in self.sample(random): 64 | tokens = tokenizer.tokenize(d[0], maxlen=maxlen) 65 | mapping = tokenizer.rematch(d[0], tokens) 66 | start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} 67 | end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} 68 | token_ids = tokenizer.tokens_to_ids(tokens) 69 | segment_ids = [0] * len(token_ids) 70 | labels = np.zeros((len(categories), maxlen, maxlen)) 71 | for start, end, label in d[1:]: 72 | if start in start_mapping and end in end_mapping: 73 | start = start_mapping[start] 74 | end = end_mapping[end] 75 | label = categories.index(label) 76 | labels[label, start, end] = 1 77 | batch_token_ids.append(token_ids) 78 | batch_segment_ids.append(segment_ids) 79 | batch_labels.append(labels[:, :len(token_ids), :len(token_ids)]) 80 | if len(batch_token_ids) == self.batch_size or is_end: 81 | batch_token_ids = sequence_padding(batch_token_ids) 82 | batch_segment_ids = sequence_padding(batch_segment_ids) 83 | batch_labels = sequence_padding(batch_labels, seq_dims=3) 84 | yield [batch_token_ids, batch_segment_ids], batch_labels 85 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 86 | 87 | 88 | def global_pointer_crossentropy(y_true, y_pred): 89 | """给GlobalPointer设计的交叉熵 90 | """ 91 | bh = K.prod(K.shape(y_pred)[:2]) 92 | y_true = K.reshape(y_true, (bh, -1)) 93 | y_pred = K.reshape(y_pred, (bh, -1)) 94 | return K.mean(multilabel_categorical_crossentropy(y_true, y_pred)) 95 | 96 | 97 | def global_pointer_f1_score(y_true, y_pred): 98 | """给GlobalPointer设计的F1 99 | """ 100 | y_pred = K.cast(K.greater(y_pred, 0), K.floatx()) 101 | return 2 * K.sum(y_true * y_pred) / K.sum(y_true + y_pred) 102 | 103 | 104 | model = build_transformer_model(config_path, checkpoint_path) 105 | output = GlobalPointer(len(categories), 64)(model.output) 106 | 107 | model = Model(model.input, output) 108 | model.summary() 109 | 110 | model.compile( 111 | loss=global_pointer_crossentropy, 112 | optimizer=Adam(learning_rate), 113 | metrics=[global_pointer_f1_score] 114 | ) 115 | 116 | 117 | class NamedEntityRecognizer(object): 118 | """命名实体识别器 119 | """ 120 | def recognize(self, text, threshold=0): 121 | tokens = tokenizer.tokenize(text, maxlen=512) 122 | mapping = tokenizer.rematch(text, tokens) 123 | token_ids = tokenizer.tokens_to_ids(tokens) 124 | segment_ids = [0] * len(token_ids) 125 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 126 | scores = model.predict([token_ids, segment_ids])[0] 127 | scores[:, [0, -1]] -= np.inf 128 | scores[:, :, [0, -1]] -= np.inf 129 | entities = [] 130 | for l, start, end in zip(*np.where(scores > threshold)): 131 | entities.append( 132 | (mapping[start][0], mapping[end][-1], categories[l]) 133 | ) 134 | return entities 135 | 136 | 137 | NER = NamedEntityRecognizer() 138 | 139 | 140 | def evaluate(data): 141 | """评测函数 142 | """ 143 | X, Y, Z = 1e-10, 1e-10, 1e-10 144 | for d in tqdm(data, ncols=100): 145 | R = set(NER.recognize(d[0])) 146 | T = set([tuple(i) for i in d[1:]]) 147 | X += len(R & T) 148 | Y += len(R) 149 | Z += len(T) 150 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 151 | return f1, precision, recall 152 | 153 | 154 | class Evaluator(keras.callbacks.Callback): 155 | """评估与保存 156 | """ 157 | def __init__(self): 158 | self.best_val_f1 = 0 159 | 160 | def on_epoch_end(self, epoch, logs=None): 161 | f1, precision, recall = evaluate(valid_data) 162 | # 保存最优 163 | if f1 >= self.best_val_f1: 164 | self.best_val_f1 = f1 165 | model.save_weights('./best_model_cluener_globalpointer.weights') 166 | print( 167 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % 168 | (f1, precision, recall, self.best_val_f1) 169 | ) 170 | 171 | 172 | def predict_to_file(in_file, out_file): 173 | """预测到文件 174 | 可以提交到 https://www.cluebenchmarks.com/ner.html 175 | """ 176 | fw = open(out_file, 'w', encoding='utf-8') 177 | with open(in_file) as fr: 178 | for l in tqdm(fr): 179 | l = json.loads(l) 180 | l['label'] = {} 181 | for start, end, label in NER.recognize(l['text']): 182 | if label not in l['label']: 183 | l['label'][label] = {} 184 | entity = l['text'][start:end + 1] 185 | if entity not in l['label'][label]: 186 | l['label'][label][entity] = [] 187 | l['label'][label][entity].append([start, end]) 188 | l = json.dumps(l, ensure_ascii=False) 189 | fw.write(l + '\n') 190 | fw.close() 191 | 192 | 193 | if __name__ == '__main__': 194 | 195 | evaluator = Evaluator() 196 | train_generator = data_generator(train_data, batch_size) 197 | 198 | model.fit( 199 | train_generator.forfit(), 200 | steps_per_epoch=len(train_generator), 201 | epochs=epochs, 202 | callbacks=[evaluator] 203 | ) 204 | 205 | else: 206 | 207 | model.load_weights('./best_model_cluener_globalpointer.weights') 208 | # predict_to_file('/root/ner/cluener/test.json', 'cluener_test.json') 209 | --------------------------------------------------------------------------------