├── .idea
└── modules.xml
├── README.md
├── ckpt
└── .gitkeep
├── conlleval.py
├── data
└── README.md
├── data_helper.py
├── imgs
├── bert_bilstm_crf.png
├── demo.png
└── struct.png
├── log
└── .gitkeep
├── models
├── BERT_BiLSTM_CRF.py
├── __init__.py
├── base_config.py
└── rnncell.py
├── result
└── .gitkeep
├── run.py
├── train_val_test.py
└── utils.py
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 基于BERT+BiLSTM+CRF实现中文命名实体识别
2 | ## 1、目录结构
3 |
4 | - data: 训练数据集
5 | - models: 构造的模型
6 | - result: 存放结果
7 | - ckpt: 存放模型的文件夹
8 | - log: 日志
9 | - conlleval.py: 计算模型性能用
10 | - data_helper: 数据处理
11 | - run.py: 执行程序
12 | - train_val_test.py: 训练、验证和测试
13 | - utils.py: 包含一些用到的功能
14 |
15 |
16 | ## 2、数据
17 |
18 | - 开源数据集集合
19 | - 本项目用到的数据集,在data下README的网盘链接上
20 |
21 |
22 | ## 3、运行
23 | 下载bert到项目路径
24 | 创建bert_model路径,将预训练好的bert模型放到这个路径下解压
25 | 具体结构如下:
26 |
27 | python3 run.py --mode xxx
28 | xxx: train/test/demo,默认为demo
29 |
30 | ## 4、效果
31 | 训练过程:
32 |
33 |
34 | 单句测试:
35 |
36 |
37 | ## 5、参考
38 | [1] https://github.com/yumath/bertNER
39 | [2] https://github.com/google-research/bert
40 |
--------------------------------------------------------------------------------
/ckpt/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/ckpt/.gitkeep
--------------------------------------------------------------------------------
/conlleval.py:
--------------------------------------------------------------------------------
1 | # Python version of the evaluation script from CoNLL'00-
2 | # Originates from: https://github.com/spyysalo/conlleval.py
3 |
4 |
5 | # Intentional differences:
6 | # - accept any space as delimiter by default
7 | # - optional file argument (default STDIN)
8 | # - option to set boundary (-b argument)
9 | # - LaTeX output (-l argument) not supported
10 | # - raw tags (-r argument) not supported
11 |
12 | import sys
13 | import re
14 | import codecs
15 | from collections import defaultdict, namedtuple
16 |
17 | ANY_SPACE = ''
18 |
19 |
20 | class FormatError(Exception):
21 | pass
22 |
23 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
24 |
25 |
26 | class EvalCounts(object):
27 | def __init__(self):
28 | self.correct_chunk = 0 # number of correctly identified chunks
29 | self.correct_tags = 0 # number of correct chunk tags
30 | self.found_correct = 0 # number of chunks in corpus
31 | self.found_guessed = 0 # number of identified chunks
32 | self.token_counter = 0 # token counter (ignores sentence breaks)
33 |
34 | # counts by type
35 | self.t_correct_chunk = defaultdict(int)
36 | self.t_found_correct = defaultdict(int)
37 | self.t_found_guessed = defaultdict(int)
38 |
39 |
40 | def parse_args(argv):
41 | import argparse
42 | parser = argparse.ArgumentParser(
43 | description='evaluate tagging results using CoNLL criteria',
44 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
45 | )
46 | arg = parser.add_argument
47 | arg('-b', '--boundary', metavar='STR', default='-X-',
48 | help='sentence boundary')
49 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
50 | help='character delimiting items in input')
51 | arg('-o', '--otag', metavar='CHAR', default='O',
52 | help='alternative outside tag')
53 | arg('file', nargs='?', default=None)
54 | return parser.parse_args(argv)
55 |
56 |
57 | def parse_tag(t):
58 | m = re.match(r'^([^-]*)-(.*)$', t)
59 | return m.groups() if m else (t, '')
60 |
61 |
62 | def evaluate(iterable, options=None):
63 | if options is None:
64 | options = parse_args([]) # use defaults
65 |
66 | counts = EvalCounts()
67 | num_features = None # number of features per line
68 | in_correct = False # currently processed chunks is correct until now
69 | last_correct = 'O' # previous chunk tag in corpus
70 | last_correct_type = '' # type of previously identified chunk tag
71 | last_guessed = 'O' # previously identified chunk tag
72 | last_guessed_type = '' # type of previous chunk tag in corpus
73 |
74 | for line in iterable:
75 | line = line.rstrip('\r\n')
76 |
77 | if options.delimiter == ANY_SPACE:
78 | features = line.split()
79 | else:
80 | features = line.split(options.delimiter)
81 |
82 | if num_features is None:
83 | num_features = len(features)
84 | elif num_features != len(features) and len(features) != 0:
85 | raise FormatError('unexpected number of features: %d (%d)' %
86 | (len(features), num_features))
87 |
88 | if len(features) == 0 or features[0] == options.boundary:
89 | features = [options.boundary, 'O', 'O']
90 | if len(features) < 3:
91 | raise FormatError('unexpected number of features in line %s' % line)
92 |
93 | guessed, guessed_type = parse_tag(features.pop())
94 | correct, correct_type = parse_tag(features.pop())
95 | first_item = features.pop(0)
96 |
97 | if first_item == options.boundary:
98 | guessed = 'O'
99 |
100 | end_correct = end_of_chunk(last_correct, correct,
101 | last_correct_type, correct_type)
102 | end_guessed = end_of_chunk(last_guessed, guessed,
103 | last_guessed_type, guessed_type)
104 | start_correct = start_of_chunk(last_correct, correct,
105 | last_correct_type, correct_type)
106 | start_guessed = start_of_chunk(last_guessed, guessed,
107 | last_guessed_type, guessed_type)
108 |
109 | if in_correct:
110 | if (end_correct and end_guessed and
111 | last_guessed_type == last_correct_type):
112 | in_correct = False
113 | counts.correct_chunk += 1
114 | counts.t_correct_chunk[last_correct_type] += 1
115 | elif (end_correct != end_guessed or guessed_type != correct_type):
116 | in_correct = False
117 |
118 | if start_correct and start_guessed and guessed_type == correct_type:
119 | in_correct = True
120 |
121 | if start_correct:
122 | counts.found_correct += 1
123 | counts.t_found_correct[correct_type] += 1
124 | if start_guessed:
125 | counts.found_guessed += 1
126 | counts.t_found_guessed[guessed_type] += 1
127 | if first_item != options.boundary:
128 | if correct == guessed and guessed_type == correct_type:
129 | counts.correct_tags += 1
130 | counts.token_counter += 1
131 |
132 | last_guessed = guessed
133 | last_correct = correct
134 | last_guessed_type = guessed_type
135 | last_correct_type = correct_type
136 |
137 | if in_correct:
138 | counts.correct_chunk += 1
139 | counts.t_correct_chunk[last_correct_type] += 1
140 |
141 | return counts
142 |
143 |
144 | def uniq(iterable):
145 | seen = set()
146 | return [i for i in iterable if not (i in seen or seen.add(i))]
147 |
148 |
149 | def calculate_metrics(correct, guessed, total):
150 | tp, fp, fn = correct, guessed-correct, total-correct
151 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
152 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
153 | f = 0 if p + r == 0 else 2 * p * r / (p + r)
154 | return Metrics(tp, fp, fn, p, r, f)
155 |
156 |
157 | def metrics(counts):
158 | c = counts
159 | overall = calculate_metrics(
160 | c.correct_chunk, c.found_guessed, c.found_correct
161 | )
162 | by_type = {}
163 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
164 | by_type[t] = calculate_metrics(
165 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
166 | )
167 | return overall, by_type
168 |
169 |
170 | def report(counts, out=None):
171 | if out is None:
172 | out = sys.stdout
173 |
174 | overall, by_type = metrics(counts)
175 |
176 | c = counts
177 | out.write('processed %d tokens with %d phrases; ' %
178 | (c.token_counter, c.found_correct))
179 | out.write('found: %d phrases; correct: %d.\n' %
180 | (c.found_guessed, c.correct_chunk))
181 |
182 | if c.token_counter > 0:
183 | out.write('accuracy: %6.2f%%; ' %
184 | (100.*c.correct_tags/c.token_counter))
185 | out.write('precision: %6.2f%%; ' % (100.*overall.prec))
186 | out.write('recall: %6.2f%%; ' % (100.*overall.rec))
187 | out.write('FB1: %6.2f\n' % (100.*overall.fscore))
188 |
189 | for i, m in sorted(by_type.items()):
190 | out.write('%17s: ' % i)
191 | out.write('precision: %6.2f%%; ' % (100.*m.prec))
192 | out.write('recall: %6.2f%%; ' % (100.*m.rec))
193 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
194 |
195 |
196 | def report_notprint(counts, out=None):
197 | if out is None:
198 | out = sys.stdout
199 |
200 | overall, by_type = metrics(counts)
201 |
202 | c = counts
203 | final_report = []
204 | line = []
205 | line.append('processed %d tokens with %d phrases; ' %
206 | (c.token_counter, c.found_correct))
207 | line.append('found: %d phrases; correct: %d.\n' %
208 | (c.found_guessed, c.correct_chunk))
209 | final_report.append("".join(line))
210 |
211 | if c.token_counter > 0:
212 | line = []
213 | line.append('accuracy: %6.2f%%; ' %
214 | (100.*c.correct_tags/c.token_counter))
215 | line.append('precision: %6.2f%%; ' % (100.*overall.prec))
216 | line.append('recall: %6.2f%%; ' % (100.*overall.rec))
217 | line.append('FB1: %6.2f\n' % (100.*overall.fscore))
218 | final_report.append("".join(line))
219 |
220 | for i, m in sorted(by_type.items()):
221 | line = []
222 | line.append('%17s: ' % i)
223 | line.append('precision: %6.2f%%; ' % (100.*m.prec))
224 | line.append('recall: %6.2f%%; ' % (100.*m.rec))
225 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
226 | final_report.append("".join(line))
227 | return final_report
228 |
229 |
230 | def end_of_chunk(prev_tag, tag, prev_type, type_):
231 | # check if a chunk ended between the previous and current word
232 | # arguments: previous and current chunk tags, previous and current types
233 | chunk_end = False
234 |
235 | if prev_tag == 'E': chunk_end = True
236 | if prev_tag == 'S': chunk_end = True
237 |
238 | if prev_tag == 'B' and tag == 'B': chunk_end = True
239 | if prev_tag == 'B' and tag == 'S': chunk_end = True
240 | if prev_tag == 'B' and tag == 'O': chunk_end = True
241 | if prev_tag == 'I' and tag == 'B': chunk_end = True
242 | if prev_tag == 'I' and tag == 'S': chunk_end = True
243 | if prev_tag == 'I' and tag == 'O': chunk_end = True
244 |
245 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
246 | chunk_end = True
247 |
248 | # these chunks are assumed to have length 1
249 | if prev_tag == ']': chunk_end = True
250 | if prev_tag == '[': chunk_end = True
251 |
252 | return chunk_end
253 |
254 |
255 | def start_of_chunk(prev_tag, tag, prev_type, type_):
256 | # check if a chunk started between the previous and current word
257 | # arguments: previous and current chunk tags, previous and current types
258 | chunk_start = False
259 |
260 | if tag == 'B': chunk_start = True
261 | if tag == 'S': chunk_start = True
262 |
263 | if prev_tag == 'E' and tag == 'E': chunk_start = True
264 | if prev_tag == 'E' and tag == 'I': chunk_start = True
265 | if prev_tag == 'S' and tag == 'E': chunk_start = True
266 | if prev_tag == 'S' and tag == 'I': chunk_start = True
267 | if prev_tag == 'O' and tag == 'E': chunk_start = True
268 | if prev_tag == 'O' and tag == 'I': chunk_start = True
269 |
270 | if tag != 'O' and tag != '.' and prev_type != type_:
271 | chunk_start = True
272 |
273 | # these chunks are assumed to have length 1
274 | if tag == '[': chunk_start = True
275 | if tag == ']': chunk_start = True
276 |
277 | return chunk_start
278 |
279 |
280 | def return_report(input_file):
281 | with codecs.open(input_file, "r", "utf8") as f:
282 | counts = evaluate(f)
283 | return report_notprint(counts)
284 |
285 |
286 | def main(argv):
287 | args = parse_args(argv[1:])
288 |
289 | if args.file is None:
290 | counts = evaluate(sys.stdin, args)
291 | else:
292 | with open(args.file) as f:
293 | counts = evaluate(f, args)
294 | report(counts)
295 |
296 | if __name__ == '__main__':
297 | sys.exit(main(sys.argv))
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | 链接:https://pan.baidu.com/s/1NSWtWQSNr7xgBvyMLJM66Q 提取码:32et
--------------------------------------------------------------------------------
/data_helper.py:
--------------------------------------------------------------------------------
1 | # Author:duguiming
2 | # Description: 数据处理
3 | # Date: 2020-4-15
4 | import codecs
5 | import math
6 | import random
7 |
8 | from utils import create_dico, create_mapping, zero_digits
9 | from bert import tokenization
10 |
11 | tokenizer = tokenization.FullTokenizer(vocab_file='bert_model/chinese_L-12_H-768_A-12/vocab.txt',
12 | do_lower_case=True)
13 |
14 |
15 | def load_sentences(path, lower, zeros):
16 | """
17 | Load sentences. A line must contain at least a word and its tag.
18 | Sentences are separated by empty lines.
19 | """
20 | sentences = []
21 | sentence = []
22 | num = 0
23 | for line in codecs.open(path, 'r', 'utf8'):
24 | num += 1
25 | line = zero_digits(line.rstrip()) if zeros else line.rstrip()
26 | if not line:
27 | if len(sentence) > 0:
28 | if 'DOCSTART' not in sentence[0][0]:
29 | sentences.append(sentence)
30 | sentence = []
31 | else:
32 | if line[0] == " ":
33 | line = "$" + line[1:]
34 | word = line.split()
35 | else:
36 | word = line.split()
37 | assert len(word) >= 2, print([word[0]])
38 | sentence.append(word)
39 | if len(sentence) > 0:
40 | if 'DOCSTART' not in sentence[0][0]:
41 | sentences.append(sentence)
42 | return sentences
43 |
44 |
45 | def tag_mapping(sentences):
46 | """
47 | Create a dictionary and a mapping of tags, sorted by frequency.
48 | """
49 | tags = [[char[-1] for char in s] for s in sentences]
50 |
51 | dico = create_dico(tags)
52 | dico['[SEP]'] = len(dico) + 1
53 | dico['[CLS]'] = len(dico) + 2
54 |
55 | tag_to_id, id_to_tag = create_mapping(dico)
56 | print("Found %i unique named entity tags" % len(dico))
57 | return dico, tag_to_id, id_to_tag
58 |
59 |
60 | def convert_single_example(char_line, tag_to_id, max_seq_length, tokenizer, label_line):
61 | """
62 | 将一个样本进行分析,然后将字转化为id, 标签转化为lb
63 | """
64 | text_list = char_line.split(' ')
65 | label_list = label_line.split(' ')
66 |
67 | tokens = []
68 | labels = []
69 | for i, word in enumerate(text_list):
70 | token = tokenizer.tokenize(word)
71 | tokens.extend(token)
72 | label_1 = label_list[i]
73 | for m in range(len(token)):
74 | if m == 0:
75 | labels.append(label_1)
76 | else:
77 | labels.append("X")
78 | # 序列截断
79 | if len(tokens) >= max_seq_length - 1:
80 | tokens = tokens[0:(max_seq_length - 2)]
81 | labels = labels[0:(max_seq_length - 2)]
82 | ntokens = []
83 | segment_ids = []
84 | label_ids = []
85 | ntokens.append("[CLS]")
86 | segment_ids.append(0)
87 | # append("O") or append("[CLS]") not sure!
88 | label_ids.append(tag_to_id["[CLS]"])
89 | for i, token in enumerate(tokens):
90 | ntokens.append(token)
91 | segment_ids.append(0)
92 | label_ids.append(tag_to_id[labels[i]])
93 | ntokens.append("[SEP]")
94 | segment_ids.append(0)
95 | # append("O") or append("[SEP]") not sure!
96 | label_ids.append(tag_to_id["[SEP]"])
97 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
98 | input_mask = [1] * len(input_ids)
99 |
100 | # padding
101 | while len(input_ids) < max_seq_length:
102 | input_ids.append(0)
103 | input_mask.append(0)
104 | segment_ids.append(0)
105 | # we don't concerned about it!
106 | label_ids.append(0)
107 | ntokens.append("**NULL**")
108 |
109 | return input_ids, input_mask, segment_ids, label_ids
110 |
111 |
112 | def prepare_dataset(sentences, max_seq_length, tag_to_id, lower=False, train=True):
113 | """
114 | Prepare the dataset. Return a list of lists of dictionaries containing:
115 | - word indexes
116 | - word char indexes
117 | - tag indexes
118 | """
119 | def f(x):
120 | return x.lower() if lower else x
121 | data = []
122 | for s in sentences:
123 | string = [w[0].strip() for w in s]
124 | char_line = ' '.join(string) # 使用空格把汉字拼起来
125 | text = tokenization.convert_to_unicode(char_line)
126 |
127 | if train:
128 | tags = [w[-1] for w in s]
129 | else:
130 | tags = ['O' for _ in string]
131 |
132 | labels = ' '.join(tags) # 使用空格把标签拼起来
133 | labels = tokenization.convert_to_unicode(labels)
134 |
135 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text,
136 | tag_to_id=tag_to_id,
137 | max_seq_length=max_seq_length,
138 | tokenizer=tokenizer,
139 | label_line=labels)
140 | data.append([string, segment_ids, ids, mask, label_ids])
141 |
142 | return data
143 |
144 |
145 | class BatchManager(object):
146 |
147 | def __init__(self, data, batch_size):
148 | self.batch_data = self.sort_and_pad(data, batch_size)
149 | self.len_data = len(self.batch_data)
150 |
151 | def sort_and_pad(self, data, batch_size):
152 | num_batch = int(math.ceil(len(data) /batch_size))
153 | sorted_data = sorted(data, key=lambda x: len(x[0]))
154 | batch_data = list()
155 | for i in range(num_batch):
156 | batch_data.append(self.arrange_batch(sorted_data[int(i*batch_size) : int((i+1)*batch_size)]))
157 | return batch_data
158 |
159 | @staticmethod
160 | def arrange_batch(batch):
161 | '''
162 | 把batch整理为一个[5, ]的数组
163 | :param batch:
164 | :return:
165 | '''
166 | strings = []
167 | segment_ids = []
168 | chars = []
169 | mask = []
170 | targets = []
171 | for string, seg_ids, char, msk, target in batch:
172 | strings.append(string)
173 | segment_ids.append(seg_ids)
174 | chars.append(char)
175 | mask.append(msk)
176 | targets.append(target)
177 | return [strings, segment_ids, chars, mask, targets]
178 |
179 | @staticmethod
180 | def pad_data(data):
181 | strings = []
182 | chars = []
183 | segs = []
184 | targets = []
185 | max_length = max([len(sentence[0]) for sentence in data])
186 | for line in data:
187 | string, segment_ids, char, seg, target = line
188 | padding = [0] * (max_length - len(string))
189 | strings.append(string + padding)
190 | chars.append(char + padding)
191 | segs.append(seg + padding)
192 | targets.append(target + padding)
193 | return [strings, chars, segs, targets]
194 |
195 | def iter_batch(self, shuffle=False):
196 | if shuffle:
197 | random.shuffle(self.batch_data)
198 | for idx in range(self.len_data):
199 | yield self.batch_data[idx]
200 |
201 |
202 | def input_from_line(line, max_seq_length, tag_to_id):
203 | """
204 | Take sentence data and return an input for
205 | the training or the evaluation function.
206 | """
207 | string = [w[0].strip() for w in line]
208 | # chars = [char_to_id[f(w) if f(w) in char_to_id else '']
209 | # for w in string]
210 | char_line = ' '.join(string) # 使用空格把汉字拼起来
211 | text = tokenization.convert_to_unicode(char_line)
212 |
213 | tags = ['O' for _ in string]
214 |
215 | labels = ' '.join(tags) # 使用空格把标签拼起来
216 | labels = tokenization.convert_to_unicode(labels)
217 |
218 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text,
219 | tag_to_id=tag_to_id,
220 | max_seq_length=max_seq_length,
221 | tokenizer=tokenizer,
222 | label_line=labels)
223 | import numpy as np
224 | segment_ids = np.reshape(segment_ids,(1, max_seq_length))
225 | ids = np.reshape(ids, (1, max_seq_length))
226 | mask = np.reshape(mask, (1, max_seq_length))
227 | label_ids = np.reshape(label_ids, (1, max_seq_length))
228 | return [string, segment_ids, ids, mask, label_ids]
--------------------------------------------------------------------------------
/imgs/bert_bilstm_crf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/imgs/bert_bilstm_crf.png
--------------------------------------------------------------------------------
/imgs/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/imgs/demo.png
--------------------------------------------------------------------------------
/imgs/struct.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/imgs/struct.png
--------------------------------------------------------------------------------
/log/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/log/.gitkeep
--------------------------------------------------------------------------------
/models/BERT_BiLSTM_CRF.py:
--------------------------------------------------------------------------------
1 | # Author: duguiming
2 | # Description: 模型
3 | # Date: 2020-4-15
4 | import tensorflow as tf
5 | from tensorflow.contrib.crf import crf_log_likelihood
6 | from tensorflow.contrib.layers.python.layers import initializers
7 |
8 | from models.base_config import BaseConfig
9 | from bert import modeling
10 | import models.rnncell as rnn
11 |
12 |
13 | class Config(BaseConfig):
14 | batch_size = 128
15 | epoch = 100
16 | print_per_batch = 100
17 | clip = 5
18 | dropout_keep_prob = 0.5
19 | lr = 0.001
20 | optimizer = 'adam'
21 | zeros = False
22 | lower = True
23 |
24 | num_tags = None
25 | lstm_dim = 200
26 | max_seq_len = 128
27 | max_epoch = 100
28 | steps_check = 100
29 |
30 |
31 | class BertBiLSTMCrf(object):
32 | def __init__(self, config):
33 | self.config = config
34 | # add placeholders for the model
35 | self.input_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_ids")
36 | self.input_mask = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_mask")
37 | self.segment_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="segment_ids")
38 | self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="Targets")
39 | # dropout keep prob
40 | self.dropout = tf.placeholder(dtype=tf.float32, name="Dropout")
41 | self.global_step = tf.Variable(0, trainable=False)
42 | self.best_dev_f1 = tf.Variable(0.0, trainable=False)
43 |
44 | self.initializer = initializers.xavier_initializer()
45 |
46 | self.bert_bilstm_crf()
47 |
48 | def bert_bilstm_crf(self):
49 | # parameter
50 | used = tf.sign(tf.abs(self.input_ids))
51 | length = tf.reduce_sum(used, reduction_indices=1)
52 | self.lengths = tf.cast(length, tf.int32)
53 | self.batch_size = tf.shape(self.input_ids)[0]
54 | self.num_steps = tf.shape(self.input_ids)[-1]
55 | # bert embedding
56 | embedding = self.bert_embedding()
57 | # dropout
58 | lstm_inputs = tf.nn.dropout(embedding, self.dropout)
59 | # bi-directional lstm layer
60 | lstm_outputs = self.biLSTM_layer(lstm_inputs, self.config.lstm_dim, self.lengths)
61 | # logits for tags
62 | self.logits = self.project_layer(lstm_outputs)
63 | # loss of the model
64 | self.loss = self.loss_layer(self.logits, self.lengths)
65 |
66 | # bert模型参数初始化的地方
67 | init_checkpoint = self.config.init_checkpoint
68 | # 获取模型中所有的训练参数。
69 | tvars = tf.trainable_variables()
70 | # 加载BERT模型
71 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
72 | init_checkpoint)
73 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
74 | print("**** Trainable Variables ****")
75 | # 打印加载模型的参数
76 | train_vars = []
77 | for var in tvars:
78 | init_string = ""
79 | if var.name in initialized_variable_names:
80 | init_string = ", *INIT_FROM_CKPT*"
81 | else:
82 | train_vars.append(var)
83 | print(" name = %s, shape = %s%s", var.name, var.shape,
84 | init_string)
85 |
86 | optimizer = self.config.optimizer
87 | if optimizer == "adam":
88 | self.opt = tf.train.AdamOptimizer(self.config.lr)
89 | else:
90 | raise KeyError
91 | grads = tf.gradients(self.loss, train_vars)
92 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
93 |
94 | self.train_op = self.opt.apply_gradients(
95 | zip(grads, train_vars), global_step=self.global_step)
96 |
97 | def bert_embedding(self):
98 | # load bert embedding
99 | bert_config = modeling.BertConfig.from_json_file(self.config.bert_config_path) # 配置文件地址。
100 | model = modeling.BertModel(
101 | config=bert_config,
102 | is_training=True,
103 | input_ids=self.input_ids,
104 | input_mask=self.input_mask,
105 | token_type_ids=self.segment_ids,
106 | use_one_hot_embeddings=False)
107 | embedding = model.get_sequence_output()
108 | return embedding
109 |
110 | def biLSTM_layer(self, lstm_inputs, lstm_dim, lengths, name=None):
111 | """
112 | :param lstm_inputs: [batch_size, num_steps, emb_size]
113 | :return: [batch_size, num_steps, 2*lstm_dim]
114 | """
115 | with tf.variable_scope("char_BiLSTM" if not name else name):
116 | lstm_cell = {}
117 | for direction in ["forward", "backward"]:
118 | with tf.variable_scope(direction):
119 | lstm_cell[direction] = rnn.CoupledInputForgetGateLSTMCell(
120 | lstm_dim,
121 | use_peepholes=True,
122 | initializer=self.initializer,
123 | state_is_tuple=True)
124 | outputs, final_states = tf.nn.bidirectional_dynamic_rnn(
125 | lstm_cell["forward"],
126 | lstm_cell["backward"],
127 | lstm_inputs,
128 | dtype=tf.float32,
129 | sequence_length=lengths)
130 | return tf.concat(outputs, axis=2)
131 |
132 | def project_layer(self, lstm_outputs, name=None):
133 | """
134 | hidden layer between lstm layer and logits
135 | :param lstm_outputs: [batch_size, num_steps, emb_size]
136 | :return: [batch_size, num_steps, num_tags]
137 | """
138 | with tf.variable_scope("project" if not name else name):
139 | with tf.variable_scope("hidden"):
140 | W = tf.get_variable("W", shape=[self.config.lstm_dim * 2, self.config.lstm_dim],
141 | dtype=tf.float32, initializer=self.initializer)
142 |
143 | b = tf.get_variable("b", shape=[self.config.lstm_dim], dtype=tf.float32,
144 | initializer=tf.zeros_initializer())
145 | output = tf.reshape(lstm_outputs, shape=[-1, self.config.lstm_dim * 2])
146 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b))
147 |
148 | # project to score of tags
149 | with tf.variable_scope("logits"):
150 | W = tf.get_variable("W", shape=[self.config.lstm_dim, self.config.num_tags],
151 | dtype=tf.float32, initializer=self.initializer)
152 |
153 | b = tf.get_variable("b", shape=[self.config.num_tags], dtype=tf.float32,
154 | initializer=tf.zeros_initializer())
155 |
156 | pred = tf.nn.xw_plus_b(hidden, W, b)
157 |
158 | return tf.reshape(pred, [-1, self.num_steps, self.config.num_tags])
159 |
160 | def loss_layer(self, project_logits, lengths, name=None):
161 | """
162 | calculate crf loss
163 | :param project_logits: [1, num_steps, num_tags]
164 | :return: scalar loss
165 | """
166 | with tf.variable_scope("crf_loss" if not name else name):
167 | small = -1000.0
168 | # pad logits for crf loss
169 | start_logits = tf.concat(
170 | [small * tf.ones(shape=[self.batch_size, 1, self.config.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])],
171 | axis=-1)
172 | pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
173 | logits = tf.concat([project_logits, pad_logits], axis=-1)
174 | logits = tf.concat([start_logits, logits], axis=1)
175 | targets = tf.concat(
176 | [tf.cast(self.config.num_tags * tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1)
177 |
178 | self.trans = tf.get_variable(
179 | "transitions",
180 | shape=[self.config.num_tags + 1, self.config.num_tags + 1],
181 | initializer=self.initializer)
182 | log_likelihood, self.trans = crf_log_likelihood(
183 | inputs=logits,
184 | tag_indices=targets,
185 | transition_params=self.trans,
186 | sequence_lengths=lengths + 1)
187 | return tf.reduce_mean(-log_likelihood)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/models/__init__.py
--------------------------------------------------------------------------------
/models/base_config.py:
--------------------------------------------------------------------------------
1 | # AUthor: duguiming
2 | # Description: 基础的配置项目
3 | # Date: 2020-04-15
4 | import os
5 |
6 | base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
7 |
8 |
9 | class BaseConfig(object):
10 | # bert模型参数初始化的地方
11 | init_checkpoint = "bert_model/chinese_L-12_H-768_A-12/bert_model.ckpt"
12 | bert_config_path = "bert_model/chinese_L-12_H-768_A-12/bert_config.json"
13 |
14 | # 数据的路径
15 | train_path = os.path.join(base_dir, 'data', 'train.txt')
16 | dev_path = os.path.join(base_dir, 'data', 'dev.txt')
17 | test_path = os.path.join(base_dir, 'data', 'test.txt')
18 |
19 | # 存放结果的路径
20 | map_file = os.path.join(base_dir, 'maps.pkl')
21 | result_path = os.path.join(base_dir, 'result')
22 | ckpt_path = os.path.join(base_dir, 'ckpt')
23 | log_path = os.path.join(base_dir, 'log')
24 | log_file = os.path.join(log_path, 'train.log')
25 | checkpoint_path = os.path.join(ckpt_path, 'ner.ckpt')
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/models/rnncell.py:
--------------------------------------------------------------------------------
1 | """Module for constructing RNN Cells."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import collections
7 | import math
8 | import tensorflow as tf
9 | from tensorflow.contrib.compiler import jit
10 | from tensorflow.contrib.layers.python.layers import layers
11 | from tensorflow.python.framework import dtypes
12 | from tensorflow.python.framework import op_def_registry
13 | from tensorflow.python.framework import ops
14 | from tensorflow.python.ops import array_ops
15 | from tensorflow.python.ops import clip_ops
16 | from tensorflow.python.ops import init_ops
17 | from tensorflow.python.ops import math_ops
18 | from tensorflow.python.ops import nn_ops
19 | from tensorflow.python.ops import random_ops
20 | from tensorflow.python.ops import rnn_cell_impl
21 | from tensorflow.python.ops import variable_scope as vs
22 | from tensorflow.python.platform import tf_logging as logging
23 | from tensorflow.python.util import nest
24 |
25 |
26 | def _get_concat_variable(name, shape, dtype, num_shards):
27 | """Get a sharded variable concatenated into one tensor."""
28 | sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
29 | if len(sharded_variable) == 1:
30 | return sharded_variable[0]
31 |
32 | concat_name = name + "/concat"
33 | concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
34 | for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
35 | if value.name == concat_full_name:
36 | return value
37 |
38 | concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
39 | ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
40 | concat_variable)
41 | return concat_variable
42 |
43 |
44 | def _get_sharded_variable(name, shape, dtype, num_shards):
45 | """Get a list of sharded variables with the given dtype."""
46 | if num_shards > shape[0]:
47 | raise ValueError("Too many shards: shape=%s, num_shards=%d" %
48 | (shape, num_shards))
49 | unit_shard_size = int(math.floor(shape[0] / num_shards))
50 | remaining_rows = shape[0] - unit_shard_size * num_shards
51 |
52 | shards = []
53 | for i in range(num_shards):
54 | current_size = unit_shard_size
55 | if i < remaining_rows:
56 | current_size += 1
57 | shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
58 | dtype=dtype))
59 | return shards
60 |
61 |
62 | class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
63 | """Long short-term memory unit (LSTM) recurrent network cell.
64 |
65 | The default non-peephole implementation is based on:
66 |
67 | http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
68 |
69 | S. Hochreiter and J. Schmidhuber.
70 | "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
71 |
72 | The peephole implementation is based on:
73 |
74 | https://research.google.com/pubs/archive/43905.pdf
75 |
76 | Hasim Sak, Andrew Senior, and Francoise Beaufays.
77 | "Long short-term memory recurrent neural network architectures for
78 | large scale acoustic modeling." INTERSPEECH, 2014.
79 |
80 | The coupling of input and forget gate is based on:
81 |
82 | http://arxiv.org/pdf/1503.04069.pdf
83 |
84 | Greff et al. "LSTM: A Search Space Odyssey"
85 |
86 | The class uses optional peep-hole connections, and an optional projection
87 | layer.
88 | """
89 |
90 | def __init__(self, num_units, use_peepholes=False,
91 | initializer=None, num_proj=None, proj_clip=None,
92 | num_unit_shards=1, num_proj_shards=1,
93 | forget_bias=1.0, state_is_tuple=True,
94 | activation=math_ops.tanh, reuse=None):
95 | """Initialize the parameters for an LSTM cell.
96 |
97 | Args:
98 | num_units: int, The number of units in the LSTM cell
99 | use_peepholes: bool, set True to enable diagonal/peephole connections.
100 | initializer: (optional) The initializer to use for the weight and
101 | projection matrices.
102 | num_proj: (optional) int, The output dimensionality for the projection
103 | matrices. If None, no projection is performed.
104 | proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
105 | provided, then the projected values are clipped elementwise to within
106 | `[-proj_clip, proj_clip]`.
107 | num_unit_shards: How to split the weight matrix. If >1, the weight
108 | matrix is stored across num_unit_shards.
109 | num_proj_shards: How to split the projection matrix. If >1, the
110 | projection matrix is stored across num_proj_shards.
111 | forget_bias: Biases of the forget gate are initialized by default to 1
112 | in order to reduce the scale of forgetting at the beginning of
113 | the training.
114 | state_is_tuple: If True, accepted and returned states are 2-tuples of
115 | the `c_state` and `m_state`. By default (False), they are concatenated
116 | along the column axis. This default behavior will soon be deprecated.
117 | activation: Activation function of the inner states.
118 | reuse: (optional) Python boolean describing whether to reuse variables
119 | in an existing scope. If not `True`, and the existing scope already has
120 | the given variables, an error is raised.
121 | """
122 | super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
123 | if not state_is_tuple:
124 | logging.warn(
125 | "%s: Using a concatenated state is slower and will soon be "
126 | "deprecated. Use state_is_tuple=True.", self)
127 | self._num_units = num_units
128 | self._use_peepholes = use_peepholes
129 | self._initializer = initializer
130 | self._num_proj = num_proj
131 | self._proj_clip = proj_clip
132 | self._num_unit_shards = num_unit_shards
133 | self._num_proj_shards = num_proj_shards
134 | self._forget_bias = forget_bias
135 | self._state_is_tuple = state_is_tuple
136 | self._activation = activation
137 | self._reuse = reuse
138 |
139 | if num_proj:
140 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
141 | if state_is_tuple else num_units + num_proj)
142 | self._output_size = num_proj
143 | else:
144 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
145 | if state_is_tuple else 2 * num_units)
146 | self._output_size = num_units
147 |
148 | @property
149 | def state_size(self):
150 | return self._state_size
151 |
152 | @property
153 | def output_size(self):
154 | return self._output_size
155 |
156 | def call(self, inputs, state):
157 | """Run one step of LSTM.
158 |
159 | Args:
160 | inputs: input Tensor, 2D, batch x num_units.
161 | state: if `state_is_tuple` is False, this must be a state Tensor,
162 | `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
163 | tuple of state Tensors, both `2-D`, with column sizes `c_state` and
164 | `m_state`.
165 | scope: VariableScope for the created subgraph; defaults to "LSTMCell".
166 |
167 | Returns:
168 | A tuple containing:
169 | - A `2-D, [batch x output_dim]`, Tensor representing the output of the
170 | LSTM after reading `inputs` when previous state was `state`.
171 | Here output_dim is:
172 | num_proj if num_proj was set,
173 | num_units otherwise.
174 | - Tensor(s) representing the new state of LSTM after reading `inputs` when
175 | the previous state was `state`. Same type and shape(s) as `state`.
176 |
177 | Raises:
178 | ValueError: If input size cannot be inferred from inputs via
179 | static shape inference.
180 | """
181 | sigmoid = math_ops.sigmoid
182 |
183 | num_proj = self._num_units if self._num_proj is None else self._num_proj
184 |
185 | if self._state_is_tuple:
186 | (c_prev, m_prev) = state
187 | else:
188 | c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
189 | m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
190 |
191 | dtype = inputs.dtype
192 | input_size = inputs.get_shape().with_rank(2)[1]
193 |
194 | if input_size.value is None:
195 | raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
196 |
197 | # Input gate weights
198 | self.w_xi = tf.get_variable("_w_xi", [input_size.value, self._num_units])
199 | self.w_hi = tf.get_variable("_w_hi", [self._num_units, self._num_units])
200 | self.w_ci = tf.get_variable("_w_ci", [self._num_units, self._num_units])
201 | # Output gate weights
202 | self.w_xo = tf.get_variable("_w_xo", [input_size.value, self._num_units])
203 | self.w_ho = tf.get_variable("_w_ho", [self._num_units, self._num_units])
204 | self.w_co = tf.get_variable("_w_co", [self._num_units, self._num_units])
205 |
206 | # Cell weights
207 | self.w_xc = tf.get_variable("_w_xc", [input_size.value, self._num_units])
208 | self.w_hc = tf.get_variable("_w_hc", [self._num_units, self._num_units])
209 |
210 | # Initialize the bias vectors
211 | self.b_i = tf.get_variable("_b_i", [self._num_units], initializer=init_ops.zeros_initializer())
212 | self.b_c = tf.get_variable("_b_c", [self._num_units], initializer=init_ops.zeros_initializer())
213 | self.b_o = tf.get_variable("_b_o", [self._num_units], initializer=init_ops.zeros_initializer())
214 |
215 | i_t = sigmoid(math_ops.matmul(inputs, self.w_xi) +
216 | math_ops.matmul(m_prev, self.w_hi) +
217 | math_ops.matmul(c_prev, self.w_ci) +
218 | self.b_i)
219 | c_t = ((1 - i_t) * c_prev + i_t * self._activation(math_ops.matmul(inputs, self.w_xc) +
220 | math_ops.matmul(m_prev, self.w_hc) + self.b_c))
221 |
222 | o_t = sigmoid(math_ops.matmul(inputs, self.w_xo) +
223 | math_ops.matmul(m_prev, self.w_ho) +
224 | math_ops.matmul(c_t, self.w_co) +
225 | self.b_o)
226 |
227 | h_t = o_t * self._activation(c_t)
228 |
229 | new_state = (rnn_cell_impl.LSTMStateTuple(c_t, h_t) if self._state_is_tuple else
230 | array_ops.concat([c_t, h_t], 1))
231 | return h_t, new_state
--------------------------------------------------------------------------------
/result/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duguiming111/NER-BERT-BiLSTM-CRF-/f327980e8a4535f791b8a7090a2207677ef338d9/result/.gitkeep
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # Author: duguiming
2 | # Description: 运行程序
3 | # Date: 2020-4-15
4 | import os
5 | import pickle
6 | import argparse
7 |
8 | from models.BERT_BiLSTM_CRF import Config, BertBiLSTMCrf
9 | from data_helper import tag_mapping, load_sentences, prepare_dataset, BatchManager
10 | from train_val_test import train, test, demo
11 | from utils import make_path
12 |
13 |
14 | parser = argparse.ArgumentParser(description='Chinese NER')
15 | parser.add_argument('--mode', type=str, required=True, help='train test or demo')
16 | args = parser.parse_args()
17 |
18 |
19 | if __name__ == "__main__":
20 | mode = args.mode
21 | config = Config()
22 |
23 | # load data
24 | train_sentences = load_sentences(config.train_path, config.lower, config.zeros)
25 | dev_sentences = load_sentences(config.dev_path, config.lower, config.zeros)
26 | test_sentences = load_sentences(config.test_path, config.lower, config.zeros)
27 |
28 | # tags dict
29 | if not os.path.isfile(config.map_file):
30 | # Create a dictionary and a mapping for tags
31 | _t, tag_to_id, id_to_tag = tag_mapping(train_sentences)
32 | with open(config.map_file, "wb") as f:
33 | pickle.dump([tag_to_id, id_to_tag], f)
34 | else:
35 | with open(config.map_file, "rb") as f:
36 | tag_to_id, id_to_tag = pickle.load(f)
37 |
38 | config.num_tags = len(tag_to_id)
39 |
40 | train_data = prepare_dataset(
41 | train_sentences, config.max_seq_len, tag_to_id, config.lower
42 | )
43 | dev_data = prepare_dataset(
44 | dev_sentences, config.max_seq_len, tag_to_id, config.lower
45 | )
46 | test_data = prepare_dataset(
47 | test_sentences, config.max_seq_len, tag_to_id, config.lower
48 | )
49 | print("%i / %i / %i sentences in train / dev / test." % (
50 | len(train_data), 0, len(test_data)))
51 |
52 | train_manager = BatchManager(train_data, config.batch_size)
53 | dev_manager = BatchManager(dev_data, config.batch_size)
54 | test_manager = BatchManager(test_data, config.batch_size)
55 |
56 | model = BertBiLSTMCrf(config)
57 | make_path(config)
58 |
59 | if mode == "train":
60 | train(model, config, train_manager, dev_manager, id_to_tag)
61 | elif mode == "test":
62 | test(model, config, test_manager, id_to_tag)
63 | else:
64 | demo(model, config, id_to_tag, tag_to_id)
65 |
--------------------------------------------------------------------------------
/train_val_test.py:
--------------------------------------------------------------------------------
1 | # Author: duguiming
2 | # Description: 训练、验证和测试
3 | # Date:2020-4-25
4 | import tensorflow as tf
5 | import numpy as np
6 | from tensorflow.contrib.crf import viterbi_decode
7 |
8 | from utils import get_logger, test_ner, bio_to_json
9 | from data_helper import input_from_line
10 |
11 |
12 | def get_feed_dict(model, is_train, batch, config):
13 | """
14 | :param is_train: Flag, True for train batch
15 | :param batch: list train/evaluate data
16 | :return: structured data to feed
17 | """
18 | _, segment_ids, chars, mask, tags = batch
19 | feed_dict = {
20 | model.input_ids: np.asarray(chars),
21 | model.input_mask: np.asarray(mask),
22 | model.segment_ids: np.asarray(segment_ids),
23 | model.dropout: 1.0,
24 | }
25 | if is_train:
26 | feed_dict[model.targets] = np.asarray(tags)
27 | feed_dict[model.dropout] = config.dropout_keep_prob
28 | return feed_dict
29 |
30 |
31 | def decode(logits, lengths, matrix, config):
32 | """
33 | :param logits: [batch_size, num_steps, num_tags]float32, logits
34 | :param lengths: [batch_size]int32, real length of each sequence
35 | :param matrix: transaction matrix for inference
36 | :return:
37 | """
38 | # inference final labels usa viterbi Algorithm
39 | paths = []
40 | small = -1000.0
41 | start = np.asarray([[small]*config.num_tags +[0]])
42 | for score, length in zip(logits, lengths):
43 | score = score[:length]
44 | pad = small * np.ones([length, 1])
45 | logits = np.concatenate([score, pad], axis=1)
46 | logits = np.concatenate([start, logits], axis=0)
47 | path, _ = viterbi_decode(logits, matrix)
48 |
49 | paths.append(path[1:])
50 | return paths
51 |
52 |
53 | def evaluate_(sess, model, data_manager, id_to_tag, config):
54 | """
55 | :param sess: session to run the model
56 | :param data: list of data
57 | :param id_to_tag: index to tag name
58 | :return: evaluate result
59 | """
60 | results = []
61 | trans = model.trans.eval()
62 | for batch in data_manager.iter_batch():
63 | strings = batch[0]
64 | labels = batch[-1]
65 | feed_dict = get_feed_dict(model, False, batch, config)
66 | lengths, scores = sess.run([model.lengths, model.logits], feed_dict)
67 | batch_paths = decode(scores, lengths, trans, config)
68 | for i in range(len(strings)):
69 | result = []
70 | string = strings[i][:lengths[i]]
71 | gold = [id_to_tag[int(x)] for x in labels[i][1:lengths[i]]]
72 | pred = [id_to_tag[int(x)] for x in batch_paths[i][1:lengths[i]]]
73 | for char, gold, pred in zip(string, gold, pred):
74 | result.append(" ".join([char, gold, pred]))
75 | results.append(result)
76 | return results
77 |
78 |
79 | def evaluate(sess, model, name, data_manager, id_to_tag, logger, config):
80 | logger.info("evaluate:{}".format(name))
81 | ner_results = evaluate_(sess, model, data_manager, id_to_tag, config)
82 | eval_lines = test_ner(ner_results, config.result_path)
83 | for line in eval_lines:
84 | logger.info(line)
85 | f1 = float(eval_lines[1].strip().split()[-1])
86 |
87 | if name == "dev":
88 | best_test_f1 = model.best_dev_f1.eval()
89 | if f1 > best_test_f1:
90 | tf.assign(model.best_dev_f1, f1).eval()
91 | logger.info("new best dev f1 score:{:>.3f}".format(f1))
92 | return f1 > best_test_f1
93 |
94 |
95 | def train(model, config, train_manager, dev_manager, id_to_tag):
96 | logger = get_logger(config.log_file)
97 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
98 | # limit GPU memory
99 | tf_config = tf.ConfigProto()
100 | tf_config.gpu_options.allow_growth = True
101 | steps_per_epoch = train_manager.len_data
102 | with tf.Session(config=tf_config) as sess:
103 | ckpt = tf.train.get_checkpoint_state(config.ckpt_path)
104 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
105 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
106 | model.saver.restore(sess, ckpt.model_checkpoint_path)
107 | else:
108 | logger.info("Created model with fresh parameters.")
109 | sess.run(tf.global_variables_initializer())
110 |
111 | logger.info("start training")
112 | loss = []
113 | for i in range(config.epoch):
114 | for batch in train_manager.iter_batch(shuffle=True):
115 | feed_dict = get_feed_dict(model, True, batch, config)
116 | global_step, batch_loss, _ = sess.run([model.global_step, model.loss, model.train_op], feed_dict)
117 |
118 | loss.append(batch_loss)
119 | if global_step % config.print_per_batch == 0:
120 | iteration = global_step // steps_per_epoch + 1
121 | logger.info("iteration:{} step:{}/{}, "
122 | "NER loss:{:>9.6f}".format(
123 | iteration, global_step % steps_per_epoch, steps_per_epoch, np.mean(loss)))
124 | loss = []
125 | best = evaluate(sess, model, "dev", dev_manager, id_to_tag, logger, config)
126 | if best:
127 | saver.save(sess, config.checkpoint_path, global_step=global_step)
128 |
129 |
130 | def test(model, config, test_manager, id_to_tag):
131 | logger = get_logger(config.log_file)
132 | tf_config = tf.ConfigProto()
133 | tf_config.gpu_options.allow_growth = True
134 | with tf.Session(config=tf_config) as sess:
135 | ckpt = tf.train.get_checkpoint_state(config.ckpt_path)
136 | sess.run(tf.global_variables_initializer())
137 | sess.run(tf.local_variables_initializer())
138 | saver = tf.train.Saver()
139 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
140 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
141 | # saver = tf.train.import_meta_graph('ckpt/ner.ckpt.meta')
142 | # saver.restore(session, tf.train.latest_checkpoint("ckpt/"))
143 | saver.restore(sess, ckpt.model_checkpoint_path)
144 | evaluate(sess, model, 'test', test_manager, id_to_tag, logger, config)
145 |
146 |
147 | def demo(model, config, id_to_tag, tag_to_id):
148 | logger = get_logger(config.log_file)
149 | tf_config = tf.ConfigProto()
150 | tf_config.gpu_options.allow_growth = True
151 | with tf.Session(config=tf_config) as sess:
152 | ckpt = tf.train.get_checkpoint_state(config.ckpt_path)
153 | sess.run(tf.global_variables_initializer())
154 | sess.run(tf.local_variables_initializer())
155 | saver = tf.train.Saver()
156 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
157 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
158 | # saver = tf.train.import_meta_graph('ckpt/ner.ckpt.meta')
159 | # saver.restore(session, tf.train.latest_checkpoint("ckpt/"))
160 | saver.restore(sess, ckpt.model_checkpoint_path)
161 | while True:
162 | line = input("input sentence, please:")
163 | inputs = input_from_line(line, config.max_seq_len, tag_to_id)
164 | trans = model.trans.eval(sess)
165 | feed_dict = get_feed_dict(model, False, inputs, config)
166 | lengths, scores = sess.run([model.lengths, model.logits], feed_dict)
167 | batch_paths = decode(scores, lengths, trans, config)
168 | tags = [id_to_tag[idx] for idx in batch_paths[0]]
169 | result = bio_to_json(inputs[0], tags[1:-1])
170 | print(result['entities'])
171 |
172 |
173 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Author:duguiming
2 | # Description: 工具文件
3 | # Date: 2020-4-15
4 | import os
5 | import re
6 | import codecs
7 | import logging
8 |
9 | from conlleval import return_report
10 |
11 |
12 | def create_dico(item_list):
13 | """
14 | Create a dictionary of items from a list of list of items.
15 | """
16 | assert type(item_list) is list
17 | dico = {}
18 | for items in item_list:
19 | for item in items:
20 | if item not in dico:
21 | dico[item] = 1
22 | else:
23 | dico[item] += 1
24 | return dico
25 |
26 |
27 | def test_ner(results, path):
28 | """
29 | Run perl script to evaluate model
30 | """
31 | output_file = os.path.join(path, "ner_predict.utf8")
32 | with codecs.open(output_file, "w", 'utf8') as f:
33 | to_write = []
34 | for block in results:
35 | for line in block:
36 | to_write.append(line + "\n")
37 | to_write.append("\n")
38 | f.writelines(to_write)
39 | eval_lines = return_report(output_file)
40 | return eval_lines
41 |
42 |
43 | def create_mapping(dico):
44 | """
45 | Create a mapping (item to ID / ID to item) from a dictionary.
46 | Items are ordered by decreasing frequency.
47 | """
48 | sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
49 | id_to_item = {i: v[0] for i, v in enumerate(sorted_items)}
50 | item_to_id = {v: k for k, v in id_to_item.items()}
51 | return item_to_id, id_to_item
52 |
53 |
54 | def zero_digits(s):
55 | """
56 | Replace every digit in a string by a zero.
57 | """
58 | return re.sub('\d', '0', s)
59 |
60 |
61 | def get_logger(log_file):
62 | logger = logging.getLogger(log_file)
63 | logger.setLevel(logging.DEBUG)
64 | fh = logging.FileHandler(log_file)
65 | fh.setLevel(logging.DEBUG)
66 | ch = logging.StreamHandler()
67 | ch.setLevel(logging.INFO)
68 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
69 | ch.setFormatter(formatter)
70 | fh.setFormatter(formatter)
71 | logger.addHandler(ch)
72 | logger.addHandler(fh)
73 | return logger
74 |
75 |
76 | def make_path(params):
77 | """
78 | Make folders for training and evaluation
79 | """
80 | if not os.path.isdir(params.result_path):
81 | os.makedirs(params.result_path)
82 | if not os.path.isdir(params.ckpt_path):
83 | os.makedirs(params.ckpt_path)
84 | if not os.path.isdir(params.log_path):
85 | os.makedirs(params.log_path)
86 |
87 |
88 | def bio_to_json(string, tags):
89 | item = {"string": string, "entities": []}
90 | entity_name = ""
91 | entity_start = 0
92 | iCount = 0
93 | entity_tag = ""
94 |
95 | for c_idx in range(len(tags)):
96 | c, tag = string[c_idx], tags[c_idx]
97 | if c_idx < len(tags)-1:
98 | tag_next = tags[c_idx+1]
99 | else:
100 | tag_next = ''
101 |
102 | if tag[0] == 'B':
103 | entity_tag = tag[2:]
104 | entity_name = c
105 | entity_start = iCount
106 | if tag_next[2:] != entity_tag:
107 | item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]})
108 | elif tag[0] == "I":
109 | if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O':
110 | tags[c_idx] = 'O'
111 | pass
112 | else:
113 | entity_name = entity_name + c
114 | if tag_next[2:] != entity_tag:
115 | item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag})
116 | entity_name = ''
117 | iCount += 1
118 | return item
--------------------------------------------------------------------------------