├── .idea └── modules.xml ├── README.md ├── ckpt └── .gitkeep ├── conlleval.py ├── data └── README.md ├── data_helper.py ├── imgs ├── bert_bilstm_crf.png ├── demo.png └── struct.png ├── log └── .gitkeep ├── models ├── BERT_BiLSTM_CRF.py ├── __init__.py ├── base_config.py └── rnncell.py ├── result └── .gitkeep ├── run.py ├── train_val_test.py └── utils.py /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于BERT+BiLSTM+CRF实现中文命名实体识别 2 | ## 1、目录结构 3 | 15 | 16 | ## 2、数据 17 | 21 | 22 | ## 3、运行 23 | 下载bert到项目路径
24 | 创建bert_model路径,将预训练好的bert模型放到这个路径下解压
25 | 具体结构如下: 26 | 27 | python3 run.py --mode xxx
28 | xxx: train/test/demo,默认为demo 29 | 30 | ## 4、效果 31 | 训练过程: 32 | 33 | 34 | 单句测试: 35 | 36 | 37 | ## 5、参考 38 | [1] https://github.com/yumath/bertNER
39 | [2] https://github.com/google-research/bert 40 | -------------------------------------------------------------------------------- /ckpt/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/ckpt/.gitkeep -------------------------------------------------------------------------------- /conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | import sys 13 | import re 14 | import codecs 15 | from collections import defaultdict, namedtuple 16 | 17 | ANY_SPACE = '' 18 | 19 | 20 | class FormatError(Exception): 21 | pass 22 | 23 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 24 | 25 | 26 | class EvalCounts(object): 27 | def __init__(self): 28 | self.correct_chunk = 0 # number of correctly identified chunks 29 | self.correct_tags = 0 # number of correct chunk tags 30 | self.found_correct = 0 # number of chunks in corpus 31 | self.found_guessed = 0 # number of identified chunks 32 | self.token_counter = 0 # token counter (ignores sentence breaks) 33 | 34 | # counts by type 35 | self.t_correct_chunk = defaultdict(int) 36 | self.t_found_correct = defaultdict(int) 37 | self.t_found_guessed = defaultdict(int) 38 | 39 | 40 | def parse_args(argv): 41 | import argparse 42 | parser = argparse.ArgumentParser( 43 | description='evaluate tagging results using CoNLL criteria', 44 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 45 | ) 46 | arg = parser.add_argument 47 | arg('-b', '--boundary', metavar='STR', default='-X-', 48 | help='sentence boundary') 49 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 50 | help='character delimiting items in input') 51 | arg('-o', '--otag', metavar='CHAR', default='O', 52 | help='alternative outside tag') 53 | arg('file', nargs='?', default=None) 54 | return parser.parse_args(argv) 55 | 56 | 57 | def parse_tag(t): 58 | m = re.match(r'^([^-]*)-(.*)$', t) 59 | return m.groups() if m else (t, '') 60 | 61 | 62 | def evaluate(iterable, options=None): 63 | if options is None: 64 | options = parse_args([]) # use defaults 65 | 66 | counts = EvalCounts() 67 | num_features = None # number of features per line 68 | in_correct = False # currently processed chunks is correct until now 69 | last_correct = 'O' # previous chunk tag in corpus 70 | last_correct_type = '' # type of previously identified chunk tag 71 | last_guessed = 'O' # previously identified chunk tag 72 | last_guessed_type = '' # type of previous chunk tag in corpus 73 | 74 | for line in iterable: 75 | line = line.rstrip('\r\n') 76 | 77 | if options.delimiter == ANY_SPACE: 78 | features = line.split() 79 | else: 80 | features = line.split(options.delimiter) 81 | 82 | if num_features is None: 83 | num_features = len(features) 84 | elif num_features != len(features) and len(features) != 0: 85 | raise FormatError('unexpected number of features: %d (%d)' % 86 | (len(features), num_features)) 87 | 88 | if len(features) == 0 or features[0] == options.boundary: 89 | features = [options.boundary, 'O', 'O'] 90 | if len(features) < 3: 91 | raise FormatError('unexpected number of features in line %s' % line) 92 | 93 | guessed, guessed_type = parse_tag(features.pop()) 94 | correct, correct_type = parse_tag(features.pop()) 95 | first_item = features.pop(0) 96 | 97 | if first_item == options.boundary: 98 | guessed = 'O' 99 | 100 | end_correct = end_of_chunk(last_correct, correct, 101 | last_correct_type, correct_type) 102 | end_guessed = end_of_chunk(last_guessed, guessed, 103 | last_guessed_type, guessed_type) 104 | start_correct = start_of_chunk(last_correct, correct, 105 | last_correct_type, correct_type) 106 | start_guessed = start_of_chunk(last_guessed, guessed, 107 | last_guessed_type, guessed_type) 108 | 109 | if in_correct: 110 | if (end_correct and end_guessed and 111 | last_guessed_type == last_correct_type): 112 | in_correct = False 113 | counts.correct_chunk += 1 114 | counts.t_correct_chunk[last_correct_type] += 1 115 | elif (end_correct != end_guessed or guessed_type != correct_type): 116 | in_correct = False 117 | 118 | if start_correct and start_guessed and guessed_type == correct_type: 119 | in_correct = True 120 | 121 | if start_correct: 122 | counts.found_correct += 1 123 | counts.t_found_correct[correct_type] += 1 124 | if start_guessed: 125 | counts.found_guessed += 1 126 | counts.t_found_guessed[guessed_type] += 1 127 | if first_item != options.boundary: 128 | if correct == guessed and guessed_type == correct_type: 129 | counts.correct_tags += 1 130 | counts.token_counter += 1 131 | 132 | last_guessed = guessed 133 | last_correct = correct 134 | last_guessed_type = guessed_type 135 | last_correct_type = correct_type 136 | 137 | if in_correct: 138 | counts.correct_chunk += 1 139 | counts.t_correct_chunk[last_correct_type] += 1 140 | 141 | return counts 142 | 143 | 144 | def uniq(iterable): 145 | seen = set() 146 | return [i for i in iterable if not (i in seen or seen.add(i))] 147 | 148 | 149 | def calculate_metrics(correct, guessed, total): 150 | tp, fp, fn = correct, guessed-correct, total-correct 151 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 152 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 153 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 154 | return Metrics(tp, fp, fn, p, r, f) 155 | 156 | 157 | def metrics(counts): 158 | c = counts 159 | overall = calculate_metrics( 160 | c.correct_chunk, c.found_guessed, c.found_correct 161 | ) 162 | by_type = {} 163 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 164 | by_type[t] = calculate_metrics( 165 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 166 | ) 167 | return overall, by_type 168 | 169 | 170 | def report(counts, out=None): 171 | if out is None: 172 | out = sys.stdout 173 | 174 | overall, by_type = metrics(counts) 175 | 176 | c = counts 177 | out.write('processed %d tokens with %d phrases; ' % 178 | (c.token_counter, c.found_correct)) 179 | out.write('found: %d phrases; correct: %d.\n' % 180 | (c.found_guessed, c.correct_chunk)) 181 | 182 | if c.token_counter > 0: 183 | out.write('accuracy: %6.2f%%; ' % 184 | (100.*c.correct_tags/c.token_counter)) 185 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 186 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 187 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 188 | 189 | for i, m in sorted(by_type.items()): 190 | out.write('%17s: ' % i) 191 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 192 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 193 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 194 | 195 | 196 | def report_notprint(counts, out=None): 197 | if out is None: 198 | out = sys.stdout 199 | 200 | overall, by_type = metrics(counts) 201 | 202 | c = counts 203 | final_report = [] 204 | line = [] 205 | line.append('processed %d tokens with %d phrases; ' % 206 | (c.token_counter, c.found_correct)) 207 | line.append('found: %d phrases; correct: %d.\n' % 208 | (c.found_guessed, c.correct_chunk)) 209 | final_report.append("".join(line)) 210 | 211 | if c.token_counter > 0: 212 | line = [] 213 | line.append('accuracy: %6.2f%%; ' % 214 | (100.*c.correct_tags/c.token_counter)) 215 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 216 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 217 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 218 | final_report.append("".join(line)) 219 | 220 | for i, m in sorted(by_type.items()): 221 | line = [] 222 | line.append('%17s: ' % i) 223 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 224 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 225 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 226 | final_report.append("".join(line)) 227 | return final_report 228 | 229 | 230 | def end_of_chunk(prev_tag, tag, prev_type, type_): 231 | # check if a chunk ended between the previous and current word 232 | # arguments: previous and current chunk tags, previous and current types 233 | chunk_end = False 234 | 235 | if prev_tag == 'E': chunk_end = True 236 | if prev_tag == 'S': chunk_end = True 237 | 238 | if prev_tag == 'B' and tag == 'B': chunk_end = True 239 | if prev_tag == 'B' and tag == 'S': chunk_end = True 240 | if prev_tag == 'B' and tag == 'O': chunk_end = True 241 | if prev_tag == 'I' and tag == 'B': chunk_end = True 242 | if prev_tag == 'I' and tag == 'S': chunk_end = True 243 | if prev_tag == 'I' and tag == 'O': chunk_end = True 244 | 245 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 246 | chunk_end = True 247 | 248 | # these chunks are assumed to have length 1 249 | if prev_tag == ']': chunk_end = True 250 | if prev_tag == '[': chunk_end = True 251 | 252 | return chunk_end 253 | 254 | 255 | def start_of_chunk(prev_tag, tag, prev_type, type_): 256 | # check if a chunk started between the previous and current word 257 | # arguments: previous and current chunk tags, previous and current types 258 | chunk_start = False 259 | 260 | if tag == 'B': chunk_start = True 261 | if tag == 'S': chunk_start = True 262 | 263 | if prev_tag == 'E' and tag == 'E': chunk_start = True 264 | if prev_tag == 'E' and tag == 'I': chunk_start = True 265 | if prev_tag == 'S' and tag == 'E': chunk_start = True 266 | if prev_tag == 'S' and tag == 'I': chunk_start = True 267 | if prev_tag == 'O' and tag == 'E': chunk_start = True 268 | if prev_tag == 'O' and tag == 'I': chunk_start = True 269 | 270 | if tag != 'O' and tag != '.' and prev_type != type_: 271 | chunk_start = True 272 | 273 | # these chunks are assumed to have length 1 274 | if tag == '[': chunk_start = True 275 | if tag == ']': chunk_start = True 276 | 277 | return chunk_start 278 | 279 | 280 | def return_report(input_file): 281 | with codecs.open(input_file, "r", "utf8") as f: 282 | counts = evaluate(f) 283 | return report_notprint(counts) 284 | 285 | 286 | def main(argv): 287 | args = parse_args(argv[1:]) 288 | 289 | if args.file is None: 290 | counts = evaluate(sys.stdin, args) 291 | else: 292 | with open(args.file) as f: 293 | counts = evaluate(f, args) 294 | report(counts) 295 | 296 | if __name__ == '__main__': 297 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | 链接:https://pan.baidu.com/s/1NSWtWQSNr7xgBvyMLJM66Q 提取码:32et -------------------------------------------------------------------------------- /data_helper.py: -------------------------------------------------------------------------------- 1 | # Author:duguiming 2 | # Description: 数据处理 3 | # Date: 2020-4-15 4 | import codecs 5 | import math 6 | import random 7 | 8 | from utils import create_dico, create_mapping, zero_digits 9 | from bert import tokenization 10 | 11 | tokenizer = tokenization.FullTokenizer(vocab_file='bert_model/chinese_L-12_H-768_A-12/vocab.txt', 12 | do_lower_case=True) 13 | 14 | 15 | def load_sentences(path, lower, zeros): 16 | """ 17 | Load sentences. A line must contain at least a word and its tag. 18 | Sentences are separated by empty lines. 19 | """ 20 | sentences = [] 21 | sentence = [] 22 | num = 0 23 | for line in codecs.open(path, 'r', 'utf8'): 24 | num += 1 25 | line = zero_digits(line.rstrip()) if zeros else line.rstrip() 26 | if not line: 27 | if len(sentence) > 0: 28 | if 'DOCSTART' not in sentence[0][0]: 29 | sentences.append(sentence) 30 | sentence = [] 31 | else: 32 | if line[0] == " ": 33 | line = "$" + line[1:] 34 | word = line.split() 35 | else: 36 | word = line.split() 37 | assert len(word) >= 2, print([word[0]]) 38 | sentence.append(word) 39 | if len(sentence) > 0: 40 | if 'DOCSTART' not in sentence[0][0]: 41 | sentences.append(sentence) 42 | return sentences 43 | 44 | 45 | def tag_mapping(sentences): 46 | """ 47 | Create a dictionary and a mapping of tags, sorted by frequency. 48 | """ 49 | tags = [[char[-1] for char in s] for s in sentences] 50 | 51 | dico = create_dico(tags) 52 | dico['[SEP]'] = len(dico) + 1 53 | dico['[CLS]'] = len(dico) + 2 54 | 55 | tag_to_id, id_to_tag = create_mapping(dico) 56 | print("Found %i unique named entity tags" % len(dico)) 57 | return dico, tag_to_id, id_to_tag 58 | 59 | 60 | def convert_single_example(char_line, tag_to_id, max_seq_length, tokenizer, label_line): 61 | """ 62 | 将一个样本进行分析,然后将字转化为id, 标签转化为lb 63 | """ 64 | text_list = char_line.split(' ') 65 | label_list = label_line.split(' ') 66 | 67 | tokens = [] 68 | labels = [] 69 | for i, word in enumerate(text_list): 70 | token = tokenizer.tokenize(word) 71 | tokens.extend(token) 72 | label_1 = label_list[i] 73 | for m in range(len(token)): 74 | if m == 0: 75 | labels.append(label_1) 76 | else: 77 | labels.append("X") 78 | # 序列截断 79 | if len(tokens) >= max_seq_length - 1: 80 | tokens = tokens[0:(max_seq_length - 2)] 81 | labels = labels[0:(max_seq_length - 2)] 82 | ntokens = [] 83 | segment_ids = [] 84 | label_ids = [] 85 | ntokens.append("[CLS]") 86 | segment_ids.append(0) 87 | # append("O") or append("[CLS]") not sure! 88 | label_ids.append(tag_to_id["[CLS]"]) 89 | for i, token in enumerate(tokens): 90 | ntokens.append(token) 91 | segment_ids.append(0) 92 | label_ids.append(tag_to_id[labels[i]]) 93 | ntokens.append("[SEP]") 94 | segment_ids.append(0) 95 | # append("O") or append("[SEP]") not sure! 96 | label_ids.append(tag_to_id["[SEP]"]) 97 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 98 | input_mask = [1] * len(input_ids) 99 | 100 | # padding 101 | while len(input_ids) < max_seq_length: 102 | input_ids.append(0) 103 | input_mask.append(0) 104 | segment_ids.append(0) 105 | # we don't concerned about it! 106 | label_ids.append(0) 107 | ntokens.append("**NULL**") 108 | 109 | return input_ids, input_mask, segment_ids, label_ids 110 | 111 | 112 | def prepare_dataset(sentences, max_seq_length, tag_to_id, lower=False, train=True): 113 | """ 114 | Prepare the dataset. Return a list of lists of dictionaries containing: 115 | - word indexes 116 | - word char indexes 117 | - tag indexes 118 | """ 119 | def f(x): 120 | return x.lower() if lower else x 121 | data = [] 122 | for s in sentences: 123 | string = [w[0].strip() for w in s] 124 | char_line = ' '.join(string) # 使用空格把汉字拼起来 125 | text = tokenization.convert_to_unicode(char_line) 126 | 127 | if train: 128 | tags = [w[-1] for w in s] 129 | else: 130 | tags = ['O' for _ in string] 131 | 132 | labels = ' '.join(tags) # 使用空格把标签拼起来 133 | labels = tokenization.convert_to_unicode(labels) 134 | 135 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text, 136 | tag_to_id=tag_to_id, 137 | max_seq_length=max_seq_length, 138 | tokenizer=tokenizer, 139 | label_line=labels) 140 | data.append([string, segment_ids, ids, mask, label_ids]) 141 | 142 | return data 143 | 144 | 145 | class BatchManager(object): 146 | 147 | def __init__(self, data, batch_size): 148 | self.batch_data = self.sort_and_pad(data, batch_size) 149 | self.len_data = len(self.batch_data) 150 | 151 | def sort_and_pad(self, data, batch_size): 152 | num_batch = int(math.ceil(len(data) /batch_size)) 153 | sorted_data = sorted(data, key=lambda x: len(x[0])) 154 | batch_data = list() 155 | for i in range(num_batch): 156 | batch_data.append(self.arrange_batch(sorted_data[int(i*batch_size) : int((i+1)*batch_size)])) 157 | return batch_data 158 | 159 | @staticmethod 160 | def arrange_batch(batch): 161 | ''' 162 | 把batch整理为一个[5, ]的数组 163 | :param batch: 164 | :return: 165 | ''' 166 | strings = [] 167 | segment_ids = [] 168 | chars = [] 169 | mask = [] 170 | targets = [] 171 | for string, seg_ids, char, msk, target in batch: 172 | strings.append(string) 173 | segment_ids.append(seg_ids) 174 | chars.append(char) 175 | mask.append(msk) 176 | targets.append(target) 177 | return [strings, segment_ids, chars, mask, targets] 178 | 179 | @staticmethod 180 | def pad_data(data): 181 | strings = [] 182 | chars = [] 183 | segs = [] 184 | targets = [] 185 | max_length = max([len(sentence[0]) for sentence in data]) 186 | for line in data: 187 | string, segment_ids, char, seg, target = line 188 | padding = [0] * (max_length - len(string)) 189 | strings.append(string + padding) 190 | chars.append(char + padding) 191 | segs.append(seg + padding) 192 | targets.append(target + padding) 193 | return [strings, chars, segs, targets] 194 | 195 | def iter_batch(self, shuffle=False): 196 | if shuffle: 197 | random.shuffle(self.batch_data) 198 | for idx in range(self.len_data): 199 | yield self.batch_data[idx] 200 | 201 | 202 | def input_from_line(line, max_seq_length, tag_to_id): 203 | """ 204 | Take sentence data and return an input for 205 | the training or the evaluation function. 206 | """ 207 | string = [w[0].strip() for w in line] 208 | # chars = [char_to_id[f(w) if f(w) in char_to_id else ''] 209 | # for w in string] 210 | char_line = ' '.join(string) # 使用空格把汉字拼起来 211 | text = tokenization.convert_to_unicode(char_line) 212 | 213 | tags = ['O' for _ in string] 214 | 215 | labels = ' '.join(tags) # 使用空格把标签拼起来 216 | labels = tokenization.convert_to_unicode(labels) 217 | 218 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text, 219 | tag_to_id=tag_to_id, 220 | max_seq_length=max_seq_length, 221 | tokenizer=tokenizer, 222 | label_line=labels) 223 | import numpy as np 224 | segment_ids = np.reshape(segment_ids,(1, max_seq_length)) 225 | ids = np.reshape(ids, (1, max_seq_length)) 226 | mask = np.reshape(mask, (1, max_seq_length)) 227 | label_ids = np.reshape(label_ids, (1, max_seq_length)) 228 | return [string, segment_ids, ids, mask, label_ids] -------------------------------------------------------------------------------- /imgs/bert_bilstm_crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/imgs/bert_bilstm_crf.png -------------------------------------------------------------------------------- /imgs/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/imgs/demo.png -------------------------------------------------------------------------------- /imgs/struct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/imgs/struct.png -------------------------------------------------------------------------------- /log/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/log/.gitkeep -------------------------------------------------------------------------------- /models/BERT_BiLSTM_CRF.py: -------------------------------------------------------------------------------- 1 | # Author: duguiming 2 | # Description: 模型 3 | # Date: 2020-4-15 4 | import tensorflow as tf 5 | from tensorflow.contrib.crf import crf_log_likelihood 6 | from tensorflow.contrib.layers.python.layers import initializers 7 | 8 | from models.base_config import BaseConfig 9 | from bert import modeling 10 | import models.rnncell as rnn 11 | 12 | 13 | class Config(BaseConfig): 14 | batch_size = 128 15 | epoch = 100 16 | print_per_batch = 100 17 | clip = 5 18 | dropout_keep_prob = 0.5 19 | lr = 0.001 20 | optimizer = 'adam' 21 | zeros = False 22 | lower = True 23 | 24 | num_tags = None 25 | lstm_dim = 200 26 | max_seq_len = 128 27 | max_epoch = 100 28 | steps_check = 100 29 | 30 | 31 | class BertBiLSTMCrf(object): 32 | def __init__(self, config): 33 | self.config = config 34 | # add placeholders for the model 35 | self.input_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_ids") 36 | self.input_mask = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_mask") 37 | self.segment_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="segment_ids") 38 | self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="Targets") 39 | # dropout keep prob 40 | self.dropout = tf.placeholder(dtype=tf.float32, name="Dropout") 41 | self.global_step = tf.Variable(0, trainable=False) 42 | self.best_dev_f1 = tf.Variable(0.0, trainable=False) 43 | 44 | self.initializer = initializers.xavier_initializer() 45 | 46 | self.bert_bilstm_crf() 47 | 48 | def bert_bilstm_crf(self): 49 | # parameter 50 | used = tf.sign(tf.abs(self.input_ids)) 51 | length = tf.reduce_sum(used, reduction_indices=1) 52 | self.lengths = tf.cast(length, tf.int32) 53 | self.batch_size = tf.shape(self.input_ids)[0] 54 | self.num_steps = tf.shape(self.input_ids)[-1] 55 | # bert embedding 56 | embedding = self.bert_embedding() 57 | # dropout 58 | lstm_inputs = tf.nn.dropout(embedding, self.dropout) 59 | # bi-directional lstm layer 60 | lstm_outputs = self.biLSTM_layer(lstm_inputs, self.config.lstm_dim, self.lengths) 61 | # logits for tags 62 | self.logits = self.project_layer(lstm_outputs) 63 | # loss of the model 64 | self.loss = self.loss_layer(self.logits, self.lengths) 65 | 66 | # bert模型参数初始化的地方 67 | init_checkpoint = self.config.init_checkpoint 68 | # 获取模型中所有的训练参数。 69 | tvars = tf.trainable_variables() 70 | # 加载BERT模型 71 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, 72 | init_checkpoint) 73 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 74 | print("**** Trainable Variables ****") 75 | # 打印加载模型的参数 76 | train_vars = [] 77 | for var in tvars: 78 | init_string = "" 79 | if var.name in initialized_variable_names: 80 | init_string = ", *INIT_FROM_CKPT*" 81 | else: 82 | train_vars.append(var) 83 | print(" name = %s, shape = %s%s", var.name, var.shape, 84 | init_string) 85 | 86 | optimizer = self.config.optimizer 87 | if optimizer == "adam": 88 | self.opt = tf.train.AdamOptimizer(self.config.lr) 89 | else: 90 | raise KeyError 91 | grads = tf.gradients(self.loss, train_vars) 92 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 93 | 94 | self.train_op = self.opt.apply_gradients( 95 | zip(grads, train_vars), global_step=self.global_step) 96 | 97 | def bert_embedding(self): 98 | # load bert embedding 99 | bert_config = modeling.BertConfig.from_json_file(self.config.bert_config_path) # 配置文件地址。 100 | model = modeling.BertModel( 101 | config=bert_config, 102 | is_training=True, 103 | input_ids=self.input_ids, 104 | input_mask=self.input_mask, 105 | token_type_ids=self.segment_ids, 106 | use_one_hot_embeddings=False) 107 | embedding = model.get_sequence_output() 108 | return embedding 109 | 110 | def biLSTM_layer(self, lstm_inputs, lstm_dim, lengths, name=None): 111 | """ 112 | :param lstm_inputs: [batch_size, num_steps, emb_size] 113 | :return: [batch_size, num_steps, 2*lstm_dim] 114 | """ 115 | with tf.variable_scope("char_BiLSTM" if not name else name): 116 | lstm_cell = {} 117 | for direction in ["forward", "backward"]: 118 | with tf.variable_scope(direction): 119 | lstm_cell[direction] = rnn.CoupledInputForgetGateLSTMCell( 120 | lstm_dim, 121 | use_peepholes=True, 122 | initializer=self.initializer, 123 | state_is_tuple=True) 124 | outputs, final_states = tf.nn.bidirectional_dynamic_rnn( 125 | lstm_cell["forward"], 126 | lstm_cell["backward"], 127 | lstm_inputs, 128 | dtype=tf.float32, 129 | sequence_length=lengths) 130 | return tf.concat(outputs, axis=2) 131 | 132 | def project_layer(self, lstm_outputs, name=None): 133 | """ 134 | hidden layer between lstm layer and logits 135 | :param lstm_outputs: [batch_size, num_steps, emb_size] 136 | :return: [batch_size, num_steps, num_tags] 137 | """ 138 | with tf.variable_scope("project" if not name else name): 139 | with tf.variable_scope("hidden"): 140 | W = tf.get_variable("W", shape=[self.config.lstm_dim * 2, self.config.lstm_dim], 141 | dtype=tf.float32, initializer=self.initializer) 142 | 143 | b = tf.get_variable("b", shape=[self.config.lstm_dim], dtype=tf.float32, 144 | initializer=tf.zeros_initializer()) 145 | output = tf.reshape(lstm_outputs, shape=[-1, self.config.lstm_dim * 2]) 146 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 147 | 148 | # project to score of tags 149 | with tf.variable_scope("logits"): 150 | W = tf.get_variable("W", shape=[self.config.lstm_dim, self.config.num_tags], 151 | dtype=tf.float32, initializer=self.initializer) 152 | 153 | b = tf.get_variable("b", shape=[self.config.num_tags], dtype=tf.float32, 154 | initializer=tf.zeros_initializer()) 155 | 156 | pred = tf.nn.xw_plus_b(hidden, W, b) 157 | 158 | return tf.reshape(pred, [-1, self.num_steps, self.config.num_tags]) 159 | 160 | def loss_layer(self, project_logits, lengths, name=None): 161 | """ 162 | calculate crf loss 163 | :param project_logits: [1, num_steps, num_tags] 164 | :return: scalar loss 165 | """ 166 | with tf.variable_scope("crf_loss" if not name else name): 167 | small = -1000.0 168 | # pad logits for crf loss 169 | start_logits = tf.concat( 170 | [small * tf.ones(shape=[self.batch_size, 1, self.config.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])], 171 | axis=-1) 172 | pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32) 173 | logits = tf.concat([project_logits, pad_logits], axis=-1) 174 | logits = tf.concat([start_logits, logits], axis=1) 175 | targets = tf.concat( 176 | [tf.cast(self.config.num_tags * tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1) 177 | 178 | self.trans = tf.get_variable( 179 | "transitions", 180 | shape=[self.config.num_tags + 1, self.config.num_tags + 1], 181 | initializer=self.initializer) 182 | log_likelihood, self.trans = crf_log_likelihood( 183 | inputs=logits, 184 | tag_indices=targets, 185 | transition_params=self.trans, 186 | sequence_lengths=lengths + 1) 187 | return tf.reduce_mean(-log_likelihood) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/models/__init__.py -------------------------------------------------------------------------------- /models/base_config.py: -------------------------------------------------------------------------------- 1 | # AUthor: duguiming 2 | # Description: 基础的配置项目 3 | # Date: 2020-04-15 4 | import os 5 | 6 | base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 7 | 8 | 9 | class BaseConfig(object): 10 | # bert模型参数初始化的地方 11 | init_checkpoint = "bert_model/chinese_L-12_H-768_A-12/bert_model.ckpt" 12 | bert_config_path = "bert_model/chinese_L-12_H-768_A-12/bert_config.json" 13 | 14 | # 数据的路径 15 | train_path = os.path.join(base_dir, 'data', 'train.txt') 16 | dev_path = os.path.join(base_dir, 'data', 'dev.txt') 17 | test_path = os.path.join(base_dir, 'data', 'test.txt') 18 | 19 | # 存放结果的路径 20 | map_file = os.path.join(base_dir, 'maps.pkl') 21 | result_path = os.path.join(base_dir, 'result') 22 | ckpt_path = os.path.join(base_dir, 'ckpt') 23 | log_path = os.path.join(base_dir, 'log') 24 | log_file = os.path.join(log_path, 'train.log') 25 | checkpoint_path = os.path.join(ckpt_path, 'ner.ckpt') 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /models/rnncell.py: -------------------------------------------------------------------------------- 1 | """Module for constructing RNN Cells.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import collections 7 | import math 8 | import tensorflow as tf 9 | from tensorflow.contrib.compiler import jit 10 | from tensorflow.contrib.layers.python.layers import layers 11 | from tensorflow.python.framework import dtypes 12 | from tensorflow.python.framework import op_def_registry 13 | from tensorflow.python.framework import ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.ops import clip_ops 16 | from tensorflow.python.ops import init_ops 17 | from tensorflow.python.ops import math_ops 18 | from tensorflow.python.ops import nn_ops 19 | from tensorflow.python.ops import random_ops 20 | from tensorflow.python.ops import rnn_cell_impl 21 | from tensorflow.python.ops import variable_scope as vs 22 | from tensorflow.python.platform import tf_logging as logging 23 | from tensorflow.python.util import nest 24 | 25 | 26 | def _get_concat_variable(name, shape, dtype, num_shards): 27 | """Get a sharded variable concatenated into one tensor.""" 28 | sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 29 | if len(sharded_variable) == 1: 30 | return sharded_variable[0] 31 | 32 | concat_name = name + "/concat" 33 | concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 34 | for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 35 | if value.name == concat_full_name: 36 | return value 37 | 38 | concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 39 | ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, 40 | concat_variable) 41 | return concat_variable 42 | 43 | 44 | def _get_sharded_variable(name, shape, dtype, num_shards): 45 | """Get a list of sharded variables with the given dtype.""" 46 | if num_shards > shape[0]: 47 | raise ValueError("Too many shards: shape=%s, num_shards=%d" % 48 | (shape, num_shards)) 49 | unit_shard_size = int(math.floor(shape[0] / num_shards)) 50 | remaining_rows = shape[0] - unit_shard_size * num_shards 51 | 52 | shards = [] 53 | for i in range(num_shards): 54 | current_size = unit_shard_size 55 | if i < remaining_rows: 56 | current_size += 1 57 | shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:], 58 | dtype=dtype)) 59 | return shards 60 | 61 | 62 | class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): 63 | """Long short-term memory unit (LSTM) recurrent network cell. 64 | 65 | The default non-peephole implementation is based on: 66 | 67 | http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 68 | 69 | S. Hochreiter and J. Schmidhuber. 70 | "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 71 | 72 | The peephole implementation is based on: 73 | 74 | https://research.google.com/pubs/archive/43905.pdf 75 | 76 | Hasim Sak, Andrew Senior, and Francoise Beaufays. 77 | "Long short-term memory recurrent neural network architectures for 78 | large scale acoustic modeling." INTERSPEECH, 2014. 79 | 80 | The coupling of input and forget gate is based on: 81 | 82 | http://arxiv.org/pdf/1503.04069.pdf 83 | 84 | Greff et al. "LSTM: A Search Space Odyssey" 85 | 86 | The class uses optional peep-hole connections, and an optional projection 87 | layer. 88 | """ 89 | 90 | def __init__(self, num_units, use_peepholes=False, 91 | initializer=None, num_proj=None, proj_clip=None, 92 | num_unit_shards=1, num_proj_shards=1, 93 | forget_bias=1.0, state_is_tuple=True, 94 | activation=math_ops.tanh, reuse=None): 95 | """Initialize the parameters for an LSTM cell. 96 | 97 | Args: 98 | num_units: int, The number of units in the LSTM cell 99 | use_peepholes: bool, set True to enable diagonal/peephole connections. 100 | initializer: (optional) The initializer to use for the weight and 101 | projection matrices. 102 | num_proj: (optional) int, The output dimensionality for the projection 103 | matrices. If None, no projection is performed. 104 | proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 105 | provided, then the projected values are clipped elementwise to within 106 | `[-proj_clip, proj_clip]`. 107 | num_unit_shards: How to split the weight matrix. If >1, the weight 108 | matrix is stored across num_unit_shards. 109 | num_proj_shards: How to split the projection matrix. If >1, the 110 | projection matrix is stored across num_proj_shards. 111 | forget_bias: Biases of the forget gate are initialized by default to 1 112 | in order to reduce the scale of forgetting at the beginning of 113 | the training. 114 | state_is_tuple: If True, accepted and returned states are 2-tuples of 115 | the `c_state` and `m_state`. By default (False), they are concatenated 116 | along the column axis. This default behavior will soon be deprecated. 117 | activation: Activation function of the inner states. 118 | reuse: (optional) Python boolean describing whether to reuse variables 119 | in an existing scope. If not `True`, and the existing scope already has 120 | the given variables, an error is raised. 121 | """ 122 | super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) 123 | if not state_is_tuple: 124 | logging.warn( 125 | "%s: Using a concatenated state is slower and will soon be " 126 | "deprecated. Use state_is_tuple=True.", self) 127 | self._num_units = num_units 128 | self._use_peepholes = use_peepholes 129 | self._initializer = initializer 130 | self._num_proj = num_proj 131 | self._proj_clip = proj_clip 132 | self._num_unit_shards = num_unit_shards 133 | self._num_proj_shards = num_proj_shards 134 | self._forget_bias = forget_bias 135 | self._state_is_tuple = state_is_tuple 136 | self._activation = activation 137 | self._reuse = reuse 138 | 139 | if num_proj: 140 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 141 | if state_is_tuple else num_units + num_proj) 142 | self._output_size = num_proj 143 | else: 144 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units) 145 | if state_is_tuple else 2 * num_units) 146 | self._output_size = num_units 147 | 148 | @property 149 | def state_size(self): 150 | return self._state_size 151 | 152 | @property 153 | def output_size(self): 154 | return self._output_size 155 | 156 | def call(self, inputs, state): 157 | """Run one step of LSTM. 158 | 159 | Args: 160 | inputs: input Tensor, 2D, batch x num_units. 161 | state: if `state_is_tuple` is False, this must be a state Tensor, 162 | `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 163 | tuple of state Tensors, both `2-D`, with column sizes `c_state` and 164 | `m_state`. 165 | scope: VariableScope for the created subgraph; defaults to "LSTMCell". 166 | 167 | Returns: 168 | A tuple containing: 169 | - A `2-D, [batch x output_dim]`, Tensor representing the output of the 170 | LSTM after reading `inputs` when previous state was `state`. 171 | Here output_dim is: 172 | num_proj if num_proj was set, 173 | num_units otherwise. 174 | - Tensor(s) representing the new state of LSTM after reading `inputs` when 175 | the previous state was `state`. Same type and shape(s) as `state`. 176 | 177 | Raises: 178 | ValueError: If input size cannot be inferred from inputs via 179 | static shape inference. 180 | """ 181 | sigmoid = math_ops.sigmoid 182 | 183 | num_proj = self._num_units if self._num_proj is None else self._num_proj 184 | 185 | if self._state_is_tuple: 186 | (c_prev, m_prev) = state 187 | else: 188 | c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 189 | m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 190 | 191 | dtype = inputs.dtype 192 | input_size = inputs.get_shape().with_rank(2)[1] 193 | 194 | if input_size.value is None: 195 | raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 196 | 197 | # Input gate weights 198 | self.w_xi = tf.get_variable("_w_xi", [input_size.value, self._num_units]) 199 | self.w_hi = tf.get_variable("_w_hi", [self._num_units, self._num_units]) 200 | self.w_ci = tf.get_variable("_w_ci", [self._num_units, self._num_units]) 201 | # Output gate weights 202 | self.w_xo = tf.get_variable("_w_xo", [input_size.value, self._num_units]) 203 | self.w_ho = tf.get_variable("_w_ho", [self._num_units, self._num_units]) 204 | self.w_co = tf.get_variable("_w_co", [self._num_units, self._num_units]) 205 | 206 | # Cell weights 207 | self.w_xc = tf.get_variable("_w_xc", [input_size.value, self._num_units]) 208 | self.w_hc = tf.get_variable("_w_hc", [self._num_units, self._num_units]) 209 | 210 | # Initialize the bias vectors 211 | self.b_i = tf.get_variable("_b_i", [self._num_units], initializer=init_ops.zeros_initializer()) 212 | self.b_c = tf.get_variable("_b_c", [self._num_units], initializer=init_ops.zeros_initializer()) 213 | self.b_o = tf.get_variable("_b_o", [self._num_units], initializer=init_ops.zeros_initializer()) 214 | 215 | i_t = sigmoid(math_ops.matmul(inputs, self.w_xi) + 216 | math_ops.matmul(m_prev, self.w_hi) + 217 | math_ops.matmul(c_prev, self.w_ci) + 218 | self.b_i) 219 | c_t = ((1 - i_t) * c_prev + i_t * self._activation(math_ops.matmul(inputs, self.w_xc) + 220 | math_ops.matmul(m_prev, self.w_hc) + self.b_c)) 221 | 222 | o_t = sigmoid(math_ops.matmul(inputs, self.w_xo) + 223 | math_ops.matmul(m_prev, self.w_ho) + 224 | math_ops.matmul(c_t, self.w_co) + 225 | self.b_o) 226 | 227 | h_t = o_t * self._activation(c_t) 228 | 229 | new_state = (rnn_cell_impl.LSTMStateTuple(c_t, h_t) if self._state_is_tuple else 230 | array_ops.concat([c_t, h_t], 1)) 231 | return h_t, new_state -------------------------------------------------------------------------------- /result/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/result/.gitkeep -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # Author: duguiming 2 | # Description: 运行程序 3 | # Date: 2020-4-15 4 | import os 5 | import pickle 6 | import argparse 7 | 8 | from models.BERT_BiLSTM_CRF import Config, BertBiLSTMCrf 9 | from data_helper import tag_mapping, load_sentences, prepare_dataset, BatchManager 10 | from train_val_test import train, test, demo 11 | from utils import make_path 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Chinese NER') 15 | parser.add_argument('--mode', type=str, required=True, help='train test or demo') 16 | args = parser.parse_args() 17 | 18 | 19 | if __name__ == "__main__": 20 | mode = args.mode 21 | config = Config() 22 | 23 | # load data 24 | train_sentences = load_sentences(config.train_path, config.lower, config.zeros) 25 | dev_sentences = load_sentences(config.dev_path, config.lower, config.zeros) 26 | test_sentences = load_sentences(config.test_path, config.lower, config.zeros) 27 | 28 | # tags dict 29 | if not os.path.isfile(config.map_file): 30 | # Create a dictionary and a mapping for tags 31 | _t, tag_to_id, id_to_tag = tag_mapping(train_sentences) 32 | with open(config.map_file, "wb") as f: 33 | pickle.dump([tag_to_id, id_to_tag], f) 34 | else: 35 | with open(config.map_file, "rb") as f: 36 | tag_to_id, id_to_tag = pickle.load(f) 37 | 38 | config.num_tags = len(tag_to_id) 39 | 40 | train_data = prepare_dataset( 41 | train_sentences, config.max_seq_len, tag_to_id, config.lower 42 | ) 43 | dev_data = prepare_dataset( 44 | dev_sentences, config.max_seq_len, tag_to_id, config.lower 45 | ) 46 | test_data = prepare_dataset( 47 | test_sentences, config.max_seq_len, tag_to_id, config.lower 48 | ) 49 | print("%i / %i / %i sentences in train / dev / test." % ( 50 | len(train_data), 0, len(test_data))) 51 | 52 | train_manager = BatchManager(train_data, config.batch_size) 53 | dev_manager = BatchManager(dev_data, config.batch_size) 54 | test_manager = BatchManager(test_data, config.batch_size) 55 | 56 | model = BertBiLSTMCrf(config) 57 | make_path(config) 58 | 59 | if mode == "train": 60 | train(model, config, train_manager, dev_manager, id_to_tag) 61 | elif mode == "test": 62 | test(model, config, test_manager, id_to_tag) 63 | else: 64 | demo(model, config, id_to_tag, tag_to_id) 65 | -------------------------------------------------------------------------------- /train_val_test.py: -------------------------------------------------------------------------------- 1 | # Author: duguiming 2 | # Description: 训练、验证和测试 3 | # Date:2020-4-25 4 | import tensorflow as tf 5 | import numpy as np 6 | from tensorflow.contrib.crf import viterbi_decode 7 | 8 | from utils import get_logger, test_ner, bio_to_json 9 | from data_helper import input_from_line 10 | 11 | 12 | def get_feed_dict(model, is_train, batch, config): 13 | """ 14 | :param is_train: Flag, True for train batch 15 | :param batch: list train/evaluate data 16 | :return: structured data to feed 17 | """ 18 | _, segment_ids, chars, mask, tags = batch 19 | feed_dict = { 20 | model.input_ids: np.asarray(chars), 21 | model.input_mask: np.asarray(mask), 22 | model.segment_ids: np.asarray(segment_ids), 23 | model.dropout: 1.0, 24 | } 25 | if is_train: 26 | feed_dict[model.targets] = np.asarray(tags) 27 | feed_dict[model.dropout] = config.dropout_keep_prob 28 | return feed_dict 29 | 30 | 31 | def decode(logits, lengths, matrix, config): 32 | """ 33 | :param logits: [batch_size, num_steps, num_tags]float32, logits 34 | :param lengths: [batch_size]int32, real length of each sequence 35 | :param matrix: transaction matrix for inference 36 | :return: 37 | """ 38 | # inference final labels usa viterbi Algorithm 39 | paths = [] 40 | small = -1000.0 41 | start = np.asarray([[small]*config.num_tags +[0]]) 42 | for score, length in zip(logits, lengths): 43 | score = score[:length] 44 | pad = small * np.ones([length, 1]) 45 | logits = np.concatenate([score, pad], axis=1) 46 | logits = np.concatenate([start, logits], axis=0) 47 | path, _ = viterbi_decode(logits, matrix) 48 | 49 | paths.append(path[1:]) 50 | return paths 51 | 52 | 53 | def evaluate_(sess, model, data_manager, id_to_tag, config): 54 | """ 55 | :param sess: session to run the model 56 | :param data: list of data 57 | :param id_to_tag: index to tag name 58 | :return: evaluate result 59 | """ 60 | results = [] 61 | trans = model.trans.eval() 62 | for batch in data_manager.iter_batch(): 63 | strings = batch[0] 64 | labels = batch[-1] 65 | feed_dict = get_feed_dict(model, False, batch, config) 66 | lengths, scores = sess.run([model.lengths, model.logits], feed_dict) 67 | batch_paths = decode(scores, lengths, trans, config) 68 | for i in range(len(strings)): 69 | result = [] 70 | string = strings[i][:lengths[i]] 71 | gold = [id_to_tag[int(x)] for x in labels[i][1:lengths[i]]] 72 | pred = [id_to_tag[int(x)] for x in batch_paths[i][1:lengths[i]]] 73 | for char, gold, pred in zip(string, gold, pred): 74 | result.append(" ".join([char, gold, pred])) 75 | results.append(result) 76 | return results 77 | 78 | 79 | def evaluate(sess, model, name, data_manager, id_to_tag, logger, config): 80 | logger.info("evaluate:{}".format(name)) 81 | ner_results = evaluate_(sess, model, data_manager, id_to_tag, config) 82 | eval_lines = test_ner(ner_results, config.result_path) 83 | for line in eval_lines: 84 | logger.info(line) 85 | f1 = float(eval_lines[1].strip().split()[-1]) 86 | 87 | if name == "dev": 88 | best_test_f1 = model.best_dev_f1.eval() 89 | if f1 > best_test_f1: 90 | tf.assign(model.best_dev_f1, f1).eval() 91 | logger.info("new best dev f1 score:{:>.3f}".format(f1)) 92 | return f1 > best_test_f1 93 | 94 | 95 | def train(model, config, train_manager, dev_manager, id_to_tag): 96 | logger = get_logger(config.log_file) 97 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) 98 | # limit GPU memory 99 | tf_config = tf.ConfigProto() 100 | tf_config.gpu_options.allow_growth = True 101 | steps_per_epoch = train_manager.len_data 102 | with tf.Session(config=tf_config) as sess: 103 | ckpt = tf.train.get_checkpoint_state(config.ckpt_path) 104 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 105 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path) 106 | model.saver.restore(sess, ckpt.model_checkpoint_path) 107 | else: 108 | logger.info("Created model with fresh parameters.") 109 | sess.run(tf.global_variables_initializer()) 110 | 111 | logger.info("start training") 112 | loss = [] 113 | for i in range(config.epoch): 114 | for batch in train_manager.iter_batch(shuffle=True): 115 | feed_dict = get_feed_dict(model, True, batch, config) 116 | global_step, batch_loss, _ = sess.run([model.global_step, model.loss, model.train_op], feed_dict) 117 | 118 | loss.append(batch_loss) 119 | if global_step % config.print_per_batch == 0: 120 | iteration = global_step // steps_per_epoch + 1 121 | logger.info("iteration:{} step:{}/{}, " 122 | "NER loss:{:>9.6f}".format( 123 | iteration, global_step % steps_per_epoch, steps_per_epoch, np.mean(loss))) 124 | loss = [] 125 | best = evaluate(sess, model, "dev", dev_manager, id_to_tag, logger, config) 126 | if best: 127 | saver.save(sess, config.checkpoint_path, global_step=global_step) 128 | 129 | 130 | def test(model, config, test_manager, id_to_tag): 131 | logger = get_logger(config.log_file) 132 | tf_config = tf.ConfigProto() 133 | tf_config.gpu_options.allow_growth = True 134 | with tf.Session(config=tf_config) as sess: 135 | ckpt = tf.train.get_checkpoint_state(config.ckpt_path) 136 | sess.run(tf.global_variables_initializer()) 137 | sess.run(tf.local_variables_initializer()) 138 | saver = tf.train.Saver() 139 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 140 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path) 141 | # saver = tf.train.import_meta_graph('ckpt/ner.ckpt.meta') 142 | # saver.restore(session, tf.train.latest_checkpoint("ckpt/")) 143 | saver.restore(sess, ckpt.model_checkpoint_path) 144 | evaluate(sess, model, 'test', test_manager, id_to_tag, logger, config) 145 | 146 | 147 | def demo(model, config, id_to_tag, tag_to_id): 148 | logger = get_logger(config.log_file) 149 | tf_config = tf.ConfigProto() 150 | tf_config.gpu_options.allow_growth = True 151 | with tf.Session(config=tf_config) as sess: 152 | ckpt = tf.train.get_checkpoint_state(config.ckpt_path) 153 | sess.run(tf.global_variables_initializer()) 154 | sess.run(tf.local_variables_initializer()) 155 | saver = tf.train.Saver() 156 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 157 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path) 158 | # saver = tf.train.import_meta_graph('ckpt/ner.ckpt.meta') 159 | # saver.restore(session, tf.train.latest_checkpoint("ckpt/")) 160 | saver.restore(sess, ckpt.model_checkpoint_path) 161 | while True: 162 | line = input("input sentence, please:") 163 | inputs = input_from_line(line, config.max_seq_len, tag_to_id) 164 | trans = model.trans.eval(sess) 165 | feed_dict = get_feed_dict(model, False, inputs, config) 166 | lengths, scores = sess.run([model.lengths, model.logits], feed_dict) 167 | batch_paths = decode(scores, lengths, trans, config) 168 | tags = [id_to_tag[idx] for idx in batch_paths[0]] 169 | result = bio_to_json(inputs[0], tags[1:-1]) 170 | print(result['entities']) 171 | 172 | 173 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Author:duguiming 2 | # Description: 工具文件 3 | # Date: 2020-4-15 4 | import os 5 | import re 6 | import codecs 7 | import logging 8 | 9 | from conlleval import return_report 10 | 11 | 12 | def create_dico(item_list): 13 | """ 14 | Create a dictionary of items from a list of list of items. 15 | """ 16 | assert type(item_list) is list 17 | dico = {} 18 | for items in item_list: 19 | for item in items: 20 | if item not in dico: 21 | dico[item] = 1 22 | else: 23 | dico[item] += 1 24 | return dico 25 | 26 | 27 | def test_ner(results, path): 28 | """ 29 | Run perl script to evaluate model 30 | """ 31 | output_file = os.path.join(path, "ner_predict.utf8") 32 | with codecs.open(output_file, "w", 'utf8') as f: 33 | to_write = [] 34 | for block in results: 35 | for line in block: 36 | to_write.append(line + "\n") 37 | to_write.append("\n") 38 | f.writelines(to_write) 39 | eval_lines = return_report(output_file) 40 | return eval_lines 41 | 42 | 43 | def create_mapping(dico): 44 | """ 45 | Create a mapping (item to ID / ID to item) from a dictionary. 46 | Items are ordered by decreasing frequency. 47 | """ 48 | sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0])) 49 | id_to_item = {i: v[0] for i, v in enumerate(sorted_items)} 50 | item_to_id = {v: k for k, v in id_to_item.items()} 51 | return item_to_id, id_to_item 52 | 53 | 54 | def zero_digits(s): 55 | """ 56 | Replace every digit in a string by a zero. 57 | """ 58 | return re.sub('\d', '0', s) 59 | 60 | 61 | def get_logger(log_file): 62 | logger = logging.getLogger(log_file) 63 | logger.setLevel(logging.DEBUG) 64 | fh = logging.FileHandler(log_file) 65 | fh.setLevel(logging.DEBUG) 66 | ch = logging.StreamHandler() 67 | ch.setLevel(logging.INFO) 68 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 69 | ch.setFormatter(formatter) 70 | fh.setFormatter(formatter) 71 | logger.addHandler(ch) 72 | logger.addHandler(fh) 73 | return logger 74 | 75 | 76 | def make_path(params): 77 | """ 78 | Make folders for training and evaluation 79 | """ 80 | if not os.path.isdir(params.result_path): 81 | os.makedirs(params.result_path) 82 | if not os.path.isdir(params.ckpt_path): 83 | os.makedirs(params.ckpt_path) 84 | if not os.path.isdir(params.log_path): 85 | os.makedirs(params.log_path) 86 | 87 | 88 | def bio_to_json(string, tags): 89 | item = {"string": string, "entities": []} 90 | entity_name = "" 91 | entity_start = 0 92 | iCount = 0 93 | entity_tag = "" 94 | 95 | for c_idx in range(len(tags)): 96 | c, tag = string[c_idx], tags[c_idx] 97 | if c_idx < len(tags)-1: 98 | tag_next = tags[c_idx+1] 99 | else: 100 | tag_next = '' 101 | 102 | if tag[0] == 'B': 103 | entity_tag = tag[2:] 104 | entity_name = c 105 | entity_start = iCount 106 | if tag_next[2:] != entity_tag: 107 | item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]}) 108 | elif tag[0] == "I": 109 | if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O': 110 | tags[c_idx] = 'O' 111 | pass 112 | else: 113 | entity_name = entity_name + c 114 | if tag_next[2:] != entity_tag: 115 | item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag}) 116 | entity_name = '' 117 | iCount += 1 118 | return item --------------------------------------------------------------------------------