├── .gitignore
├── BERT fine-tune实践.md
├── README.md
├── conlleval
├── conlleval.py
├── data
├── example.dev
├── example.test
└── example.train
├── data_utils.py
├── loader.py
├── model.py
├── pictures
└── results.png
├── predict.py
├── rnncell.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | chinese_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001
3 | bert/__init__.py
4 | .DS_Store
5 | bert/CONTRIBUTING.md
6 | bert/create_pretraining_data.py
7 | bert/extract_features.py
8 | bert/LICENSE
9 | bert/modeling.py
10 | bert/modeling_test.py
11 | bert/multilingual.md
12 | bert/optimization.py
13 | bert/optimization_test.py
14 | bert/README.md
15 | bert/requirements.txt
16 | bert/run_classifier.py
17 | bert/run_pretraining.py
18 | bert/run_squad.py
19 | bert/sample_text.txt
20 | bert/tokenization.py
21 | bert/tokenization_test.py
22 | chinese_L-12_H-768_A-12/bert_config.json
23 | chinese_L-12_H-768_A-12/bert_model.ckpt.index
24 | chinese_L-12_H-768_A-12/bert_model.ckpt.meta
25 | chinese_L-12_H-768_A-12/vocab.txt
26 | data/.DS_Store
27 | .idea/BertNER.iml
28 | .idea/deployment.xml
29 | .idea/encodings.xml
30 | .idea/misc.xml
31 | .idea/modules.xml
32 | .idea/vcs.xml
33 | .idea/workspace.xml
34 | __pycache__/conlleval.cpython-36.pyc
35 | __pycache__/data_utils.cpython-36.pyc
36 | __pycache__/loader.cpython-36.pyc
37 | __pycache__/model.cpython-36.pyc
38 | __pycache__/rnncell.cpython-36.pyc
39 | __pycache__/utils.cpython-36.pyc
40 | bert/__pycache__/__init__.cpython-36.pyc
41 | bert/__pycache__/modeling.cpython-36.pyc
42 | bert/__pycache__/tokenization.cpython-36.pyc
43 | config_file
44 | maps.pkl
45 | train.log
46 |
--------------------------------------------------------------------------------
/BERT fine-tune实践.md:
--------------------------------------------------------------------------------
1 | # BERT fine-tune实践
2 |
3 | 1. #### 下载模型&代码
4 |
5 | BERT代码及模型下载地址:https://github.com/google-research/bert
6 |
7 | 2. #### 保存路径
8 |
9 | 3. #### 加载模型参数
10 |
11 | 4. #### 将输入Tokenize
12 |
13 | ```python
14 | def convert_single_example(char_line, tag_to_id, max_seq_length, tokenizer, label_line):
15 | """
16 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
17 | :param ex_index: index
18 | :param example: 一个样本
19 | :param label_list: 标签列表
20 | :param max_seq_length:
21 | :param tokenizer:
22 | :param mode:
23 | :return:
24 | """
25 | text_list = char_line.split(' ')
26 | label_list = label_line.split(' ')
27 |
28 | tokens = []
29 | labels = []
30 | for i, word in enumerate(text_list):
31 | # 分词,如果是中文,就是分字,但是对于一些不在BERT的vocab.txt中得字符会被进行WordPice处理(例如中文的引号),可以将所有的分字操作替换为list(input)
32 | token = tokenizer.tokenize(word)
33 | tokens.extend(token)
34 | label_1 = label_list[i]
35 | for m in range(len(token)):
36 | if m == 0:
37 | labels.append(label_1)
38 | else: # 一般不会出现else
39 | labels.append("X")
40 | # 序列截断
41 | if len(tokens) >= max_seq_length - 1:
42 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
43 | labels = labels[0:(max_seq_length - 2)]
44 | ntokens = []
45 | segment_ids = []
46 | label_ids = []
47 | ntokens.append("[CLS]") # 句子开始设置CLS 标志
48 | segment_ids.append(0)
49 | # append("O") or append("[CLS]") not sure!
50 | label_ids.append(tag_to_id["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
51 | for i, token in enumerate(tokens):
52 | ntokens.append(token)
53 | segment_ids.append(0)
54 | label_ids.append(tag_to_id[labels[i]])
55 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志
56 | segment_ids.append(0)
57 | # append("O") or append("[SEP]") not sure!
58 | label_ids.append(tag_to_id["[SEP]"])
59 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式
60 | input_mask = [1] * len(input_ids)
61 |
62 | # padding
63 | while len(input_ids) < max_seq_length:
64 | input_ids.append(0)
65 | input_mask.append(0)
66 | segment_ids.append(0)
67 | # we don't concerned about it!
68 | label_ids.append(0)
69 | ntokens.append("**NULL**")
70 |
71 | return input_ids, input_mask, segment_ids, label_ids
72 | ```
73 |
74 | 5. #### 获得embedding
75 |
76 | 在上一过程中,将文本数据处理为以下函数的三个输入变量
77 |
78 | ```python
79 | # load bert embedding
80 | bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json") # 配置文件地址。
81 | model = modeling.BertModel(
82 | config=bert_config,
83 | is_training=True,
84 | input_ids=self.input_ids,
85 | input_mask=self.input_mask,
86 | token_type_ids=self.segment_ids,
87 | use_one_hot_embeddings=False)
88 | embedding = model.get_sequence_output()
89 | ```
90 |
91 | 6. #### 冻结bert参数层
92 |
93 | ```python
94 | # bert模型参数初始化的地方
95 | init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
96 | # 获取模型中所有的训练参数。
97 | tvars = tf.trainable_variables()
98 | # 加载BERT模型
99 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint)
100 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
101 | print("**** Trainable Variables ****")
102 | # 打印加载模型的参数
103 | train_vars = []
104 | for var in tvars:
105 | init_string = ""
106 | if var.name in initialized_variable_names:
107 | init_string = ", *INIT_FROM_CKPT*"
108 | else:
109 | train_vars.append(var)
110 | print(" name = %s, shape = %s%s", var.name, var.shape,
111 | init_string)
112 | grads = tf.gradients(self.loss, train_vars)
113 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
114 |
115 | self.train_op = self.opt.apply_gradients(
116 | zip(grads, train_vars), global_step=self.global_step)
117 | ```
118 |
119 | 7. #### 开始训练task层
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # *Bert-ChineseNER*
2 |
3 | ### *Introduction*
4 |
5 | 该项目是基于谷歌开源的BERT预训练模型,在中文NER任务上进行fine-tune。
6 |
7 | ### *Datasets & Model*
8 |
9 | 训练本模型的主要标记数据,来自于zjy-usas的[ChineseNER](https://github.com/zjy-ucas/ChineseNER)项目。本项目在原本的BiLSTM+CRF的框架前,添加了BERT模型作为embedding的特征获取层,预训练的中文BERT模型及代码来自于Google Research的[bert](https://github.com/google-research/bert)。
10 |
11 | ### *Results*
12 |
13 |
14 |
15 | 引入bert之后,可以看到在验证集上的F-1值在训练了16个epoch时就已经达到了**94.87**,并在测试集上达到了**93.68**,在这个数据集上的F-1值提升了两个多百分点。
16 |
17 | ### *Train*
18 |
19 | 1. 下载[bert模型代码](https://github.com/google-research/bert),放入本项目根目录
20 | 2. 下载[bert的中文预训练模型](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip),解压放入本项目根目录
21 | 3. 搭建依赖环境python3+tensorflow1.12
22 | 4. 执行`python3 train.py`即可训练模型
23 | 5. 执行`python3 predict.py`可以对单句进行测试
24 |
25 | 整理后的项目目录,应如下所示:
26 |
27 | ``` .
28 | ├── BERT fine-tune实践.md
29 | ├── README.md
30 | ├── bert
31 | ├── chinese_L-12_H-768_A-12
32 | ├── conlleval
33 | ├── conlleval.py
34 | ├── data
35 | ├── data_utils.py
36 | ├── loader.py
37 | ├── model.py
38 | ├── pictures>
39 | ├── predict.py
40 | ├── rnncell.py
41 | ├── train.py
42 | └── utils.py
43 | ```
44 |
45 | ### *Conclusion*
46 |
47 | 可以看到,使用bert以后,模型的精度提升了两个多百分点。并且,在后续测试过程中发现,使用bert训练的NER模型拥有更强的泛化性能,比如训练集中未见过的公司名称等,都可以很好的识别。而仅仅使用[ChineseNER](https://github.com/zjy-ucas/ChineseNER)中提供的训练集,基于BiLSTM+CRF的框架训练得到的模型,基本上无法解决OOV问题。
48 |
49 | ### *Fine-tune*
50 | 目前的代码是Feature Based的迁移,可以改为Fine-tune的迁移,效果还能再提升1个点左右。fine-tune可以自行修改代码,将model中的bert参数加入一起训练,并将lr修改到1e-5的量级。
51 | 并且,是否添加BiLSTM都对结果影响不大,可以直接使用BERT输出的结果进行解码,建议还是加一层CRF,强化标记间的转移规则。
52 |
53 | ### *Reference*
54 |
55 | (1) https://github.com/zjy-ucas/ChineseNER
56 |
57 | (2) https://github.com/google-research/bert
58 |
59 | (3) [Neural Architectures for Named Entity Recognition](https://arxiv.org/abs/1603.01360)
60 |
--------------------------------------------------------------------------------
/conlleval:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 | # conlleval: evaluate result of processing CoNLL-2000 shared task
3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html
5 | # options: l: generate LaTeX output for tables like in
6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex
7 | # r: accept raw result tags (without B- and I- prefix;
8 | # assumes one word per chunk)
9 | # d: alternative delimiter tag (default is single space)
10 | # o: alternative outside tag (default is O)
11 | # note: the file should contain lines with items separated
12 | # by $delimiter characters (default space). The final
13 | # two items should contain the correct tag and the
14 | # guessed tag in that order. Sentences should be
15 | # separated from each other by empty lines or lines
16 | # with $boundary fields (default -X-).
17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/
18 | # started: 1998-09-25
19 | # version: 2004-01-26
20 | # author: Erik Tjong Kim Sang
21 |
22 | use strict;
23 |
24 | my $false = 0;
25 | my $true = 42;
26 |
27 | my $boundary = "-X-"; # sentence boundary
28 | my $correct; # current corpus chunk tag (I,O,B)
29 | my $correctChunk = 0; # number of correctly identified chunks
30 | my $correctTags = 0; # number of correct chunk tags
31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32 | my $delimiter = " "; # field delimiter
33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34 | my $firstItem; # first feature (for sentence boundary checks)
35 | my $foundCorrect = 0; # number of chunks in corpus
36 | my $foundGuessed = 0; # number of identified chunks
37 | my $guessed; # current guessed chunk tag
38 | my $guessedType; # type of current guessed chunk tag
39 | my $i; # miscellaneous counter
40 | my $inCorrect = $false; # currently processed chunk is correct until now
41 | my $lastCorrect = "O"; # previous chunk tag in corpus
42 | my $latex = 0; # generate LaTeX formatted output
43 | my $lastCorrectType = ""; # type of previously identified chunk tag
44 | my $lastGuessed = "O"; # previously identified chunk tag
45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus
46 | my $lastType; # temporary storage for detecting duplicates
47 | my $line; # line
48 | my $nbrOfFeatures = -1; # number of features per line
49 | my $precision = 0.0; # precision score
50 | my $oTag = "O"; # outside tag, default O
51 | my $raw = 0; # raw input: add B to every token
52 | my $recall = 0.0; # recall score
53 | my $tokenCounter = 0; # token counter (ignores sentence breaks)
54 |
55 | my %correctChunk = (); # number of correctly identified chunks per type
56 | my %foundCorrect = (); # number of chunks in corpus per type
57 | my %foundGuessed = (); # number of identified chunks per type
58 |
59 | my @features; # features on line
60 | my @sortedTypes; # sorted list of chunk type names
61 |
62 | # sanity check
63 | while (@ARGV and $ARGV[0] =~ /^-/) {
64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66 | elsif ($ARGV[0] eq "-d") {
67 | shift(@ARGV);
68 | if (not defined $ARGV[0]) {
69 | die "conlleval: -d requires delimiter character";
70 | }
71 | $delimiter = shift(@ARGV);
72 | } elsif ($ARGV[0] eq "-o") {
73 | shift(@ARGV);
74 | if (not defined $ARGV[0]) {
75 | die "conlleval: -o requires delimiter character";
76 | }
77 | $oTag = shift(@ARGV);
78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79 | }
80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81 | # process input
82 | while () {
83 | chomp($line = $_);
84 | @features = split(/$delimiter/,$line);
85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
86 | elsif ($nbrOfFeatures != $#features and @features != 0) {
87 | printf STDERR "unexpected number of features: %d (%d)\n",
88 | $#features+1,$nbrOfFeatures+1;
89 | exit(1);
90 | }
91 | if (@features == 0 or
92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
93 | if (@features < 2) {
94 | die "conlleval: unexpected number of features in line $line\n";
95 | }
96 | if ($raw) {
97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
99 | if ($features[$#features] ne "O") {
100 | $features[$#features] = "B-$features[$#features]";
101 | }
102 | if ($features[$#features-1] ne "O") {
103 | $features[$#features-1] = "B-$features[$#features-1]";
104 | }
105 | }
106 | # 20040126 ET code which allows hyphens in the types
107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
108 | $guessed = $1;
109 | $guessedType = $2;
110 | } else {
111 | $guessed = $features[$#features];
112 | $guessedType = "";
113 | }
114 | pop(@features);
115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
116 | $correct = $1;
117 | $correctType = $2;
118 | } else {
119 | $correct = $features[$#features];
120 | $correctType = "";
121 | }
122 | pop(@features);
123 | # ($guessed,$guessedType) = split(/-/,pop(@features));
124 | # ($correct,$correctType) = split(/-/,pop(@features));
125 | $guessedType = $guessedType ? $guessedType : "";
126 | $correctType = $correctType ? $correctType : "";
127 | $firstItem = shift(@features);
128 |
129 | # 1999-06-26 sentence breaks should always be counted as out of chunk
130 | if ( $firstItem eq $boundary ) { $guessed = "O"; }
131 |
132 | if ($inCorrect) {
133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
135 | $lastGuessedType eq $lastCorrectType) {
136 | $inCorrect=$false;
137 | $correctChunk++;
138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
139 | $correctChunk{$lastCorrectType}+1 : 1;
140 | } elsif (
141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
143 | $guessedType ne $correctType ) {
144 | $inCorrect=$false;
145 | }
146 | }
147 |
148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
150 | $guessedType eq $correctType) { $inCorrect = $true; }
151 |
152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
153 | $foundCorrect++;
154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
155 | $foundCorrect{$correctType}+1 : 1;
156 | }
157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
158 | $foundGuessed++;
159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
160 | $foundGuessed{$guessedType}+1 : 1;
161 | }
162 | if ( $firstItem ne $boundary ) {
163 | if ( $correct eq $guessed and $guessedType eq $correctType ) {
164 | $correctTags++;
165 | }
166 | $tokenCounter++;
167 | }
168 |
169 | $lastGuessed = $guessed;
170 | $lastCorrect = $correct;
171 | $lastGuessedType = $guessedType;
172 | $lastCorrectType = $correctType;
173 | }
174 | if ($inCorrect) {
175 | $correctChunk++;
176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
177 | $correctChunk{$lastCorrectType}+1 : 1;
178 | }
179 |
180 | if (not $latex) {
181 | # compute overall precision, recall and FB1 (default values are 0.0)
182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
184 | $FB1 = 2*$precision*$recall/($precision+$recall)
185 | if ($precision+$recall > 0);
186 |
187 | # print overall performance
188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
190 | if ($tokenCounter>0) {
191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
192 | printf "precision: %6.2f%%; ",$precision;
193 | printf "recall: %6.2f%%; ",$recall;
194 | printf "FB1: %6.2f\n",$FB1;
195 | }
196 | }
197 |
198 | # sort chunk type names
199 | undef($lastType);
200 | @sortedTypes = ();
201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
202 | if (not($lastType) or $lastType ne $i) {
203 | push(@sortedTypes,($i));
204 | }
205 | $lastType = $i;
206 | }
207 | # print performance per chunk type
208 | if (not $latex) {
209 | for $i (@sortedTypes) {
210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
213 | if (not($foundCorrect{$i})) { $recall = 0.0; }
214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
217 | printf "%17s: ",$i;
218 | printf "precision: %6.2f%%; ",$precision;
219 | printf "recall: %6.2f%%; ",$recall;
220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
221 | }
222 | } else {
223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
224 | for $i (@sortedTypes) {
225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
226 | if (not($foundGuessed{$i})) { $precision = 0.0; }
227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
228 | if (not($foundCorrect{$i})) { $recall = 0.0; }
229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
233 | $i,$precision,$recall,$FB1;
234 | }
235 | print "\\hline\n";
236 | $precision = 0.0;
237 | $recall = 0;
238 | $FB1 = 0.0;
239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
241 | $FB1 = 2*$precision*$recall/($precision+$recall)
242 | if ($precision+$recall > 0);
243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
244 | $precision,$recall,$FB1;
245 | }
246 |
247 | exit 0;
248 |
249 | # endOfChunk: checks if a chunk ended between the previous and current word
250 | # arguments: previous and current chunk tags, previous and current types
251 | # note: this code is capable of handling other chunk representations
252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
254 |
255 | sub endOfChunk {
256 | my $prevTag = shift(@_);
257 | my $tag = shift(@_);
258 | my $prevType = shift(@_);
259 | my $type = shift(@_);
260 | my $chunkEnd = $false;
261 |
262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
266 |
267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
271 |
272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
273 | $chunkEnd = $true;
274 | }
275 |
276 | # corrected 1998-12-22: these chunks are assumed to have length 1
277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; }
278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; }
279 |
280 | return($chunkEnd);
281 | }
282 |
283 | # startOfChunk: checks if a chunk started between the previous and current word
284 | # arguments: previous and current chunk tags, previous and current types
285 | # note: this code is capable of handling other chunk representations
286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
288 |
289 | sub startOfChunk {
290 | my $prevTag = shift(@_);
291 | my $tag = shift(@_);
292 | my $prevType = shift(@_);
293 | my $type = shift(@_);
294 | my $chunkStart = $false;
295 |
296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
300 |
301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
305 |
306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
307 | $chunkStart = $true;
308 | }
309 |
310 | # corrected 1998-12-22: these chunks are assumed to have length 1
311 | if ( $tag eq "[" ) { $chunkStart = $true; }
312 | if ( $tag eq "]" ) { $chunkStart = $true; }
313 |
314 | return($chunkStart);
315 | }
316 |
--------------------------------------------------------------------------------
/conlleval.py:
--------------------------------------------------------------------------------
1 | # Python version of the evaluation script from CoNLL'00-
2 | # Originates from: https://github.com/spyysalo/conlleval.py
3 |
4 |
5 | # Intentional differences:
6 | # - accept any space as delimiter by default
7 | # - optional file argument (default STDIN)
8 | # - option to set boundary (-b argument)
9 | # - LaTeX output (-l argument) not supported
10 | # - raw tags (-r argument) not supported
11 |
12 | import sys
13 | import re
14 | import codecs
15 | from collections import defaultdict, namedtuple
16 |
17 | ANY_SPACE = ''
18 |
19 |
20 | class FormatError(Exception):
21 | pass
22 |
23 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
24 |
25 |
26 | class EvalCounts(object):
27 | def __init__(self):
28 | self.correct_chunk = 0 # number of correctly identified chunks
29 | self.correct_tags = 0 # number of correct chunk tags
30 | self.found_correct = 0 # number of chunks in corpus
31 | self.found_guessed = 0 # number of identified chunks
32 | self.token_counter = 0 # token counter (ignores sentence breaks)
33 |
34 | # counts by type
35 | self.t_correct_chunk = defaultdict(int)
36 | self.t_found_correct = defaultdict(int)
37 | self.t_found_guessed = defaultdict(int)
38 |
39 |
40 | def parse_args(argv):
41 | import argparse
42 | parser = argparse.ArgumentParser(
43 | description='evaluate tagging results using CoNLL criteria',
44 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
45 | )
46 | arg = parser.add_argument
47 | arg('-b', '--boundary', metavar='STR', default='-X-',
48 | help='sentence boundary')
49 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
50 | help='character delimiting items in input')
51 | arg('-o', '--otag', metavar='CHAR', default='O',
52 | help='alternative outside tag')
53 | arg('file', nargs='?', default=None)
54 | return parser.parse_args(argv)
55 |
56 |
57 | def parse_tag(t):
58 | m = re.match(r'^([^-]*)-(.*)$', t)
59 | return m.groups() if m else (t, '')
60 |
61 |
62 | def evaluate(iterable, options=None):
63 | if options is None:
64 | options = parse_args([]) # use defaults
65 |
66 | counts = EvalCounts()
67 | num_features = None # number of features per line
68 | in_correct = False # currently processed chunks is correct until now
69 | last_correct = 'O' # previous chunk tag in corpus
70 | last_correct_type = '' # type of previously identified chunk tag
71 | last_guessed = 'O' # previously identified chunk tag
72 | last_guessed_type = '' # type of previous chunk tag in corpus
73 |
74 | for line in iterable:
75 | line = line.rstrip('\r\n')
76 |
77 | if options.delimiter == ANY_SPACE:
78 | features = line.split()
79 | else:
80 | features = line.split(options.delimiter)
81 |
82 | if num_features is None:
83 | num_features = len(features)
84 | elif num_features != len(features) and len(features) != 0:
85 | raise FormatError('unexpected number of features: %d (%d)' %
86 | (len(features), num_features))
87 |
88 | if len(features) == 0 or features[0] == options.boundary:
89 | features = [options.boundary, 'O', 'O']
90 | if len(features) < 3:
91 | raise FormatError('unexpected number of features in line %s' % line)
92 |
93 | guessed, guessed_type = parse_tag(features.pop())
94 | correct, correct_type = parse_tag(features.pop())
95 | first_item = features.pop(0)
96 |
97 | if first_item == options.boundary:
98 | guessed = 'O'
99 |
100 | end_correct = end_of_chunk(last_correct, correct,
101 | last_correct_type, correct_type)
102 | end_guessed = end_of_chunk(last_guessed, guessed,
103 | last_guessed_type, guessed_type)
104 | start_correct = start_of_chunk(last_correct, correct,
105 | last_correct_type, correct_type)
106 | start_guessed = start_of_chunk(last_guessed, guessed,
107 | last_guessed_type, guessed_type)
108 |
109 | if in_correct:
110 | if (end_correct and end_guessed and
111 | last_guessed_type == last_correct_type):
112 | in_correct = False
113 | counts.correct_chunk += 1
114 | counts.t_correct_chunk[last_correct_type] += 1
115 | elif (end_correct != end_guessed or guessed_type != correct_type):
116 | in_correct = False
117 |
118 | if start_correct and start_guessed and guessed_type == correct_type:
119 | in_correct = True
120 |
121 | if start_correct:
122 | counts.found_correct += 1
123 | counts.t_found_correct[correct_type] += 1
124 | if start_guessed:
125 | counts.found_guessed += 1
126 | counts.t_found_guessed[guessed_type] += 1
127 | if first_item != options.boundary:
128 | if correct == guessed and guessed_type == correct_type:
129 | counts.correct_tags += 1
130 | counts.token_counter += 1
131 |
132 | last_guessed = guessed
133 | last_correct = correct
134 | last_guessed_type = guessed_type
135 | last_correct_type = correct_type
136 |
137 | if in_correct:
138 | counts.correct_chunk += 1
139 | counts.t_correct_chunk[last_correct_type] += 1
140 |
141 | return counts
142 |
143 |
144 | def uniq(iterable):
145 | seen = set()
146 | return [i for i in iterable if not (i in seen or seen.add(i))]
147 |
148 |
149 | def calculate_metrics(correct, guessed, total):
150 | tp, fp, fn = correct, guessed-correct, total-correct
151 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
152 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
153 | f = 0 if p + r == 0 else 2 * p * r / (p + r)
154 | return Metrics(tp, fp, fn, p, r, f)
155 |
156 |
157 | def metrics(counts):
158 | c = counts
159 | overall = calculate_metrics(
160 | c.correct_chunk, c.found_guessed, c.found_correct
161 | )
162 | by_type = {}
163 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
164 | by_type[t] = calculate_metrics(
165 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
166 | )
167 | return overall, by_type
168 |
169 |
170 | def report(counts, out=None):
171 | if out is None:
172 | out = sys.stdout
173 |
174 | overall, by_type = metrics(counts)
175 |
176 | c = counts
177 | out.write('processed %d tokens with %d phrases; ' %
178 | (c.token_counter, c.found_correct))
179 | out.write('found: %d phrases; correct: %d.\n' %
180 | (c.found_guessed, c.correct_chunk))
181 |
182 | if c.token_counter > 0:
183 | out.write('accuracy: %6.2f%%; ' %
184 | (100.*c.correct_tags/c.token_counter))
185 | out.write('precision: %6.2f%%; ' % (100.*overall.prec))
186 | out.write('recall: %6.2f%%; ' % (100.*overall.rec))
187 | out.write('FB1: %6.2f\n' % (100.*overall.fscore))
188 |
189 | for i, m in sorted(by_type.items()):
190 | out.write('%17s: ' % i)
191 | out.write('precision: %6.2f%%; ' % (100.*m.prec))
192 | out.write('recall: %6.2f%%; ' % (100.*m.rec))
193 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
194 |
195 |
196 | def report_notprint(counts, out=None):
197 | if out is None:
198 | out = sys.stdout
199 |
200 | overall, by_type = metrics(counts)
201 |
202 | c = counts
203 | final_report = []
204 | line = []
205 | line.append('processed %d tokens with %d phrases; ' %
206 | (c.token_counter, c.found_correct))
207 | line.append('found: %d phrases; correct: %d.\n' %
208 | (c.found_guessed, c.correct_chunk))
209 | final_report.append("".join(line))
210 |
211 | if c.token_counter > 0:
212 | line = []
213 | line.append('accuracy: %6.2f%%; ' %
214 | (100.*c.correct_tags/c.token_counter))
215 | line.append('precision: %6.2f%%; ' % (100.*overall.prec))
216 | line.append('recall: %6.2f%%; ' % (100.*overall.rec))
217 | line.append('FB1: %6.2f\n' % (100.*overall.fscore))
218 | final_report.append("".join(line))
219 |
220 | for i, m in sorted(by_type.items()):
221 | line = []
222 | line.append('%17s: ' % i)
223 | line.append('precision: %6.2f%%; ' % (100.*m.prec))
224 | line.append('recall: %6.2f%%; ' % (100.*m.rec))
225 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
226 | final_report.append("".join(line))
227 | return final_report
228 |
229 |
230 | def end_of_chunk(prev_tag, tag, prev_type, type_):
231 | # check if a chunk ended between the previous and current word
232 | # arguments: previous and current chunk tags, previous and current types
233 | chunk_end = False
234 |
235 | if prev_tag == 'E': chunk_end = True
236 | if prev_tag == 'S': chunk_end = True
237 |
238 | if prev_tag == 'B' and tag == 'B': chunk_end = True
239 | if prev_tag == 'B' and tag == 'S': chunk_end = True
240 | if prev_tag == 'B' and tag == 'O': chunk_end = True
241 | if prev_tag == 'I' and tag == 'B': chunk_end = True
242 | if prev_tag == 'I' and tag == 'S': chunk_end = True
243 | if prev_tag == 'I' and tag == 'O': chunk_end = True
244 |
245 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
246 | chunk_end = True
247 |
248 | # these chunks are assumed to have length 1
249 | if prev_tag == ']': chunk_end = True
250 | if prev_tag == '[': chunk_end = True
251 |
252 | return chunk_end
253 |
254 |
255 | def start_of_chunk(prev_tag, tag, prev_type, type_):
256 | # check if a chunk started between the previous and current word
257 | # arguments: previous and current chunk tags, previous and current types
258 | chunk_start = False
259 |
260 | if tag == 'B': chunk_start = True
261 | if tag == 'S': chunk_start = True
262 |
263 | if prev_tag == 'E' and tag == 'E': chunk_start = True
264 | if prev_tag == 'E' and tag == 'I': chunk_start = True
265 | if prev_tag == 'S' and tag == 'E': chunk_start = True
266 | if prev_tag == 'S' and tag == 'I': chunk_start = True
267 | if prev_tag == 'O' and tag == 'E': chunk_start = True
268 | if prev_tag == 'O' and tag == 'I': chunk_start = True
269 |
270 | if tag != 'O' and tag != '.' and prev_type != type_:
271 | chunk_start = True
272 |
273 | # these chunks are assumed to have length 1
274 | if tag == '[': chunk_start = True
275 | if tag == ']': chunk_start = True
276 |
277 | return chunk_start
278 |
279 |
280 | def return_report(input_file):
281 | with codecs.open(input_file, "r", "utf8") as f:
282 | counts = evaluate(f)
283 | return report_notprint(counts)
284 |
285 |
286 | def main(argv):
287 | args = parse_args(argv[1:])
288 |
289 | if args.file is None:
290 | counts = evaluate(sys.stdin, args)
291 | else:
292 | with open(args.file) as f:
293 | counts = evaluate(f, args)
294 | report(counts)
295 |
296 | if __name__ == '__main__':
297 | sys.exit(main(sys.argv))
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | # encoding = utf8
2 | import re
3 | import math
4 | import codecs
5 | import random
6 |
7 | import numpy as np
8 | #import jieba
9 | #jieba.initialize()
10 |
11 |
12 | def create_dico(item_list):
13 | """
14 | Create a dictionary of items from a list of list of items.
15 | """
16 | assert type(item_list) is list
17 | dico = {}
18 | for items in item_list:
19 | for item in items:
20 | if item not in dico:
21 | dico[item] = 1
22 | else:
23 | dico[item] += 1
24 |
25 | return dico
26 |
27 |
28 | def create_mapping(dico):
29 | """
30 | Create a mapping (item to ID / ID to item) from a dictionary.
31 | Items are ordered by decreasing frequency.
32 | """
33 | sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
34 | id_to_item = {i: v[0] for i, v in enumerate(sorted_items)}
35 | item_to_id = {v: k for k, v in id_to_item.items()}
36 | return item_to_id, id_to_item
37 |
38 |
39 | def zero_digits(s):
40 | """
41 | Replace every digit in a string by a zero.
42 | """
43 | return re.sub('\d', '0', s)
44 |
45 |
46 | def iob2(tags):
47 | """
48 | Check that tags have a valid IOB format.
49 | Tags in IOB1 format are converted to IOB2.
50 | """
51 | for i, tag in enumerate(tags):
52 | if tag == 'O':
53 | continue
54 | split = tag.split('-')
55 | if len(split) != 2 or split[0] not in ['I', 'B']:
56 | return False
57 | if split[0] == 'B':
58 | continue
59 | elif i == 0 or tags[i - 1] == 'O': # conversion IOB1 to IOB2
60 | tags[i] = 'B' + tag[1:]
61 | elif tags[i - 1][1:] == tag[1:]:
62 | continue
63 | else: # conversion IOB1 to IOB2
64 | tags[i] = 'B' + tag[1:]
65 | return True
66 |
67 |
68 | def iob_iobes(tags):
69 | """
70 | IOB -> IOBES
71 | """
72 | new_tags = []
73 | for i, tag in enumerate(tags):
74 | if tag == 'O':
75 | new_tags.append(tag)
76 | elif tag.split('-')[0] == 'B':
77 | if i + 1 != len(tags) and \
78 | tags[i + 1].split('-')[0] == 'I':
79 | new_tags.append(tag)
80 | else:
81 | new_tags.append(tag.replace('B-', 'S-'))
82 | elif tag.split('-')[0] == 'I':
83 | if i + 1 < len(tags) and \
84 | tags[i + 1].split('-')[0] == 'I':
85 | new_tags.append(tag)
86 | else:
87 | new_tags.append(tag.replace('I-', 'E-'))
88 | else:
89 | raise Exception('Invalid IOB format!')
90 | return new_tags
91 |
92 |
93 | def iobes_iob(tags):
94 | """
95 | IOBES -> IOB
96 | """
97 | new_tags = []
98 | for i, tag in enumerate(tags):
99 | if tag.split('-')[0] == 'B':
100 | new_tags.append(tag)
101 | elif tag.split('-')[0] == 'I':
102 | new_tags.append(tag)
103 | elif tag.split('-')[0] == 'S':
104 | new_tags.append(tag.replace('S-', 'B-'))
105 | elif tag.split('-')[0] == 'E':
106 | new_tags.append(tag.replace('E-', 'I-'))
107 | elif tag.split('-')[0] == 'O':
108 | new_tags.append(tag)
109 | else:
110 | raise Exception('Invalid format!')
111 | return new_tags
112 |
113 |
114 | def insert_singletons(words, singletons, p=0.5):
115 | """
116 | Replace singletons by the unknown word with a probability p.
117 | """
118 | new_words = []
119 | for word in words:
120 | if word in singletons and np.random.uniform() < p:
121 | new_words.append(0)
122 | else:
123 | new_words.append(word)
124 | return new_words
125 |
126 |
127 | def get_seg_features(string):
128 | """
129 | Segment text with jieba
130 | features are represented in bies format
131 | s donates single word
132 | """
133 | seg_feature = []
134 |
135 | for word in jieba.cut(string):
136 | if len(word) == 1:
137 | seg_feature.append(0) #o
138 | else:
139 | tmp = [2] * len(word) #i
140 | tmp[0] = 1 #b
141 | tmp[-1] = 3 #e
142 | seg_feature.extend(tmp)
143 | return seg_feature
144 |
145 |
146 | def create_input(data):
147 | """
148 | Take sentence data and return an input for
149 | the training or the evaluation function.
150 | """
151 | inputs = list()
152 | inputs.append(data['chars'])
153 | inputs.append(data["segs"])
154 | inputs.append(data['tags'])
155 | return inputs
156 |
157 |
158 | def load_word2vec(emb_path, id_to_word, word_dim, old_weights):
159 | """
160 | Load word embedding from pre-trained file
161 | embedding size must match
162 | """
163 | new_weights = old_weights
164 | print('Loading pretrained embeddings from {}...'.format(emb_path))
165 | pre_trained = {}
166 | emb_invalid = 0
167 | for i, line in enumerate(codecs.open(emb_path, 'r', 'utf-8')):
168 | line = line.rstrip().split()
169 | if len(line) == word_dim + 1:
170 | pre_trained[line[0]] = np.array(
171 | [float(x) for x in line[1:]]
172 | ).astype(np.float32)
173 | else:
174 | emb_invalid += 1
175 | if emb_invalid > 0:
176 | print('WARNING: %i invalid lines' % emb_invalid)
177 | c_found = 0
178 | c_lower = 0
179 | c_zeros = 0
180 | n_words = len(id_to_word)
181 | # Lookup table initialization
182 | for i in range(n_words):
183 | word = id_to_word[i]
184 | if word in pre_trained:
185 | new_weights[i] = pre_trained[word]
186 | c_found += 1
187 | elif word.lower() in pre_trained:
188 | new_weights[i] = pre_trained[word.lower()]
189 | c_lower += 1
190 | elif re.sub('\d', '0', word.lower()) in pre_trained: #replace numbers to zero
191 | new_weights[i] = pre_trained[
192 | re.sub('\d', '0', word.lower())
193 | ]
194 | c_zeros += 1
195 | print('Loaded %i pretrained embeddings.' % len(pre_trained))
196 | print('%i / %i (%.4f%%) words have been initialized with '
197 | 'pretrained embeddings.' % (
198 | c_found + c_lower + c_zeros, n_words,
199 | 100. * (c_found + c_lower + c_zeros) / n_words)
200 | )
201 | print('%i found directly, %i after lowercasing, '
202 | '%i after lowercasing + zero.' % (
203 | c_found, c_lower, c_zeros
204 | ))
205 | return new_weights
206 |
207 |
208 | def full_to_half(s):
209 | """
210 | Convert full-width character to half-width one
211 | """
212 | n = []
213 | for char in s:
214 | num = ord(char)
215 | if num == 0x3000:
216 | num = 32
217 | elif 0xFF01 <= num <= 0xFF5E:
218 | num -= 0xfee0
219 | char = chr(num)
220 | n.append(char)
221 | return ''.join(n)
222 |
223 |
224 | def cut_to_sentence(text):
225 | """
226 | Cut text to sentences
227 | """
228 | sentence = []
229 | sentences = []
230 | len_p = len(text)
231 | pre_cut = False
232 | for idx, word in enumerate(text):
233 | sentence.append(word)
234 | cut = False
235 | if pre_cut:
236 | cut=True
237 | pre_cut=False
238 | if word in u"。;!?\n":
239 | cut = True
240 | if len_p > idx+1:
241 | if text[idx+1] in ".。”\"\'“”‘’?!":
242 | cut = False
243 | pre_cut=True
244 |
245 | if cut:
246 | sentences.append(sentence)
247 | sentence = []
248 | if sentence:
249 | sentences.append("".join(list(sentence)))
250 | return sentences
251 |
252 |
253 | def replace_html(s):
254 | s = s.replace('"','"')
255 | s = s.replace('&','&')
256 | s = s.replace('<','<')
257 | s = s.replace('>','>')
258 | s = s.replace(' ',' ')
259 | s = s.replace("“", "“")
260 | s = s.replace("”", "”")
261 | s = s.replace("—","")
262 | s = s.replace("\xa0", " ")
263 | return(s)
264 |
265 | class BatchManager(object):
266 |
267 | def __init__(self, data, batch_size):
268 | self.batch_data = self.sort_and_pad(data, batch_size)
269 | self.len_data = len(self.batch_data)
270 |
271 | def sort_and_pad(self, data, batch_size):
272 | num_batch = int(math.ceil(len(data) /batch_size))
273 | sorted_data = sorted(data, key=lambda x: len(x[0]))
274 | batch_data = list()
275 | for i in range(num_batch):
276 | batch_data.append(self.arrange_batch(sorted_data[int(i*batch_size) : int((i+1)*batch_size)]))
277 | return batch_data
278 |
279 | @staticmethod
280 | def arrange_batch(batch):
281 | '''
282 | 把batch整理为一个[5, ]的数组
283 | :param batch:
284 | :return:
285 | '''
286 | strings = []
287 | segment_ids = []
288 | chars = []
289 | mask = []
290 | targets = []
291 | for string, seg_ids, char, msk, target in batch:
292 | strings.append(string)
293 | segment_ids.append(seg_ids)
294 | chars.append(char)
295 | mask.append(msk)
296 | targets.append(target)
297 | return [strings, segment_ids, chars, mask, targets]
298 |
299 | @staticmethod
300 | def pad_data(data):
301 | strings = []
302 | chars = []
303 | segs = []
304 | targets = []
305 | max_length = max([len(sentence[0]) for sentence in data])
306 | for line in data:
307 | string, segment_ids, char, seg, target = line
308 | padding = [0] * (max_length - len(string))
309 | strings.append(string + padding)
310 | chars.append(char + padding)
311 | segs.append(seg + padding)
312 | targets.append(target + padding)
313 | return [strings, chars, segs, targets]
314 |
315 | def iter_batch(self, shuffle=False):
316 | if shuffle:
317 | random.shuffle(self.batch_data)
318 | for idx in range(self.len_data):
319 | yield self.batch_data[idx]
320 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import codecs
4 |
5 | from bert import tokenization
6 | from utils import convert_single_example
7 |
8 | tokenizer = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt',
9 | do_lower_case=True)
10 |
11 | from data_utils import create_dico, create_mapping, zero_digits
12 | from data_utils import iob2, iob_iobes, get_seg_features
13 |
14 |
15 | def load_sentences(path, lower, zeros):
16 | """
17 | Load sentences. A line must contain at least a word and its tag.
18 | Sentences are separated by empty lines.
19 | """
20 | sentences = []
21 | sentence = []
22 | num = 0
23 | for line in codecs.open(path, 'r', 'utf8'):
24 | num+=1
25 | line = zero_digits(line.rstrip()) if zeros else line.rstrip()
26 | # print(list(line))
27 | if not line:
28 | if len(sentence) > 0:
29 | if 'DOCSTART' not in sentence[0][0]:
30 | sentences.append(sentence)
31 | sentence = []
32 | else:
33 | if line[0] == " ":
34 | line = "$" + line[1:]
35 | word = line.split()
36 | # word[0] = " "
37 | else:
38 | word= line.split()
39 | assert len(word) >= 2, print([word[0]])
40 | sentence.append(word)
41 | if len(sentence) > 0:
42 | if 'DOCSTART' not in sentence[0][0]:
43 | sentences.append(sentence)
44 | return sentences
45 |
46 |
47 | def update_tag_scheme(sentences, tag_scheme):
48 | """
49 | Check and update sentences tagging scheme to IOB2.
50 | Only IOB1 and IOB2 schemes are accepted.
51 | """
52 | for i, s in enumerate(sentences):
53 | tags = [w[-1] for w in s]
54 | # Check that tags are given in the IOB format
55 | if not iob2(tags):
56 | s_str = '\n'.join(' '.join(w) for w in s)
57 | raise Exception('Sentences should be given in IOB format! ' +
58 | 'Please check sentence %i:\n%s' % (i, s_str))
59 | if tag_scheme == 'iob':
60 | # If format was IOB1, we convert to IOB2
61 | for word, new_tag in zip(s, tags):
62 | word[-1] = new_tag
63 | elif tag_scheme == 'iobes':
64 | new_tags = iob_iobes(tags)
65 | for word, new_tag in zip(s, new_tags):
66 | word[-1] = new_tag
67 | else:
68 | raise Exception('Unknown tagging scheme!')
69 |
70 |
71 | def char_mapping(sentences, lower):
72 | """
73 | Create a dictionary and a mapping of words, sorted by frequency.
74 | """
75 | chars = [[x[0].lower() if lower else x[0] for x in s] for s in sentences]
76 | dico = create_dico(chars)
77 | dico[""] = 10000001
78 | dico[''] = 10000000
79 |
80 | char_to_id, id_to_char = create_mapping(dico)
81 | print("Found %i unique words (%i in total)" % (
82 | len(dico), sum(len(x) for x in chars)
83 | ))
84 | return dico, char_to_id, id_to_char
85 |
86 |
87 | def tag_mapping(sentences):
88 | """
89 | Create a dictionary and a mapping of tags, sorted by frequency.
90 | """
91 | tags = [[char[-1] for char in s] for s in sentences]
92 |
93 | dico = create_dico(tags)
94 | dico['[SEP]'] = len(dico) + 1
95 | dico['[CLS]'] = len(dico) + 2
96 |
97 | tag_to_id, id_to_tag = create_mapping(dico)
98 | print("Found %i unique named entity tags" % len(dico))
99 | return dico, tag_to_id, id_to_tag
100 |
101 |
102 | def prepare_dataset(sentences, max_seq_length, tag_to_id, lower=False, train=True):
103 | """
104 | Prepare the dataset. Return a list of lists of dictionaries containing:
105 | - word indexes
106 | - word char indexes
107 | - tag indexes
108 | """
109 | def f(x):
110 | return x.lower() if lower else x
111 | data = []
112 | for s in sentences:
113 | string = [w[0].strip() for w in s]
114 | #chars = [char_to_id[f(w) if f(w) in char_to_id else '']
115 | # for w in string]
116 | char_line = ' '.join(string) #使用空格把汉字拼起来
117 | text = tokenization.convert_to_unicode(char_line)
118 |
119 | if train:
120 | tags = [w[-1] for w in s]
121 | else:
122 | tags = ['O' for _ in string]
123 |
124 | labels = ' '.join(tags) #使用空格把标签拼起来
125 | labels = tokenization.convert_to_unicode(labels)
126 |
127 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text,
128 | tag_to_id=tag_to_id,
129 | max_seq_length=max_seq_length,
130 | tokenizer=tokenizer,
131 | label_line=labels)
132 | data.append([string, segment_ids, ids, mask, label_ids])
133 |
134 | return data
135 |
136 |
137 | def input_from_line(line, max_seq_length, tag_to_id):
138 | """
139 | Take sentence data and return an input for
140 | the training or the evaluation function.
141 | """
142 | string = [w[0].strip() for w in line]
143 | # chars = [char_to_id[f(w) if f(w) in char_to_id else '']
144 | # for w in string]
145 | char_line = ' '.join(string) # 使用空格把汉字拼起来
146 | text = tokenization.convert_to_unicode(char_line)
147 |
148 | tags = ['O' for _ in string]
149 |
150 | labels = ' '.join(tags) # 使用空格把标签拼起来
151 | labels = tokenization.convert_to_unicode(labels)
152 |
153 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text,
154 | tag_to_id=tag_to_id,
155 | max_seq_length=max_seq_length,
156 | tokenizer=tokenizer,
157 | label_line=labels)
158 | import numpy as np
159 | segment_ids = np.reshape(segment_ids,(1, max_seq_length))
160 | ids = np.reshape(ids, (1, max_seq_length))
161 | mask = np.reshape(mask, (1, max_seq_length))
162 | label_ids = np.reshape(label_ids, (1, max_seq_length))
163 | return [string, segment_ids, ids, mask, label_ids]
164 |
165 | def augment_with_pretrained(dictionary, ext_emb_path, chars):
166 | """
167 | Augment the dictionary with words that have a pretrained embedding.
168 | If `words` is None, we add every word that has a pretrained embedding
169 | to the dictionary, otherwise, we only add the words that are given by
170 | `words` (typically the words in the development and test sets.)
171 | """
172 | print('Loading pretrained embeddings from %s...' % ext_emb_path)
173 | assert os.path.isfile(ext_emb_path)
174 |
175 | # Load pretrained embeddings from file
176 | pretrained = set([
177 | line.rstrip().split()[0].strip()
178 | for line in codecs.open(ext_emb_path, 'r', 'utf-8')
179 | if len(ext_emb_path) > 0
180 | ])
181 |
182 | # We either add every word in the pretrained file,
183 | # or only words given in the `words` list to which
184 | # we can assign a pretrained embedding
185 | if chars is None:
186 | for char in pretrained:
187 | if char not in dictionary:
188 | dictionary[char] = 0
189 | else:
190 | for char in chars:
191 | if any(x in pretrained for x in [
192 | char,
193 | char.lower(),
194 | re.sub('\d', '0', char.lower())
195 | ]) and char not in dictionary:
196 | dictionary[char] = 0
197 |
198 | word_to_id, id_to_word = create_mapping(dictionary)
199 | return dictionary, word_to_id, id_to_word
200 |
201 |
202 | def save_maps(save_path, *params):
203 | """
204 | Save mappings and invert mappings
205 | """
206 | pass
207 | # with codecs.open(save_path, "w", encoding="utf8") as f:
208 | # pickle.dump(params, f)
209 |
210 |
211 | def load_maps(save_path):
212 | """
213 | Load mappings from the file
214 | """
215 | pass
216 | # with codecs.open(save_path, "r", encoding="utf8") as f:
217 | # pickle.load(save_path, f)
218 |
219 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # encoding = utf8
2 | import numpy as np
3 | import tensorflow as tf
4 | from tensorflow.contrib.crf import crf_log_likelihood
5 | from tensorflow.contrib.crf import viterbi_decode
6 | from tensorflow.contrib.layers.python.layers import initializers
7 |
8 | import rnncell as rnn
9 | from utils import bio_to_json
10 | from bert import modeling
11 |
12 | class Model(object):
13 | def __init__(self, config):
14 |
15 | self.config = config
16 | self.lr = config["lr"]
17 | self.lstm_dim = config["lstm_dim"]
18 | self.num_tags = config["num_tags"]
19 |
20 | self.global_step = tf.Variable(0, trainable=False)
21 | self.best_dev_f1 = tf.Variable(0.0, trainable=False)
22 | self.best_test_f1 = tf.Variable(0.0, trainable=False)
23 | self.initializer = initializers.xavier_initializer()
24 |
25 | # add placeholders for the model
26 | self.input_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_ids")
27 | self.input_mask = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_mask")
28 | self.segment_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="segment_ids")
29 | self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="Targets")
30 | # dropout keep prob
31 | self.dropout = tf.placeholder(dtype=tf.float32, name="Dropout")
32 |
33 | used = tf.sign(tf.abs(self.input_ids))
34 | length = tf.reduce_sum(used, reduction_indices=1)
35 | self.lengths = tf.cast(length, tf.int32)
36 | self.batch_size = tf.shape(self.input_ids)[0]
37 | self.num_steps = tf.shape(self.input_ids)[-1]
38 |
39 | # embeddings for chinese character and segmentation representation
40 | embedding = self.bert_embedding()
41 |
42 | # apply dropout before feed to lstm layer
43 | lstm_inputs = tf.nn.dropout(embedding, self.dropout)
44 |
45 | # bi-directional lstm layer
46 | lstm_outputs = self.biLSTM_layer(lstm_inputs, self.lstm_dim, self.lengths)
47 |
48 | # logits for tags
49 | self.logits = self.project_layer(lstm_outputs)
50 |
51 | # loss of the model
52 | self.loss = self.loss_layer(self.logits, self.lengths)
53 |
54 | # bert模型参数初始化的地方
55 | init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
56 | # 获取模型中所有的训练参数。
57 | tvars = tf.trainable_variables()
58 | # 加载BERT模型
59 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
60 | init_checkpoint)
61 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
62 | print("**** Trainable Variables ****")
63 | # 打印加载模型的参数
64 | train_vars = []
65 | for var in tvars:
66 | init_string = ""
67 | if var.name in initialized_variable_names:
68 | init_string = ", *INIT_FROM_CKPT*"
69 | else:
70 | train_vars.append(var)
71 | print(" name = %s, shape = %s%s", var.name, var.shape,
72 | init_string)
73 | with tf.variable_scope("optimizer"):
74 | optimizer = self.config["optimizer"]
75 | if optimizer == "adam":
76 | self.opt = tf.train.AdamOptimizer(self.lr)
77 | else:
78 | raise KeyError
79 |
80 | grads = tf.gradients(self.loss, train_vars)
81 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
82 |
83 | self.train_op = self.opt.apply_gradients(
84 | zip(grads, train_vars), global_step=self.global_step)
85 | #capped_grads_vars = [[tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v]
86 | # for g, v in grads_vars if g is not None]
87 | #self.train_op = self.opt.apply_gradients(capped_grads_vars, self.global_step, )
88 |
89 | # saver of the model
90 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
91 |
92 | def bert_embedding(self):
93 | # load bert embedding
94 | bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json") # 配置文件地址。
95 | model = modeling.BertModel(
96 | config=bert_config,
97 | is_training=True,
98 | input_ids=self.input_ids,
99 | input_mask=self.input_mask,
100 | token_type_ids=self.segment_ids,
101 | use_one_hot_embeddings=False)
102 | embedding = model.get_sequence_output()
103 | return embedding
104 |
105 | def biLSTM_layer(self, lstm_inputs, lstm_dim, lengths, name=None):
106 | """
107 | :param lstm_inputs: [batch_size, num_steps, emb_size]
108 | :return: [batch_size, num_steps, 2*lstm_dim]
109 | """
110 | with tf.variable_scope("char_BiLSTM" if not name else name):
111 | lstm_cell = {}
112 | for direction in ["forward", "backward"]:
113 | with tf.variable_scope(direction):
114 | lstm_cell[direction] = rnn.CoupledInputForgetGateLSTMCell(
115 | lstm_dim,
116 | use_peepholes=True,
117 | initializer=self.initializer,
118 | state_is_tuple=True)
119 | outputs, final_states = tf.nn.bidirectional_dynamic_rnn(
120 | lstm_cell["forward"],
121 | lstm_cell["backward"],
122 | lstm_inputs,
123 | dtype=tf.float32,
124 | sequence_length=lengths)
125 | return tf.concat(outputs, axis=2)
126 |
127 | def project_layer(self, lstm_outputs, name=None):
128 | """
129 | hidden layer between lstm layer and logits
130 | :param lstm_outputs: [batch_size, num_steps, emb_size]
131 | :return: [batch_size, num_steps, num_tags]
132 | """
133 | with tf.variable_scope("project" if not name else name):
134 | with tf.variable_scope("hidden"):
135 | W = tf.get_variable("W", shape=[self.lstm_dim*2, self.lstm_dim],
136 | dtype=tf.float32, initializer=self.initializer)
137 |
138 | b = tf.get_variable("b", shape=[self.lstm_dim], dtype=tf.float32,
139 | initializer=tf.zeros_initializer())
140 | output = tf.reshape(lstm_outputs, shape=[-1, self.lstm_dim*2])
141 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b))
142 |
143 | # project to score of tags
144 | with tf.variable_scope("logits"):
145 | W = tf.get_variable("W", shape=[self.lstm_dim, self.num_tags],
146 | dtype=tf.float32, initializer=self.initializer)
147 |
148 | b = tf.get_variable("b", shape=[self.num_tags], dtype=tf.float32,
149 | initializer=tf.zeros_initializer())
150 |
151 | pred = tf.nn.xw_plus_b(hidden, W, b)
152 |
153 | return tf.reshape(pred, [-1, self.num_steps, self.num_tags])
154 |
155 | def loss_layer(self, project_logits, lengths, name=None):
156 | """
157 | calculate crf loss
158 | :param project_logits: [1, num_steps, num_tags]
159 | :return: scalar loss
160 | """
161 | with tf.variable_scope("crf_loss" if not name else name):
162 | small = -1000.0
163 | # pad logits for crf loss
164 | start_logits = tf.concat(
165 | [small * tf.ones(shape=[self.batch_size, 1, self.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])], axis=-1)
166 | pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
167 | logits = tf.concat([project_logits, pad_logits], axis=-1)
168 | logits = tf.concat([start_logits, logits], axis=1)
169 | targets = tf.concat(
170 | [tf.cast(self.num_tags*tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1)
171 |
172 | self.trans = tf.get_variable(
173 | "transitions",
174 | shape=[self.num_tags + 1, self.num_tags + 1],
175 | initializer=self.initializer)
176 | log_likelihood, self.trans = crf_log_likelihood(
177 | inputs=logits,
178 | tag_indices=targets,
179 | transition_params=self.trans,
180 | sequence_lengths=lengths+1)
181 | return tf.reduce_mean(-log_likelihood)
182 |
183 | def create_feed_dict(self, is_train, batch):
184 | """
185 | :param is_train: Flag, True for train batch
186 | :param batch: list train/evaluate data
187 | :return: structured data to feed
188 | """
189 | _, segment_ids, chars, mask, tags = batch
190 | feed_dict = {
191 | self.input_ids: np.asarray(chars),
192 | self.input_mask: np.asarray(mask),
193 | self.segment_ids: np.asarray(segment_ids),
194 | self.dropout: 1.0,
195 | }
196 | if is_train:
197 | feed_dict[self.targets] = np.asarray(tags)
198 | feed_dict[self.dropout] = self.config["dropout_keep"]
199 | return feed_dict
200 |
201 | def run_step(self, sess, is_train, batch):
202 | """
203 | :param sess: session to run the batch
204 | :param is_train: a flag indicate if it is a train batch
205 | :param batch: a dict containing batch data
206 | :return: batch result, loss of the batch or logits
207 | """
208 | feed_dict = self.create_feed_dict(is_train, batch)
209 | if is_train:
210 | global_step, loss, _ = sess.run(
211 | [self.global_step, self.loss, self.train_op],
212 | feed_dict)
213 | return global_step, loss
214 | else:
215 | lengths, logits = sess.run([self.lengths, self.logits], feed_dict)
216 | return lengths, logits
217 |
218 | def decode(self, logits, lengths, matrix):
219 | """
220 | :param logits: [batch_size, num_steps, num_tags]float32, logits
221 | :param lengths: [batch_size]int32, real length of each sequence
222 | :param matrix: transaction matrix for inference
223 | :return:
224 | """
225 | # inference final labels usa viterbi Algorithm
226 | paths = []
227 | small = -1000.0
228 | start = np.asarray([[small]*self.num_tags +[0]])
229 | for score, length in zip(logits, lengths):
230 | score = score[:length]
231 | pad = small * np.ones([length, 1])
232 | logits = np.concatenate([score, pad], axis=1)
233 | logits = np.concatenate([start, logits], axis=0)
234 | path, _ = viterbi_decode(logits, matrix)
235 |
236 | paths.append(path[1:])
237 | return paths
238 |
239 | def evaluate(self, sess, data_manager, id_to_tag):
240 | """
241 | :param sess: session to run the model
242 | :param data: list of data
243 | :param id_to_tag: index to tag name
244 | :return: evaluate result
245 | """
246 | results = []
247 | trans = self.trans.eval()
248 | for batch in data_manager.iter_batch():
249 | strings = batch[0]
250 | labels = batch[-1]
251 | lengths, scores = self.run_step(sess, False, batch)
252 | batch_paths = self.decode(scores, lengths, trans)
253 | for i in range(len(strings)):
254 | result = []
255 | string = strings[i][:lengths[i]]
256 | gold = [id_to_tag[int(x)] for x in labels[i][1:lengths[i]]]
257 | pred = [id_to_tag[int(x)] for x in batch_paths[i][1:lengths[i]]]
258 | for char, gold, pred in zip(string, gold, pred):
259 | result.append(" ".join([char, gold, pred]))
260 | results.append(result)
261 | return results
262 |
263 | def evaluate_line(self, sess, inputs, id_to_tag):
264 | trans = self.trans.eval(sess)
265 | lengths, scores = self.run_step(sess, False, inputs)
266 | batch_paths = self.decode(scores, lengths, trans)
267 | tags = [id_to_tag[idx] for idx in batch_paths[0]]
268 | return bio_to_json(inputs[0], tags[1:-1])
--------------------------------------------------------------------------------
/pictures/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumath/bertNER/ee044eb1333b532bed2147610abb4d600b1cd8cf/pictures/results.png
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import tensorflow as tf
4 | from utils import create_model, get_logger
5 | from model import Model
6 | from loader import input_from_line
7 | from train import FLAGS, load_config
8 |
9 | def main(_):
10 | config = load_config(FLAGS.config_file)
11 | logger = get_logger(FLAGS.log_file)
12 | # limit GPU memory
13 | tf_config = tf.ConfigProto()
14 | tf_config.gpu_options.allow_growth = True
15 | with open(FLAGS.map_file, "rb") as f:
16 | tag_to_id, id_to_tag = pickle.load(f)
17 | with tf.Session(config=tf_config) as sess:
18 | model = create_model(sess, Model, FLAGS.ckpt_path, config, logger)
19 | while True:
20 | line = input("input sentence, please:")
21 | result = model.evaluate_line(sess, input_from_line(line, FLAGS.max_seq_len, tag_to_id), id_to_tag)
22 | print(result['entities'])
23 |
24 | if __name__ == '__main__':
25 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
26 | tf.app.run(main)
--------------------------------------------------------------------------------
/rnncell.py:
--------------------------------------------------------------------------------
1 | """Module for constructing RNN Cells."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import collections
7 | import math
8 | import tensorflow as tf
9 | from tensorflow.contrib.compiler import jit
10 | from tensorflow.contrib.layers.python.layers import layers
11 | from tensorflow.python.framework import dtypes
12 | from tensorflow.python.framework import op_def_registry
13 | from tensorflow.python.framework import ops
14 | from tensorflow.python.ops import array_ops
15 | from tensorflow.python.ops import clip_ops
16 | from tensorflow.python.ops import init_ops
17 | from tensorflow.python.ops import math_ops
18 | from tensorflow.python.ops import nn_ops
19 | from tensorflow.python.ops import random_ops
20 | from tensorflow.python.ops import rnn_cell_impl
21 | from tensorflow.python.ops import variable_scope as vs
22 | from tensorflow.python.platform import tf_logging as logging
23 | from tensorflow.python.util import nest
24 |
25 |
26 | def _get_concat_variable(name, shape, dtype, num_shards):
27 | """Get a sharded variable concatenated into one tensor."""
28 | sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
29 | if len(sharded_variable) == 1:
30 | return sharded_variable[0]
31 |
32 | concat_name = name + "/concat"
33 | concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
34 | for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
35 | if value.name == concat_full_name:
36 | return value
37 |
38 | concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
39 | ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
40 | concat_variable)
41 | return concat_variable
42 |
43 |
44 | def _get_sharded_variable(name, shape, dtype, num_shards):
45 | """Get a list of sharded variables with the given dtype."""
46 | if num_shards > shape[0]:
47 | raise ValueError("Too many shards: shape=%s, num_shards=%d" %
48 | (shape, num_shards))
49 | unit_shard_size = int(math.floor(shape[0] / num_shards))
50 | remaining_rows = shape[0] - unit_shard_size * num_shards
51 |
52 | shards = []
53 | for i in range(num_shards):
54 | current_size = unit_shard_size
55 | if i < remaining_rows:
56 | current_size += 1
57 | shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
58 | dtype=dtype))
59 | return shards
60 |
61 |
62 | class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
63 | """Long short-term memory unit (LSTM) recurrent network cell.
64 |
65 | The default non-peephole implementation is based on:
66 |
67 | http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
68 |
69 | S. Hochreiter and J. Schmidhuber.
70 | "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
71 |
72 | The peephole implementation is based on:
73 |
74 | https://research.google.com/pubs/archive/43905.pdf
75 |
76 | Hasim Sak, Andrew Senior, and Francoise Beaufays.
77 | "Long short-term memory recurrent neural network architectures for
78 | large scale acoustic modeling." INTERSPEECH, 2014.
79 |
80 | The coupling of input and forget gate is based on:
81 |
82 | http://arxiv.org/pdf/1503.04069.pdf
83 |
84 | Greff et al. "LSTM: A Search Space Odyssey"
85 |
86 | The class uses optional peep-hole connections, and an optional projection
87 | layer.
88 | """
89 |
90 | def __init__(self, num_units, use_peepholes=False,
91 | initializer=None, num_proj=None, proj_clip=None,
92 | num_unit_shards=1, num_proj_shards=1,
93 | forget_bias=1.0, state_is_tuple=True,
94 | activation=math_ops.tanh, reuse=None):
95 | """Initialize the parameters for an LSTM cell.
96 |
97 | Args:
98 | num_units: int, The number of units in the LSTM cell
99 | use_peepholes: bool, set True to enable diagonal/peephole connections.
100 | initializer: (optional) The initializer to use for the weight and
101 | projection matrices.
102 | num_proj: (optional) int, The output dimensionality for the projection
103 | matrices. If None, no projection is performed.
104 | proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
105 | provided, then the projected values are clipped elementwise to within
106 | `[-proj_clip, proj_clip]`.
107 | num_unit_shards: How to split the weight matrix. If >1, the weight
108 | matrix is stored across num_unit_shards.
109 | num_proj_shards: How to split the projection matrix. If >1, the
110 | projection matrix is stored across num_proj_shards.
111 | forget_bias: Biases of the forget gate are initialized by default to 1
112 | in order to reduce the scale of forgetting at the beginning of
113 | the training.
114 | state_is_tuple: If True, accepted and returned states are 2-tuples of
115 | the `c_state` and `m_state`. By default (False), they are concatenated
116 | along the column axis. This default behavior will soon be deprecated.
117 | activation: Activation function of the inner states.
118 | reuse: (optional) Python boolean describing whether to reuse variables
119 | in an existing scope. If not `True`, and the existing scope already has
120 | the given variables, an error is raised.
121 | """
122 | super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
123 | if not state_is_tuple:
124 | logging.warn(
125 | "%s: Using a concatenated state is slower and will soon be "
126 | "deprecated. Use state_is_tuple=True.", self)
127 | self._num_units = num_units
128 | self._use_peepholes = use_peepholes
129 | self._initializer = initializer
130 | self._num_proj = num_proj
131 | self._proj_clip = proj_clip
132 | self._num_unit_shards = num_unit_shards
133 | self._num_proj_shards = num_proj_shards
134 | self._forget_bias = forget_bias
135 | self._state_is_tuple = state_is_tuple
136 | self._activation = activation
137 | self._reuse = reuse
138 |
139 | if num_proj:
140 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
141 | if state_is_tuple else num_units + num_proj)
142 | self._output_size = num_proj
143 | else:
144 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
145 | if state_is_tuple else 2 * num_units)
146 | self._output_size = num_units
147 |
148 | @property
149 | def state_size(self):
150 | return self._state_size
151 |
152 | @property
153 | def output_size(self):
154 | return self._output_size
155 |
156 | def call(self, inputs, state):
157 | """Run one step of LSTM.
158 |
159 | Args:
160 | inputs: input Tensor, 2D, batch x num_units.
161 | state: if `state_is_tuple` is False, this must be a state Tensor,
162 | `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
163 | tuple of state Tensors, both `2-D`, with column sizes `c_state` and
164 | `m_state`.
165 | scope: VariableScope for the created subgraph; defaults to "LSTMCell".
166 |
167 | Returns:
168 | A tuple containing:
169 | - A `2-D, [batch x output_dim]`, Tensor representing the output of the
170 | LSTM after reading `inputs` when previous state was `state`.
171 | Here output_dim is:
172 | num_proj if num_proj was set,
173 | num_units otherwise.
174 | - Tensor(s) representing the new state of LSTM after reading `inputs` when
175 | the previous state was `state`. Same type and shape(s) as `state`.
176 |
177 | Raises:
178 | ValueError: If input size cannot be inferred from inputs via
179 | static shape inference.
180 | """
181 | sigmoid = math_ops.sigmoid
182 |
183 | num_proj = self._num_units if self._num_proj is None else self._num_proj
184 |
185 | if self._state_is_tuple:
186 | (c_prev, m_prev) = state
187 | else:
188 | c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
189 | m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
190 |
191 | dtype = inputs.dtype
192 | input_size = inputs.get_shape().with_rank(2)[1]
193 |
194 | if input_size.value is None:
195 | raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
196 |
197 | # Input gate weights
198 | self.w_xi = tf.get_variable("_w_xi", [input_size.value, self._num_units])
199 | self.w_hi = tf.get_variable("_w_hi", [self._num_units, self._num_units])
200 | self.w_ci = tf.get_variable("_w_ci", [self._num_units, self._num_units])
201 | # Output gate weights
202 | self.w_xo = tf.get_variable("_w_xo", [input_size.value, self._num_units])
203 | self.w_ho = tf.get_variable("_w_ho", [self._num_units, self._num_units])
204 | self.w_co = tf.get_variable("_w_co", [self._num_units, self._num_units])
205 |
206 | # Cell weights
207 | self.w_xc = tf.get_variable("_w_xc", [input_size.value, self._num_units])
208 | self.w_hc = tf.get_variable("_w_hc", [self._num_units, self._num_units])
209 |
210 | # Initialize the bias vectors
211 | self.b_i = tf.get_variable("_b_i", [self._num_units], initializer=init_ops.zeros_initializer())
212 | self.b_c = tf.get_variable("_b_c", [self._num_units], initializer=init_ops.zeros_initializer())
213 | self.b_o = tf.get_variable("_b_o", [self._num_units], initializer=init_ops.zeros_initializer())
214 |
215 | i_t = sigmoid(math_ops.matmul(inputs, self.w_xi) +
216 | math_ops.matmul(m_prev, self.w_hi) +
217 | math_ops.matmul(c_prev, self.w_ci) +
218 | self.b_i)
219 | c_t = ((1 - i_t) * c_prev + i_t * self._activation(math_ops.matmul(inputs, self.w_xc) +
220 | math_ops.matmul(m_prev, self.w_hc) + self.b_c))
221 |
222 | o_t = sigmoid(math_ops.matmul(inputs, self.w_xo) +
223 | math_ops.matmul(m_prev, self.w_ho) +
224 | math_ops.matmul(c_t, self.w_co) +
225 | self.b_o)
226 |
227 | h_t = o_t * self._activation(c_t)
228 |
229 | new_state = (rnn_cell_impl.LSTMStateTuple(c_t, h_t) if self._state_is_tuple else
230 | array_ops.concat([c_t, h_t], 1))
231 | return h_t, new_state
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # encoding=utf8
2 | import os
3 | import codecs
4 | import pickle
5 | import itertools
6 | from collections import OrderedDict
7 |
8 | import tensorflow as tf
9 | import numpy as np
10 | from model import Model
11 | from loader import load_sentences, update_tag_scheme
12 | from loader import char_mapping, tag_mapping
13 | from loader import augment_with_pretrained, prepare_dataset
14 | from utils import get_logger, make_path, clean, create_model, save_model
15 | from utils import print_config, save_config, load_config, test_ner
16 | from data_utils import create_input, BatchManager
17 |
18 | flags = tf.app.flags
19 | flags.DEFINE_boolean("clean", False, "clean train folder")
20 | flags.DEFINE_boolean("train", False, "Wither train the model")
21 | # configurations for the model
22 | flags.DEFINE_integer("batch_size", 128, "batch size")
23 | flags.DEFINE_integer("seg_dim", 20, "Embedding size for segmentation, 0 if not used")
24 | flags.DEFINE_integer("char_dim", 100, "Embedding size for characters")
25 | flags.DEFINE_integer("lstm_dim", 200, "Num of hidden units in LSTM")
26 | flags.DEFINE_string("tag_schema", "iob", "tagging schema iobes or iob")
27 |
28 | # configurations for training
29 | flags.DEFINE_float("clip", 5, "Gradient clip")
30 | flags.DEFINE_float("dropout", 0.5, "Dropout rate")
31 | flags.DEFINE_float("lr", 0.001, "Initial learning rate")
32 | flags.DEFINE_string("optimizer", "adam", "Optimizer for training")
33 | flags.DEFINE_boolean("zeros", False, "Wither replace digits with zero")
34 | flags.DEFINE_boolean("lower", True, "Wither lower case")
35 |
36 | flags.DEFINE_integer("max_seq_len", 128, "max sequence length for bert")
37 | flags.DEFINE_integer("max_epoch", 100, "maximum training epochs")
38 | flags.DEFINE_integer("steps_check", 100, "steps per checkpoint")
39 | flags.DEFINE_string("ckpt_path", "ckpt", "Path to save model")
40 | flags.DEFINE_string("summary_path", "summary", "Path to store summaries")
41 | flags.DEFINE_string("log_file", "train.log", "File for log")
42 | flags.DEFINE_string("map_file", "maps.pkl", "file for maps")
43 | flags.DEFINE_string("vocab_file", "vocab.json", "File for vocab")
44 | flags.DEFINE_string("config_file", "config_file", "File for config")
45 | flags.DEFINE_string("script", "conlleval", "evaluation script")
46 | flags.DEFINE_string("result_path", "result", "Path for results")
47 | flags.DEFINE_string("train_file", os.path.join("data", "example.train"), "Path for train data")
48 | flags.DEFINE_string("dev_file", os.path.join("data", "example.dev"), "Path for dev data")
49 | flags.DEFINE_string("test_file", os.path.join("data", "example.test"), "Path for test data")
50 |
51 |
52 | FLAGS = tf.app.flags.FLAGS
53 | assert FLAGS.clip < 5.1, "gradient clip should't be too much"
54 | assert 0 <= FLAGS.dropout < 1, "dropout rate between 0 and 1"
55 | assert FLAGS.lr > 0, "learning rate must larger than zero"
56 | assert FLAGS.optimizer in ["adam", "sgd", "adagrad"]
57 |
58 |
59 | # config for the model
60 | def config_model(tag_to_id):
61 | config = OrderedDict()
62 | config["num_tags"] = len(tag_to_id)
63 | config["lstm_dim"] = FLAGS.lstm_dim
64 | config["batch_size"] = FLAGS.batch_size
65 | config['max_seq_len'] = FLAGS.max_seq_len
66 |
67 | config["clip"] = FLAGS.clip
68 | config["dropout_keep"] = 1.0 - FLAGS.dropout
69 | config["optimizer"] = FLAGS.optimizer
70 | config["lr"] = FLAGS.lr
71 | config["tag_schema"] = FLAGS.tag_schema
72 | config["zeros"] = FLAGS.zeros
73 | config["lower"] = FLAGS.lower
74 | return config
75 |
76 | def evaluate(sess, model, name, data, id_to_tag, logger):
77 | logger.info("evaluate:{}".format(name))
78 | ner_results = model.evaluate(sess, data, id_to_tag)
79 | eval_lines = test_ner(ner_results, FLAGS.result_path)
80 | for line in eval_lines:
81 | logger.info(line)
82 | f1 = float(eval_lines[1].strip().split()[-1])
83 |
84 | if name == "dev":
85 | best_test_f1 = model.best_dev_f1.eval()
86 | if f1 > best_test_f1:
87 | tf.assign(model.best_dev_f1, f1).eval()
88 | logger.info("new best dev f1 score:{:>.3f}".format(f1))
89 | return f1 > best_test_f1
90 | elif name == "test":
91 | best_test_f1 = model.best_test_f1.eval()
92 | if f1 > best_test_f1:
93 | tf.assign(model.best_test_f1, f1).eval()
94 | logger.info("new best test f1 score:{:>.3f}".format(f1))
95 | return f1 > best_test_f1
96 |
97 | def train():
98 | # load data sets
99 | train_sentences = load_sentences(FLAGS.train_file, FLAGS.lower, FLAGS.zeros)
100 | dev_sentences = load_sentences(FLAGS.dev_file, FLAGS.lower, FLAGS.zeros)
101 | test_sentences = load_sentences(FLAGS.test_file, FLAGS.lower, FLAGS.zeros)
102 |
103 | # Use selected tagging scheme (IOB / IOBES)
104 | #update_tag_scheme(train_sentences, FLAGS.tag_schema)
105 | #update_tag_scheme(test_sentences, FLAGS.tag_schema)
106 |
107 | # create maps if not exist
108 | if not os.path.isfile(FLAGS.map_file):
109 | # Create a dictionary and a mapping for tags
110 | _t, tag_to_id, id_to_tag = tag_mapping(train_sentences)
111 | with open(FLAGS.map_file, "wb") as f:
112 | pickle.dump([tag_to_id, id_to_tag], f)
113 | else:
114 | with open(FLAGS.map_file, "rb") as f:
115 | tag_to_id, id_to_tag = pickle.load(f)
116 |
117 | # prepare data, get a collection of list containing index
118 | train_data = prepare_dataset(
119 | train_sentences, FLAGS.max_seq_len, tag_to_id, FLAGS.lower
120 | )
121 | dev_data = prepare_dataset(
122 | dev_sentences, FLAGS.max_seq_len, tag_to_id, FLAGS.lower
123 | )
124 | test_data = prepare_dataset(
125 | test_sentences, FLAGS.max_seq_len, tag_to_id, FLAGS.lower
126 | )
127 | print("%i / %i / %i sentences in train / dev / test." % (
128 | len(train_data), 0, len(test_data)))
129 |
130 | train_manager = BatchManager(train_data, FLAGS.batch_size)
131 | dev_manager = BatchManager(dev_data, FLAGS.batch_size)
132 | test_manager = BatchManager(test_data, FLAGS.batch_size)
133 | # make path for store log and model if not exist
134 | make_path(FLAGS)
135 | if os.path.isfile(FLAGS.config_file):
136 | config = load_config(FLAGS.config_file)
137 | else:
138 | config = config_model(tag_to_id)
139 | save_config(config, FLAGS.config_file)
140 | make_path(FLAGS)
141 |
142 | log_path = os.path.join("log", FLAGS.log_file)
143 | logger = get_logger(log_path)
144 | print_config(config, logger)
145 |
146 | # limit GPU memory
147 | tf_config = tf.ConfigProto()
148 | tf_config.gpu_options.allow_growth = True
149 | steps_per_epoch = train_manager.len_data
150 | with tf.Session(config=tf_config) as sess:
151 | model = create_model(sess, Model, FLAGS.ckpt_path, config, logger)
152 |
153 | logger.info("start training")
154 | loss = []
155 | for i in range(100):
156 | for batch in train_manager.iter_batch(shuffle=True):
157 | step, batch_loss = model.run_step(sess, True, batch)
158 |
159 | loss.append(batch_loss)
160 | if step % FLAGS.steps_check == 0:
161 | iteration = step // steps_per_epoch + 1
162 | logger.info("iteration:{} step:{}/{}, "
163 | "NER loss:{:>9.6f}".format(
164 | iteration, step%steps_per_epoch, steps_per_epoch, np.mean(loss)))
165 | loss = []
166 |
167 | best = evaluate(sess, model, "dev", dev_manager, id_to_tag, logger)
168 | if best:
169 | save_model(sess, model, FLAGS.ckpt_path, logger, global_steps=step)
170 | evaluate(sess, model, "test", test_manager, id_to_tag, logger)
171 |
172 | def main(_):
173 | FLAGS.train = True
174 | FLAGS.clean = True
175 | clean(FLAGS)
176 | train()
177 |
178 | if __name__ == "__main__":
179 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
180 | tf.app.run(main)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import json
4 | import shutil
5 | import logging
6 | import codecs
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | from conlleval import return_report
11 |
12 | from bert import modeling
13 |
14 | models_path = "./models"
15 | eval_path = "./evaluation"
16 | eval_temp = os.path.join(eval_path, "temp")
17 | eval_script = os.path.join(eval_path, "conlleval")
18 |
19 |
20 | def get_logger(log_file):
21 | logger = logging.getLogger(log_file)
22 | logger.setLevel(logging.DEBUG)
23 | fh = logging.FileHandler(log_file)
24 | fh.setLevel(logging.DEBUG)
25 | ch = logging.StreamHandler()
26 | ch.setLevel(logging.INFO)
27 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
28 | ch.setFormatter(formatter)
29 | fh.setFormatter(formatter)
30 | logger.addHandler(ch)
31 | logger.addHandler(fh)
32 | return logger
33 |
34 |
35 | # def test_ner(results, path):
36 | # """
37 | # Run perl script to evaluate model
38 | # """
39 | # script_file = "conlleval"
40 | # output_file = os.path.join(path, "ner_predict.utf8")
41 | # result_file = os.path.join(path, "ner_result.utf8")
42 | # with open(output_file, "w") as f:
43 | # to_write = []
44 | # for block in results:
45 | # for line in block:
46 | # to_write.append(line + "\n")
47 | # to_write.append("\n")
48 | #
49 | # f.writelines(to_write)
50 | # os.system("perl {} < {} > {}".format(script_file, output_file, result_file))
51 | # eval_lines = []
52 | # with open(result_file) as f:
53 | # for line in f:
54 | # eval_lines.append(line.strip())
55 | # return eval_lines
56 |
57 |
58 | def test_ner(results, path):
59 | """
60 | Run perl script to evaluate model
61 | """
62 | output_file = os.path.join(path, "ner_predict.utf8")
63 | with codecs.open(output_file, "w", 'utf8') as f:
64 | to_write = []
65 | for block in results:
66 | for line in block:
67 | to_write.append(line + "\n")
68 | to_write.append("\n")
69 |
70 | f.writelines(to_write)
71 | eval_lines = return_report(output_file)
72 | return eval_lines
73 |
74 |
75 | def print_config(config, logger):
76 | """
77 | Print configuration of the model
78 | """
79 | for k, v in config.items():
80 | logger.info("{}:\t{}".format(k.ljust(15), v))
81 |
82 |
83 | def make_path(params):
84 | """
85 | Make folders for training and evaluation
86 | """
87 | if not os.path.isdir(params.result_path):
88 | os.makedirs(params.result_path)
89 | if not os.path.isdir(params.ckpt_path):
90 | os.makedirs(params.ckpt_path)
91 | if not os.path.isdir("log"):
92 | os.makedirs("log")
93 |
94 |
95 | def clean(params):
96 | """
97 | Clean current folder
98 | remove saved model and training log
99 | """
100 | if os.path.isfile(params.vocab_file):
101 | os.remove(params.vocab_file)
102 |
103 | if os.path.isfile(params.map_file):
104 | os.remove(params.map_file)
105 |
106 | if os.path.isdir(params.ckpt_path):
107 | shutil.rmtree(params.ckpt_path)
108 |
109 | if os.path.isdir(params.summary_path):
110 | shutil.rmtree(params.summary_path)
111 |
112 | if os.path.isdir(params.result_path):
113 | shutil.rmtree(params.result_path)
114 |
115 | if os.path.isdir("log"):
116 | shutil.rmtree("log")
117 |
118 | if os.path.isdir("__pycache__"):
119 | shutil.rmtree("__pycache__")
120 |
121 | if os.path.isfile(params.config_file):
122 | os.remove(params.config_file)
123 |
124 | if os.path.isfile(params.vocab_file):
125 | os.remove(params.vocab_file)
126 |
127 |
128 | def save_config(config, config_file):
129 | """
130 | Save configuration of the model
131 | parameters are stored in json format
132 | """
133 | with open(config_file, "w", encoding="utf8") as f:
134 | json.dump(config, f, ensure_ascii=False, indent=4)
135 |
136 |
137 | def load_config(config_file):
138 | """
139 | Load configuration of the model
140 | parameters are stored in json format
141 | """
142 | with open(config_file, encoding="utf8") as f:
143 | return json.load(f)
144 |
145 |
146 | def convert_to_text(line):
147 | """
148 | Convert conll data to text
149 | """
150 | to_print = []
151 | for item in line:
152 |
153 | try:
154 | if item[0] == " ":
155 | to_print.append(" ")
156 | continue
157 | word, gold, tag = item.split(" ")
158 | if tag[0] in "SB":
159 | to_print.append("[")
160 | to_print.append(word)
161 | if tag[0] in "SE":
162 | to_print.append("@" + tag.split("-")[-1])
163 | to_print.append("]")
164 | except:
165 | print(list(item))
166 | return "".join(to_print)
167 |
168 |
169 | def save_model(sess, model, path, logger, global_steps):
170 | checkpoint_path = os.path.join(path, "ner.ckpt")
171 | model.saver.save(sess, checkpoint_path, global_step = global_steps)
172 | logger.info("model saved")
173 |
174 |
175 | def create_model(session, Model_class, path, config, logger):
176 | # create model, reuse parameters if exists
177 | model = Model_class(config)
178 |
179 | ckpt = tf.train.get_checkpoint_state(path)
180 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
181 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
182 | #saver = tf.train.import_meta_graph('ckpt/ner.ckpt.meta')
183 | #saver.restore(session, tf.train.latest_checkpoint("ckpt/"))
184 | model.saver.restore(session, ckpt.model_checkpoint_path)
185 | else:
186 | logger.info("Created model with fresh parameters.")
187 | session.run(tf.global_variables_initializer())
188 | return model
189 |
190 |
191 | def result_to_json(string, tags):
192 | item = {"string": string, "entities": []}
193 | entity_name = ""
194 | entity_start = 0
195 | idx = 0
196 | for char, tag in zip(string, tags):
197 | if tag[0] == "S":
198 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]})
199 | elif tag[0] == "B":
200 | entity_name += char
201 | entity_start = idx
202 | elif tag[0] == "I":
203 | entity_name += char
204 | elif tag[0] == "E":
205 | entity_name += char
206 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx + 1, "type": tag[2:]})
207 | entity_name = ""
208 | else:
209 | entity_name = ""
210 | entity_start = idx
211 | idx += 1
212 | return item
213 |
214 | def bio_to_json(string, tags):
215 | item = {"string": string, "entities": []}
216 | entity_name = ""
217 | entity_start = 0
218 | iCount = 0
219 | entity_tag = ""
220 | #assert len(string)==len(tags), "string length is: {}, tags length is: {}".format(len(string), len(tags))
221 |
222 | for c_idx in range(len(tags)):
223 | c, tag = string[c_idx], tags[c_idx]
224 | if c_idx < len(tags)-1:
225 | tag_next = tags[c_idx+1]
226 | else:
227 | tag_next = ''
228 |
229 | if tag[0] == 'B':
230 | entity_tag = tag[2:]
231 | entity_name = c
232 | entity_start = iCount
233 | if tag_next[2:] != entity_tag:
234 | item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]})
235 | elif tag[0] == "I":
236 | if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O':
237 | tags[c_idx] = 'O'
238 | pass
239 | else:
240 | entity_name = entity_name + c
241 | if tag_next[2:] != entity_tag:
242 | item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag})
243 | entity_name = ''
244 | iCount += 1
245 | return item
246 |
247 | def convert_single_example(char_line, tag_to_id, max_seq_length, tokenizer, label_line):
248 | """
249 | 将一个样本进行分析,然后将字转化为id, 标签转化为lb
250 | """
251 | text_list = char_line.split(' ')
252 | label_list = label_line.split(' ')
253 |
254 | tokens = []
255 | labels = []
256 | for i, word in enumerate(text_list):
257 | token = tokenizer.tokenize(word)
258 | tokens.extend(token)
259 | label_1 = label_list[i]
260 | for m in range(len(token)):
261 | if m == 0:
262 | labels.append(label_1)
263 | else:
264 | labels.append("X")
265 | # 序列截断
266 | if len(tokens) >= max_seq_length - 1:
267 | tokens = tokens[0:(max_seq_length - 2)]
268 | labels = labels[0:(max_seq_length - 2)]
269 | ntokens = []
270 | segment_ids = []
271 | label_ids = []
272 | ntokens.append("[CLS]")
273 | segment_ids.append(0)
274 | # append("O") or append("[CLS]") not sure!
275 | label_ids.append(tag_to_id["[CLS]"])
276 | for i, token in enumerate(tokens):
277 | ntokens.append(token)
278 | segment_ids.append(0)
279 | label_ids.append(tag_to_id[labels[i]])
280 | ntokens.append("[SEP]")
281 | segment_ids.append(0)
282 | # append("O") or append("[SEP]") not sure!
283 | label_ids.append(tag_to_id["[SEP]"])
284 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
285 | input_mask = [1] * len(input_ids)
286 |
287 | # padding
288 | while len(input_ids) < max_seq_length:
289 | input_ids.append(0)
290 | input_mask.append(0)
291 | segment_ids.append(0)
292 | # we don't concerned about it!
293 | label_ids.append(0)
294 | ntokens.append("**NULL**")
295 |
296 | return input_ids, input_mask, segment_ids, label_ids
--------------------------------------------------------------------------------