├── .gitignore ├── BERT fine-tune实践.md ├── README.md ├── conlleval ├── conlleval.py ├── data ├── example.dev ├── example.test └── example.train ├── data_utils.py ├── loader.py ├── model.py ├── pictures └── results.png ├── predict.py ├── rnncell.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | chinese_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001 3 | bert/__init__.py 4 | .DS_Store 5 | bert/CONTRIBUTING.md 6 | bert/create_pretraining_data.py 7 | bert/extract_features.py 8 | bert/LICENSE 9 | bert/modeling.py 10 | bert/modeling_test.py 11 | bert/multilingual.md 12 | bert/optimization.py 13 | bert/optimization_test.py 14 | bert/README.md 15 | bert/requirements.txt 16 | bert/run_classifier.py 17 | bert/run_pretraining.py 18 | bert/run_squad.py 19 | bert/sample_text.txt 20 | bert/tokenization.py 21 | bert/tokenization_test.py 22 | chinese_L-12_H-768_A-12/bert_config.json 23 | chinese_L-12_H-768_A-12/bert_model.ckpt.index 24 | chinese_L-12_H-768_A-12/bert_model.ckpt.meta 25 | chinese_L-12_H-768_A-12/vocab.txt 26 | data/.DS_Store 27 | .idea/BertNER.iml 28 | .idea/deployment.xml 29 | .idea/encodings.xml 30 | .idea/misc.xml 31 | .idea/modules.xml 32 | .idea/vcs.xml 33 | .idea/workspace.xml 34 | __pycache__/conlleval.cpython-36.pyc 35 | __pycache__/data_utils.cpython-36.pyc 36 | __pycache__/loader.cpython-36.pyc 37 | __pycache__/model.cpython-36.pyc 38 | __pycache__/rnncell.cpython-36.pyc 39 | __pycache__/utils.cpython-36.pyc 40 | bert/__pycache__/__init__.cpython-36.pyc 41 | bert/__pycache__/modeling.cpython-36.pyc 42 | bert/__pycache__/tokenization.cpython-36.pyc 43 | config_file 44 | maps.pkl 45 | train.log 46 | -------------------------------------------------------------------------------- /BERT fine-tune实践.md: -------------------------------------------------------------------------------- 1 | # BERT fine-tune实践 2 | 3 | 1. #### 下载模型&代码 4 | 5 | BERT代码及模型下载地址:https://github.com/google-research/bert 6 | 7 | 2. #### 保存路径 8 | 9 | 3. #### 加载模型参数 10 | 11 | 4. #### 将输入Tokenize 12 | 13 | ```python 14 | def convert_single_example(char_line, tag_to_id, max_seq_length, tokenizer, label_line): 15 | """ 16 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 17 | :param ex_index: index 18 | :param example: 一个样本 19 | :param label_list: 标签列表 20 | :param max_seq_length: 21 | :param tokenizer: 22 | :param mode: 23 | :return: 24 | """ 25 | text_list = char_line.split(' ') 26 | label_list = label_line.split(' ') 27 | 28 | tokens = [] 29 | labels = [] 30 | for i, word in enumerate(text_list): 31 | # 分词,如果是中文,就是分字,但是对于一些不在BERT的vocab.txt中得字符会被进行WordPice处理(例如中文的引号),可以将所有的分字操作替换为list(input) 32 | token = tokenizer.tokenize(word) 33 | tokens.extend(token) 34 | label_1 = label_list[i] 35 | for m in range(len(token)): 36 | if m == 0: 37 | labels.append(label_1) 38 | else: # 一般不会出现else 39 | labels.append("X") 40 | # 序列截断 41 | if len(tokens) >= max_seq_length - 1: 42 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 43 | labels = labels[0:(max_seq_length - 2)] 44 | ntokens = [] 45 | segment_ids = [] 46 | label_ids = [] 47 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 48 | segment_ids.append(0) 49 | # append("O") or append("[CLS]") not sure! 50 | label_ids.append(tag_to_id["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 51 | for i, token in enumerate(tokens): 52 | ntokens.append(token) 53 | segment_ids.append(0) 54 | label_ids.append(tag_to_id[labels[i]]) 55 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志 56 | segment_ids.append(0) 57 | # append("O") or append("[SEP]") not sure! 58 | label_ids.append(tag_to_id["[SEP]"]) 59 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 60 | input_mask = [1] * len(input_ids) 61 | 62 | # padding 63 | while len(input_ids) < max_seq_length: 64 | input_ids.append(0) 65 | input_mask.append(0) 66 | segment_ids.append(0) 67 | # we don't concerned about it! 68 | label_ids.append(0) 69 | ntokens.append("**NULL**") 70 | 71 | return input_ids, input_mask, segment_ids, label_ids 72 | ``` 73 | 74 | 5. #### 获得embedding 75 | 76 | 在上一过程中,将文本数据处理为以下函数的三个输入变量 77 | 78 | ```python 79 | # load bert embedding 80 | bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json") # 配置文件地址。 81 | model = modeling.BertModel( 82 | config=bert_config, 83 | is_training=True, 84 | input_ids=self.input_ids, 85 | input_mask=self.input_mask, 86 | token_type_ids=self.segment_ids, 87 | use_one_hot_embeddings=False) 88 | embedding = model.get_sequence_output() 89 | ``` 90 | 91 | 6. #### 冻结bert参数层 92 | 93 | ```python 94 | # bert模型参数初始化的地方 95 | init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt" 96 | # 获取模型中所有的训练参数。 97 | tvars = tf.trainable_variables() 98 | # 加载BERT模型 99 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint) 100 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 101 | print("**** Trainable Variables ****") 102 | # 打印加载模型的参数 103 | train_vars = [] 104 | for var in tvars: 105 | init_string = "" 106 | if var.name in initialized_variable_names: 107 | init_string = ", *INIT_FROM_CKPT*" 108 | else: 109 | train_vars.append(var) 110 | print(" name = %s, shape = %s%s", var.name, var.shape, 111 | init_string) 112 | grads = tf.gradients(self.loss, train_vars) 113 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 114 | 115 | self.train_op = self.opt.apply_gradients( 116 | zip(grads, train_vars), global_step=self.global_step) 117 | ``` 118 | 119 | 7. #### 开始训练task层 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *Bert-ChineseNER* 2 | 3 | ### *Introduction* 4 | 5 | 该项目是基于谷歌开源的BERT预训练模型,在中文NER任务上进行fine-tune。 6 | 7 | ### *Datasets & Model* 8 | 9 | 训练本模型的主要标记数据,来自于zjy-usas的[ChineseNER](https://github.com/zjy-ucas/ChineseNER)项目。本项目在原本的BiLSTM+CRF的框架前,添加了BERT模型作为embedding的特征获取层,预训练的中文BERT模型及代码来自于Google Research的[bert](https://github.com/google-research/bert)。 10 | 11 | ### *Results* 12 | 13 | 14 | 15 | 引入bert之后,可以看到在验证集上的F-1值在训练了16个epoch时就已经达到了**94.87**,并在测试集上达到了**93.68**,在这个数据集上的F-1值提升了两个多百分点。 16 | 17 | ### *Train* 18 | 19 | 1. 下载[bert模型代码](https://github.com/google-research/bert),放入本项目根目录 20 | 2. 下载[bert的中文预训练模型](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip),解压放入本项目根目录 21 | 3. 搭建依赖环境python3+tensorflow1.12 22 | 4. 执行`python3 train.py`即可训练模型 23 | 5. 执行`python3 predict.py`可以对单句进行测试 24 | 25 | 整理后的项目目录,应如下所示: 26 | 27 | ``` . 28 | ├── BERT fine-tune实践.md 29 | ├── README.md 30 | ├── bert 31 | ├── chinese_L-12_H-768_A-12 32 | ├── conlleval 33 | ├── conlleval.py 34 | ├── data 35 | ├── data_utils.py 36 | ├── loader.py 37 | ├── model.py 38 | ├── pictures> 39 | ├── predict.py 40 | ├── rnncell.py 41 | ├── train.py 42 | └── utils.py 43 | ``` 44 | 45 | ### *Conclusion* 46 | 47 | 可以看到,使用bert以后,模型的精度提升了两个多百分点。并且,在后续测试过程中发现,使用bert训练的NER模型拥有更强的泛化性能,比如训练集中未见过的公司名称等,都可以很好的识别。而仅仅使用[ChineseNER](https://github.com/zjy-ucas/ChineseNER)中提供的训练集,基于BiLSTM+CRF的框架训练得到的模型,基本上无法解决OOV问题。 48 | 49 | ### *Fine-tune* 50 | 目前的代码是Feature Based的迁移,可以改为Fine-tune的迁移,效果还能再提升1个点左右。fine-tune可以自行修改代码,将model中的bert参数加入一起训练,并将lr修改到1e-5的量级。 51 | 并且,是否添加BiLSTM都对结果影响不大,可以直接使用BERT输出的结果进行解码,建议还是加一层CRF,强化标记间的转移规则。 52 | 53 | ### *Reference* 54 | 55 | (1) https://github.com/zjy-ucas/ChineseNER 56 | 57 | (2) https://github.com/google-research/bert 58 | 59 | (3) [Neural Architectures for Named Entity Recognition](https://arxiv.org/abs/1603.01360) 60 | -------------------------------------------------------------------------------- /conlleval: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 86 | elsif ($nbrOfFeatures != $#features and @features != 0) { 87 | printf STDERR "unexpected number of features: %d (%d)\n", 88 | $#features+1,$nbrOfFeatures+1; 89 | exit(1); 90 | } 91 | if (@features == 0 or 92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 93 | if (@features < 2) { 94 | die "conlleval: unexpected number of features in line $line\n"; 95 | } 96 | if ($raw) { 97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 99 | if ($features[$#features] ne "O") { 100 | $features[$#features] = "B-$features[$#features]"; 101 | } 102 | if ($features[$#features-1] ne "O") { 103 | $features[$#features-1] = "B-$features[$#features-1]"; 104 | } 105 | } 106 | # 20040126 ET code which allows hyphens in the types 107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 108 | $guessed = $1; 109 | $guessedType = $2; 110 | } else { 111 | $guessed = $features[$#features]; 112 | $guessedType = ""; 113 | } 114 | pop(@features); 115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 116 | $correct = $1; 117 | $correctType = $2; 118 | } else { 119 | $correct = $features[$#features]; 120 | $correctType = ""; 121 | } 122 | pop(@features); 123 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 124 | # ($correct,$correctType) = split(/-/,pop(@features)); 125 | $guessedType = $guessedType ? $guessedType : ""; 126 | $correctType = $correctType ? $correctType : ""; 127 | $firstItem = shift(@features); 128 | 129 | # 1999-06-26 sentence breaks should always be counted as out of chunk 130 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 131 | 132 | if ($inCorrect) { 133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 135 | $lastGuessedType eq $lastCorrectType) { 136 | $inCorrect=$false; 137 | $correctChunk++; 138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 139 | $correctChunk{$lastCorrectType}+1 : 1; 140 | } elsif ( 141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 143 | $guessedType ne $correctType ) { 144 | $inCorrect=$false; 145 | } 146 | } 147 | 148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 150 | $guessedType eq $correctType) { $inCorrect = $true; } 151 | 152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 153 | $foundCorrect++; 154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 155 | $foundCorrect{$correctType}+1 : 1; 156 | } 157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 158 | $foundGuessed++; 159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 160 | $foundGuessed{$guessedType}+1 : 1; 161 | } 162 | if ( $firstItem ne $boundary ) { 163 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 164 | $correctTags++; 165 | } 166 | $tokenCounter++; 167 | } 168 | 169 | $lastGuessed = $guessed; 170 | $lastCorrect = $correct; 171 | $lastGuessedType = $guessedType; 172 | $lastCorrectType = $correctType; 173 | } 174 | if ($inCorrect) { 175 | $correctChunk++; 176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 177 | $correctChunk{$lastCorrectType}+1 : 1; 178 | } 179 | 180 | if (not $latex) { 181 | # compute overall precision, recall and FB1 (default values are 0.0) 182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 184 | $FB1 = 2*$precision*$recall/($precision+$recall) 185 | if ($precision+$recall > 0); 186 | 187 | # print overall performance 188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 190 | if ($tokenCounter>0) { 191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 192 | printf "precision: %6.2f%%; ",$precision; 193 | printf "recall: %6.2f%%; ",$recall; 194 | printf "FB1: %6.2f\n",$FB1; 195 | } 196 | } 197 | 198 | # sort chunk type names 199 | undef($lastType); 200 | @sortedTypes = (); 201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 202 | if (not($lastType) or $lastType ne $i) { 203 | push(@sortedTypes,($i)); 204 | } 205 | $lastType = $i; 206 | } 207 | # print performance per chunk type 208 | if (not $latex) { 209 | for $i (@sortedTypes) { 210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 213 | if (not($foundCorrect{$i})) { $recall = 0.0; } 214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 217 | printf "%17s: ",$i; 218 | printf "precision: %6.2f%%; ",$precision; 219 | printf "recall: %6.2f%%; ",$recall; 220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 221 | } 222 | } else { 223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 224 | for $i (@sortedTypes) { 225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 226 | if (not($foundGuessed{$i})) { $precision = 0.0; } 227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 228 | if (not($foundCorrect{$i})) { $recall = 0.0; } 229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 233 | $i,$precision,$recall,$FB1; 234 | } 235 | print "\\hline\n"; 236 | $precision = 0.0; 237 | $recall = 0; 238 | $FB1 = 0.0; 239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 241 | $FB1 = 2*$precision*$recall/($precision+$recall) 242 | if ($precision+$recall > 0); 243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 244 | $precision,$recall,$FB1; 245 | } 246 | 247 | exit 0; 248 | 249 | # endOfChunk: checks if a chunk ended between the previous and current word 250 | # arguments: previous and current chunk tags, previous and current types 251 | # note: this code is capable of handling other chunk representations 252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 254 | 255 | sub endOfChunk { 256 | my $prevTag = shift(@_); 257 | my $tag = shift(@_); 258 | my $prevType = shift(@_); 259 | my $type = shift(@_); 260 | my $chunkEnd = $false; 261 | 262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 266 | 267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 271 | 272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 273 | $chunkEnd = $true; 274 | } 275 | 276 | # corrected 1998-12-22: these chunks are assumed to have length 1 277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 279 | 280 | return($chunkEnd); 281 | } 282 | 283 | # startOfChunk: checks if a chunk started between the previous and current word 284 | # arguments: previous and current chunk tags, previous and current types 285 | # note: this code is capable of handling other chunk representations 286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 288 | 289 | sub startOfChunk { 290 | my $prevTag = shift(@_); 291 | my $tag = shift(@_); 292 | my $prevType = shift(@_); 293 | my $type = shift(@_); 294 | my $chunkStart = $false; 295 | 296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 300 | 301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 305 | 306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 307 | $chunkStart = $true; 308 | } 309 | 310 | # corrected 1998-12-22: these chunks are assumed to have length 1 311 | if ( $tag eq "[" ) { $chunkStart = $true; } 312 | if ( $tag eq "]" ) { $chunkStart = $true; } 313 | 314 | return($chunkStart); 315 | } 316 | -------------------------------------------------------------------------------- /conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | import sys 13 | import re 14 | import codecs 15 | from collections import defaultdict, namedtuple 16 | 17 | ANY_SPACE = '' 18 | 19 | 20 | class FormatError(Exception): 21 | pass 22 | 23 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 24 | 25 | 26 | class EvalCounts(object): 27 | def __init__(self): 28 | self.correct_chunk = 0 # number of correctly identified chunks 29 | self.correct_tags = 0 # number of correct chunk tags 30 | self.found_correct = 0 # number of chunks in corpus 31 | self.found_guessed = 0 # number of identified chunks 32 | self.token_counter = 0 # token counter (ignores sentence breaks) 33 | 34 | # counts by type 35 | self.t_correct_chunk = defaultdict(int) 36 | self.t_found_correct = defaultdict(int) 37 | self.t_found_guessed = defaultdict(int) 38 | 39 | 40 | def parse_args(argv): 41 | import argparse 42 | parser = argparse.ArgumentParser( 43 | description='evaluate tagging results using CoNLL criteria', 44 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 45 | ) 46 | arg = parser.add_argument 47 | arg('-b', '--boundary', metavar='STR', default='-X-', 48 | help='sentence boundary') 49 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 50 | help='character delimiting items in input') 51 | arg('-o', '--otag', metavar='CHAR', default='O', 52 | help='alternative outside tag') 53 | arg('file', nargs='?', default=None) 54 | return parser.parse_args(argv) 55 | 56 | 57 | def parse_tag(t): 58 | m = re.match(r'^([^-]*)-(.*)$', t) 59 | return m.groups() if m else (t, '') 60 | 61 | 62 | def evaluate(iterable, options=None): 63 | if options is None: 64 | options = parse_args([]) # use defaults 65 | 66 | counts = EvalCounts() 67 | num_features = None # number of features per line 68 | in_correct = False # currently processed chunks is correct until now 69 | last_correct = 'O' # previous chunk tag in corpus 70 | last_correct_type = '' # type of previously identified chunk tag 71 | last_guessed = 'O' # previously identified chunk tag 72 | last_guessed_type = '' # type of previous chunk tag in corpus 73 | 74 | for line in iterable: 75 | line = line.rstrip('\r\n') 76 | 77 | if options.delimiter == ANY_SPACE: 78 | features = line.split() 79 | else: 80 | features = line.split(options.delimiter) 81 | 82 | if num_features is None: 83 | num_features = len(features) 84 | elif num_features != len(features) and len(features) != 0: 85 | raise FormatError('unexpected number of features: %d (%d)' % 86 | (len(features), num_features)) 87 | 88 | if len(features) == 0 or features[0] == options.boundary: 89 | features = [options.boundary, 'O', 'O'] 90 | if len(features) < 3: 91 | raise FormatError('unexpected number of features in line %s' % line) 92 | 93 | guessed, guessed_type = parse_tag(features.pop()) 94 | correct, correct_type = parse_tag(features.pop()) 95 | first_item = features.pop(0) 96 | 97 | if first_item == options.boundary: 98 | guessed = 'O' 99 | 100 | end_correct = end_of_chunk(last_correct, correct, 101 | last_correct_type, correct_type) 102 | end_guessed = end_of_chunk(last_guessed, guessed, 103 | last_guessed_type, guessed_type) 104 | start_correct = start_of_chunk(last_correct, correct, 105 | last_correct_type, correct_type) 106 | start_guessed = start_of_chunk(last_guessed, guessed, 107 | last_guessed_type, guessed_type) 108 | 109 | if in_correct: 110 | if (end_correct and end_guessed and 111 | last_guessed_type == last_correct_type): 112 | in_correct = False 113 | counts.correct_chunk += 1 114 | counts.t_correct_chunk[last_correct_type] += 1 115 | elif (end_correct != end_guessed or guessed_type != correct_type): 116 | in_correct = False 117 | 118 | if start_correct and start_guessed and guessed_type == correct_type: 119 | in_correct = True 120 | 121 | if start_correct: 122 | counts.found_correct += 1 123 | counts.t_found_correct[correct_type] += 1 124 | if start_guessed: 125 | counts.found_guessed += 1 126 | counts.t_found_guessed[guessed_type] += 1 127 | if first_item != options.boundary: 128 | if correct == guessed and guessed_type == correct_type: 129 | counts.correct_tags += 1 130 | counts.token_counter += 1 131 | 132 | last_guessed = guessed 133 | last_correct = correct 134 | last_guessed_type = guessed_type 135 | last_correct_type = correct_type 136 | 137 | if in_correct: 138 | counts.correct_chunk += 1 139 | counts.t_correct_chunk[last_correct_type] += 1 140 | 141 | return counts 142 | 143 | 144 | def uniq(iterable): 145 | seen = set() 146 | return [i for i in iterable if not (i in seen or seen.add(i))] 147 | 148 | 149 | def calculate_metrics(correct, guessed, total): 150 | tp, fp, fn = correct, guessed-correct, total-correct 151 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 152 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 153 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 154 | return Metrics(tp, fp, fn, p, r, f) 155 | 156 | 157 | def metrics(counts): 158 | c = counts 159 | overall = calculate_metrics( 160 | c.correct_chunk, c.found_guessed, c.found_correct 161 | ) 162 | by_type = {} 163 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 164 | by_type[t] = calculate_metrics( 165 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 166 | ) 167 | return overall, by_type 168 | 169 | 170 | def report(counts, out=None): 171 | if out is None: 172 | out = sys.stdout 173 | 174 | overall, by_type = metrics(counts) 175 | 176 | c = counts 177 | out.write('processed %d tokens with %d phrases; ' % 178 | (c.token_counter, c.found_correct)) 179 | out.write('found: %d phrases; correct: %d.\n' % 180 | (c.found_guessed, c.correct_chunk)) 181 | 182 | if c.token_counter > 0: 183 | out.write('accuracy: %6.2f%%; ' % 184 | (100.*c.correct_tags/c.token_counter)) 185 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 186 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 187 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 188 | 189 | for i, m in sorted(by_type.items()): 190 | out.write('%17s: ' % i) 191 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 192 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 193 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 194 | 195 | 196 | def report_notprint(counts, out=None): 197 | if out is None: 198 | out = sys.stdout 199 | 200 | overall, by_type = metrics(counts) 201 | 202 | c = counts 203 | final_report = [] 204 | line = [] 205 | line.append('processed %d tokens with %d phrases; ' % 206 | (c.token_counter, c.found_correct)) 207 | line.append('found: %d phrases; correct: %d.\n' % 208 | (c.found_guessed, c.correct_chunk)) 209 | final_report.append("".join(line)) 210 | 211 | if c.token_counter > 0: 212 | line = [] 213 | line.append('accuracy: %6.2f%%; ' % 214 | (100.*c.correct_tags/c.token_counter)) 215 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 216 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 217 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 218 | final_report.append("".join(line)) 219 | 220 | for i, m in sorted(by_type.items()): 221 | line = [] 222 | line.append('%17s: ' % i) 223 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 224 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 225 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 226 | final_report.append("".join(line)) 227 | return final_report 228 | 229 | 230 | def end_of_chunk(prev_tag, tag, prev_type, type_): 231 | # check if a chunk ended between the previous and current word 232 | # arguments: previous and current chunk tags, previous and current types 233 | chunk_end = False 234 | 235 | if prev_tag == 'E': chunk_end = True 236 | if prev_tag == 'S': chunk_end = True 237 | 238 | if prev_tag == 'B' and tag == 'B': chunk_end = True 239 | if prev_tag == 'B' and tag == 'S': chunk_end = True 240 | if prev_tag == 'B' and tag == 'O': chunk_end = True 241 | if prev_tag == 'I' and tag == 'B': chunk_end = True 242 | if prev_tag == 'I' and tag == 'S': chunk_end = True 243 | if prev_tag == 'I' and tag == 'O': chunk_end = True 244 | 245 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 246 | chunk_end = True 247 | 248 | # these chunks are assumed to have length 1 249 | if prev_tag == ']': chunk_end = True 250 | if prev_tag == '[': chunk_end = True 251 | 252 | return chunk_end 253 | 254 | 255 | def start_of_chunk(prev_tag, tag, prev_type, type_): 256 | # check if a chunk started between the previous and current word 257 | # arguments: previous and current chunk tags, previous and current types 258 | chunk_start = False 259 | 260 | if tag == 'B': chunk_start = True 261 | if tag == 'S': chunk_start = True 262 | 263 | if prev_tag == 'E' and tag == 'E': chunk_start = True 264 | if prev_tag == 'E' and tag == 'I': chunk_start = True 265 | if prev_tag == 'S' and tag == 'E': chunk_start = True 266 | if prev_tag == 'S' and tag == 'I': chunk_start = True 267 | if prev_tag == 'O' and tag == 'E': chunk_start = True 268 | if prev_tag == 'O' and tag == 'I': chunk_start = True 269 | 270 | if tag != 'O' and tag != '.' and prev_type != type_: 271 | chunk_start = True 272 | 273 | # these chunks are assumed to have length 1 274 | if tag == '[': chunk_start = True 275 | if tag == ']': chunk_start = True 276 | 277 | return chunk_start 278 | 279 | 280 | def return_report(input_file): 281 | with codecs.open(input_file, "r", "utf8") as f: 282 | counts = evaluate(f) 283 | return report_notprint(counts) 284 | 285 | 286 | def main(argv): 287 | args = parse_args(argv[1:]) 288 | 289 | if args.file is None: 290 | counts = evaluate(sys.stdin, args) 291 | else: 292 | with open(args.file) as f: 293 | counts = evaluate(f, args) 294 | report(counts) 295 | 296 | if __name__ == '__main__': 297 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # encoding = utf8 2 | import re 3 | import math 4 | import codecs 5 | import random 6 | 7 | import numpy as np 8 | #import jieba 9 | #jieba.initialize() 10 | 11 | 12 | def create_dico(item_list): 13 | """ 14 | Create a dictionary of items from a list of list of items. 15 | """ 16 | assert type(item_list) is list 17 | dico = {} 18 | for items in item_list: 19 | for item in items: 20 | if item not in dico: 21 | dico[item] = 1 22 | else: 23 | dico[item] += 1 24 | 25 | return dico 26 | 27 | 28 | def create_mapping(dico): 29 | """ 30 | Create a mapping (item to ID / ID to item) from a dictionary. 31 | Items are ordered by decreasing frequency. 32 | """ 33 | sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0])) 34 | id_to_item = {i: v[0] for i, v in enumerate(sorted_items)} 35 | item_to_id = {v: k for k, v in id_to_item.items()} 36 | return item_to_id, id_to_item 37 | 38 | 39 | def zero_digits(s): 40 | """ 41 | Replace every digit in a string by a zero. 42 | """ 43 | return re.sub('\d', '0', s) 44 | 45 | 46 | def iob2(tags): 47 | """ 48 | Check that tags have a valid IOB format. 49 | Tags in IOB1 format are converted to IOB2. 50 | """ 51 | for i, tag in enumerate(tags): 52 | if tag == 'O': 53 | continue 54 | split = tag.split('-') 55 | if len(split) != 2 or split[0] not in ['I', 'B']: 56 | return False 57 | if split[0] == 'B': 58 | continue 59 | elif i == 0 or tags[i - 1] == 'O': # conversion IOB1 to IOB2 60 | tags[i] = 'B' + tag[1:] 61 | elif tags[i - 1][1:] == tag[1:]: 62 | continue 63 | else: # conversion IOB1 to IOB2 64 | tags[i] = 'B' + tag[1:] 65 | return True 66 | 67 | 68 | def iob_iobes(tags): 69 | """ 70 | IOB -> IOBES 71 | """ 72 | new_tags = [] 73 | for i, tag in enumerate(tags): 74 | if tag == 'O': 75 | new_tags.append(tag) 76 | elif tag.split('-')[0] == 'B': 77 | if i + 1 != len(tags) and \ 78 | tags[i + 1].split('-')[0] == 'I': 79 | new_tags.append(tag) 80 | else: 81 | new_tags.append(tag.replace('B-', 'S-')) 82 | elif tag.split('-')[0] == 'I': 83 | if i + 1 < len(tags) and \ 84 | tags[i + 1].split('-')[0] == 'I': 85 | new_tags.append(tag) 86 | else: 87 | new_tags.append(tag.replace('I-', 'E-')) 88 | else: 89 | raise Exception('Invalid IOB format!') 90 | return new_tags 91 | 92 | 93 | def iobes_iob(tags): 94 | """ 95 | IOBES -> IOB 96 | """ 97 | new_tags = [] 98 | for i, tag in enumerate(tags): 99 | if tag.split('-')[0] == 'B': 100 | new_tags.append(tag) 101 | elif tag.split('-')[0] == 'I': 102 | new_tags.append(tag) 103 | elif tag.split('-')[0] == 'S': 104 | new_tags.append(tag.replace('S-', 'B-')) 105 | elif tag.split('-')[0] == 'E': 106 | new_tags.append(tag.replace('E-', 'I-')) 107 | elif tag.split('-')[0] == 'O': 108 | new_tags.append(tag) 109 | else: 110 | raise Exception('Invalid format!') 111 | return new_tags 112 | 113 | 114 | def insert_singletons(words, singletons, p=0.5): 115 | """ 116 | Replace singletons by the unknown word with a probability p. 117 | """ 118 | new_words = [] 119 | for word in words: 120 | if word in singletons and np.random.uniform() < p: 121 | new_words.append(0) 122 | else: 123 | new_words.append(word) 124 | return new_words 125 | 126 | 127 | def get_seg_features(string): 128 | """ 129 | Segment text with jieba 130 | features are represented in bies format 131 | s donates single word 132 | """ 133 | seg_feature = [] 134 | 135 | for word in jieba.cut(string): 136 | if len(word) == 1: 137 | seg_feature.append(0) #o 138 | else: 139 | tmp = [2] * len(word) #i 140 | tmp[0] = 1 #b 141 | tmp[-1] = 3 #e 142 | seg_feature.extend(tmp) 143 | return seg_feature 144 | 145 | 146 | def create_input(data): 147 | """ 148 | Take sentence data and return an input for 149 | the training or the evaluation function. 150 | """ 151 | inputs = list() 152 | inputs.append(data['chars']) 153 | inputs.append(data["segs"]) 154 | inputs.append(data['tags']) 155 | return inputs 156 | 157 | 158 | def load_word2vec(emb_path, id_to_word, word_dim, old_weights): 159 | """ 160 | Load word embedding from pre-trained file 161 | embedding size must match 162 | """ 163 | new_weights = old_weights 164 | print('Loading pretrained embeddings from {}...'.format(emb_path)) 165 | pre_trained = {} 166 | emb_invalid = 0 167 | for i, line in enumerate(codecs.open(emb_path, 'r', 'utf-8')): 168 | line = line.rstrip().split() 169 | if len(line) == word_dim + 1: 170 | pre_trained[line[0]] = np.array( 171 | [float(x) for x in line[1:]] 172 | ).astype(np.float32) 173 | else: 174 | emb_invalid += 1 175 | if emb_invalid > 0: 176 | print('WARNING: %i invalid lines' % emb_invalid) 177 | c_found = 0 178 | c_lower = 0 179 | c_zeros = 0 180 | n_words = len(id_to_word) 181 | # Lookup table initialization 182 | for i in range(n_words): 183 | word = id_to_word[i] 184 | if word in pre_trained: 185 | new_weights[i] = pre_trained[word] 186 | c_found += 1 187 | elif word.lower() in pre_trained: 188 | new_weights[i] = pre_trained[word.lower()] 189 | c_lower += 1 190 | elif re.sub('\d', '0', word.lower()) in pre_trained: #replace numbers to zero 191 | new_weights[i] = pre_trained[ 192 | re.sub('\d', '0', word.lower()) 193 | ] 194 | c_zeros += 1 195 | print('Loaded %i pretrained embeddings.' % len(pre_trained)) 196 | print('%i / %i (%.4f%%) words have been initialized with ' 197 | 'pretrained embeddings.' % ( 198 | c_found + c_lower + c_zeros, n_words, 199 | 100. * (c_found + c_lower + c_zeros) / n_words) 200 | ) 201 | print('%i found directly, %i after lowercasing, ' 202 | '%i after lowercasing + zero.' % ( 203 | c_found, c_lower, c_zeros 204 | )) 205 | return new_weights 206 | 207 | 208 | def full_to_half(s): 209 | """ 210 | Convert full-width character to half-width one 211 | """ 212 | n = [] 213 | for char in s: 214 | num = ord(char) 215 | if num == 0x3000: 216 | num = 32 217 | elif 0xFF01 <= num <= 0xFF5E: 218 | num -= 0xfee0 219 | char = chr(num) 220 | n.append(char) 221 | return ''.join(n) 222 | 223 | 224 | def cut_to_sentence(text): 225 | """ 226 | Cut text to sentences 227 | """ 228 | sentence = [] 229 | sentences = [] 230 | len_p = len(text) 231 | pre_cut = False 232 | for idx, word in enumerate(text): 233 | sentence.append(word) 234 | cut = False 235 | if pre_cut: 236 | cut=True 237 | pre_cut=False 238 | if word in u"。;!?\n": 239 | cut = True 240 | if len_p > idx+1: 241 | if text[idx+1] in ".。”\"\'“”‘’?!": 242 | cut = False 243 | pre_cut=True 244 | 245 | if cut: 246 | sentences.append(sentence) 247 | sentence = [] 248 | if sentence: 249 | sentences.append("".join(list(sentence))) 250 | return sentences 251 | 252 | 253 | def replace_html(s): 254 | s = s.replace('"','"') 255 | s = s.replace('&','&') 256 | s = s.replace('<','<') 257 | s = s.replace('>','>') 258 | s = s.replace(' ',' ') 259 | s = s.replace("“", "“") 260 | s = s.replace("”", "”") 261 | s = s.replace("—","") 262 | s = s.replace("\xa0", " ") 263 | return(s) 264 | 265 | class BatchManager(object): 266 | 267 | def __init__(self, data, batch_size): 268 | self.batch_data = self.sort_and_pad(data, batch_size) 269 | self.len_data = len(self.batch_data) 270 | 271 | def sort_and_pad(self, data, batch_size): 272 | num_batch = int(math.ceil(len(data) /batch_size)) 273 | sorted_data = sorted(data, key=lambda x: len(x[0])) 274 | batch_data = list() 275 | for i in range(num_batch): 276 | batch_data.append(self.arrange_batch(sorted_data[int(i*batch_size) : int((i+1)*batch_size)])) 277 | return batch_data 278 | 279 | @staticmethod 280 | def arrange_batch(batch): 281 | ''' 282 | 把batch整理为一个[5, ]的数组 283 | :param batch: 284 | :return: 285 | ''' 286 | strings = [] 287 | segment_ids = [] 288 | chars = [] 289 | mask = [] 290 | targets = [] 291 | for string, seg_ids, char, msk, target in batch: 292 | strings.append(string) 293 | segment_ids.append(seg_ids) 294 | chars.append(char) 295 | mask.append(msk) 296 | targets.append(target) 297 | return [strings, segment_ids, chars, mask, targets] 298 | 299 | @staticmethod 300 | def pad_data(data): 301 | strings = [] 302 | chars = [] 303 | segs = [] 304 | targets = [] 305 | max_length = max([len(sentence[0]) for sentence in data]) 306 | for line in data: 307 | string, segment_ids, char, seg, target = line 308 | padding = [0] * (max_length - len(string)) 309 | strings.append(string + padding) 310 | chars.append(char + padding) 311 | segs.append(seg + padding) 312 | targets.append(target + padding) 313 | return [strings, chars, segs, targets] 314 | 315 | def iter_batch(self, shuffle=False): 316 | if shuffle: 317 | random.shuffle(self.batch_data) 318 | for idx in range(self.len_data): 319 | yield self.batch_data[idx] 320 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | 5 | from bert import tokenization 6 | from utils import convert_single_example 7 | 8 | tokenizer = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt', 9 | do_lower_case=True) 10 | 11 | from data_utils import create_dico, create_mapping, zero_digits 12 | from data_utils import iob2, iob_iobes, get_seg_features 13 | 14 | 15 | def load_sentences(path, lower, zeros): 16 | """ 17 | Load sentences. A line must contain at least a word and its tag. 18 | Sentences are separated by empty lines. 19 | """ 20 | sentences = [] 21 | sentence = [] 22 | num = 0 23 | for line in codecs.open(path, 'r', 'utf8'): 24 | num+=1 25 | line = zero_digits(line.rstrip()) if zeros else line.rstrip() 26 | # print(list(line)) 27 | if not line: 28 | if len(sentence) > 0: 29 | if 'DOCSTART' not in sentence[0][0]: 30 | sentences.append(sentence) 31 | sentence = [] 32 | else: 33 | if line[0] == " ": 34 | line = "$" + line[1:] 35 | word = line.split() 36 | # word[0] = " " 37 | else: 38 | word= line.split() 39 | assert len(word) >= 2, print([word[0]]) 40 | sentence.append(word) 41 | if len(sentence) > 0: 42 | if 'DOCSTART' not in sentence[0][0]: 43 | sentences.append(sentence) 44 | return sentences 45 | 46 | 47 | def update_tag_scheme(sentences, tag_scheme): 48 | """ 49 | Check and update sentences tagging scheme to IOB2. 50 | Only IOB1 and IOB2 schemes are accepted. 51 | """ 52 | for i, s in enumerate(sentences): 53 | tags = [w[-1] for w in s] 54 | # Check that tags are given in the IOB format 55 | if not iob2(tags): 56 | s_str = '\n'.join(' '.join(w) for w in s) 57 | raise Exception('Sentences should be given in IOB format! ' + 58 | 'Please check sentence %i:\n%s' % (i, s_str)) 59 | if tag_scheme == 'iob': 60 | # If format was IOB1, we convert to IOB2 61 | for word, new_tag in zip(s, tags): 62 | word[-1] = new_tag 63 | elif tag_scheme == 'iobes': 64 | new_tags = iob_iobes(tags) 65 | for word, new_tag in zip(s, new_tags): 66 | word[-1] = new_tag 67 | else: 68 | raise Exception('Unknown tagging scheme!') 69 | 70 | 71 | def char_mapping(sentences, lower): 72 | """ 73 | Create a dictionary and a mapping of words, sorted by frequency. 74 | """ 75 | chars = [[x[0].lower() if lower else x[0] for x in s] for s in sentences] 76 | dico = create_dico(chars) 77 | dico[""] = 10000001 78 | dico[''] = 10000000 79 | 80 | char_to_id, id_to_char = create_mapping(dico) 81 | print("Found %i unique words (%i in total)" % ( 82 | len(dico), sum(len(x) for x in chars) 83 | )) 84 | return dico, char_to_id, id_to_char 85 | 86 | 87 | def tag_mapping(sentences): 88 | """ 89 | Create a dictionary and a mapping of tags, sorted by frequency. 90 | """ 91 | tags = [[char[-1] for char in s] for s in sentences] 92 | 93 | dico = create_dico(tags) 94 | dico['[SEP]'] = len(dico) + 1 95 | dico['[CLS]'] = len(dico) + 2 96 | 97 | tag_to_id, id_to_tag = create_mapping(dico) 98 | print("Found %i unique named entity tags" % len(dico)) 99 | return dico, tag_to_id, id_to_tag 100 | 101 | 102 | def prepare_dataset(sentences, max_seq_length, tag_to_id, lower=False, train=True): 103 | """ 104 | Prepare the dataset. Return a list of lists of dictionaries containing: 105 | - word indexes 106 | - word char indexes 107 | - tag indexes 108 | """ 109 | def f(x): 110 | return x.lower() if lower else x 111 | data = [] 112 | for s in sentences: 113 | string = [w[0].strip() for w in s] 114 | #chars = [char_to_id[f(w) if f(w) in char_to_id else ''] 115 | # for w in string] 116 | char_line = ' '.join(string) #使用空格把汉字拼起来 117 | text = tokenization.convert_to_unicode(char_line) 118 | 119 | if train: 120 | tags = [w[-1] for w in s] 121 | else: 122 | tags = ['O' for _ in string] 123 | 124 | labels = ' '.join(tags) #使用空格把标签拼起来 125 | labels = tokenization.convert_to_unicode(labels) 126 | 127 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text, 128 | tag_to_id=tag_to_id, 129 | max_seq_length=max_seq_length, 130 | tokenizer=tokenizer, 131 | label_line=labels) 132 | data.append([string, segment_ids, ids, mask, label_ids]) 133 | 134 | return data 135 | 136 | 137 | def input_from_line(line, max_seq_length, tag_to_id): 138 | """ 139 | Take sentence data and return an input for 140 | the training or the evaluation function. 141 | """ 142 | string = [w[0].strip() for w in line] 143 | # chars = [char_to_id[f(w) if f(w) in char_to_id else ''] 144 | # for w in string] 145 | char_line = ' '.join(string) # 使用空格把汉字拼起来 146 | text = tokenization.convert_to_unicode(char_line) 147 | 148 | tags = ['O' for _ in string] 149 | 150 | labels = ' '.join(tags) # 使用空格把标签拼起来 151 | labels = tokenization.convert_to_unicode(labels) 152 | 153 | ids, mask, segment_ids, label_ids = convert_single_example(char_line=text, 154 | tag_to_id=tag_to_id, 155 | max_seq_length=max_seq_length, 156 | tokenizer=tokenizer, 157 | label_line=labels) 158 | import numpy as np 159 | segment_ids = np.reshape(segment_ids,(1, max_seq_length)) 160 | ids = np.reshape(ids, (1, max_seq_length)) 161 | mask = np.reshape(mask, (1, max_seq_length)) 162 | label_ids = np.reshape(label_ids, (1, max_seq_length)) 163 | return [string, segment_ids, ids, mask, label_ids] 164 | 165 | def augment_with_pretrained(dictionary, ext_emb_path, chars): 166 | """ 167 | Augment the dictionary with words that have a pretrained embedding. 168 | If `words` is None, we add every word that has a pretrained embedding 169 | to the dictionary, otherwise, we only add the words that are given by 170 | `words` (typically the words in the development and test sets.) 171 | """ 172 | print('Loading pretrained embeddings from %s...' % ext_emb_path) 173 | assert os.path.isfile(ext_emb_path) 174 | 175 | # Load pretrained embeddings from file 176 | pretrained = set([ 177 | line.rstrip().split()[0].strip() 178 | for line in codecs.open(ext_emb_path, 'r', 'utf-8') 179 | if len(ext_emb_path) > 0 180 | ]) 181 | 182 | # We either add every word in the pretrained file, 183 | # or only words given in the `words` list to which 184 | # we can assign a pretrained embedding 185 | if chars is None: 186 | for char in pretrained: 187 | if char not in dictionary: 188 | dictionary[char] = 0 189 | else: 190 | for char in chars: 191 | if any(x in pretrained for x in [ 192 | char, 193 | char.lower(), 194 | re.sub('\d', '0', char.lower()) 195 | ]) and char not in dictionary: 196 | dictionary[char] = 0 197 | 198 | word_to_id, id_to_word = create_mapping(dictionary) 199 | return dictionary, word_to_id, id_to_word 200 | 201 | 202 | def save_maps(save_path, *params): 203 | """ 204 | Save mappings and invert mappings 205 | """ 206 | pass 207 | # with codecs.open(save_path, "w", encoding="utf8") as f: 208 | # pickle.dump(params, f) 209 | 210 | 211 | def load_maps(save_path): 212 | """ 213 | Load mappings from the file 214 | """ 215 | pass 216 | # with codecs.open(save_path, "r", encoding="utf8") as f: 217 | # pickle.load(save_path, f) 218 | 219 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # encoding = utf8 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.contrib.crf import crf_log_likelihood 5 | from tensorflow.contrib.crf import viterbi_decode 6 | from tensorflow.contrib.layers.python.layers import initializers 7 | 8 | import rnncell as rnn 9 | from utils import bio_to_json 10 | from bert import modeling 11 | 12 | class Model(object): 13 | def __init__(self, config): 14 | 15 | self.config = config 16 | self.lr = config["lr"] 17 | self.lstm_dim = config["lstm_dim"] 18 | self.num_tags = config["num_tags"] 19 | 20 | self.global_step = tf.Variable(0, trainable=False) 21 | self.best_dev_f1 = tf.Variable(0.0, trainable=False) 22 | self.best_test_f1 = tf.Variable(0.0, trainable=False) 23 | self.initializer = initializers.xavier_initializer() 24 | 25 | # add placeholders for the model 26 | self.input_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_ids") 27 | self.input_mask = tf.placeholder(dtype=tf.int32, shape=[None, None], name="input_mask") 28 | self.segment_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="segment_ids") 29 | self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="Targets") 30 | # dropout keep prob 31 | self.dropout = tf.placeholder(dtype=tf.float32, name="Dropout") 32 | 33 | used = tf.sign(tf.abs(self.input_ids)) 34 | length = tf.reduce_sum(used, reduction_indices=1) 35 | self.lengths = tf.cast(length, tf.int32) 36 | self.batch_size = tf.shape(self.input_ids)[0] 37 | self.num_steps = tf.shape(self.input_ids)[-1] 38 | 39 | # embeddings for chinese character and segmentation representation 40 | embedding = self.bert_embedding() 41 | 42 | # apply dropout before feed to lstm layer 43 | lstm_inputs = tf.nn.dropout(embedding, self.dropout) 44 | 45 | # bi-directional lstm layer 46 | lstm_outputs = self.biLSTM_layer(lstm_inputs, self.lstm_dim, self.lengths) 47 | 48 | # logits for tags 49 | self.logits = self.project_layer(lstm_outputs) 50 | 51 | # loss of the model 52 | self.loss = self.loss_layer(self.logits, self.lengths) 53 | 54 | # bert模型参数初始化的地方 55 | init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt" 56 | # 获取模型中所有的训练参数。 57 | tvars = tf.trainable_variables() 58 | # 加载BERT模型 59 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, 60 | init_checkpoint) 61 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 62 | print("**** Trainable Variables ****") 63 | # 打印加载模型的参数 64 | train_vars = [] 65 | for var in tvars: 66 | init_string = "" 67 | if var.name in initialized_variable_names: 68 | init_string = ", *INIT_FROM_CKPT*" 69 | else: 70 | train_vars.append(var) 71 | print(" name = %s, shape = %s%s", var.name, var.shape, 72 | init_string) 73 | with tf.variable_scope("optimizer"): 74 | optimizer = self.config["optimizer"] 75 | if optimizer == "adam": 76 | self.opt = tf.train.AdamOptimizer(self.lr) 77 | else: 78 | raise KeyError 79 | 80 | grads = tf.gradients(self.loss, train_vars) 81 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 82 | 83 | self.train_op = self.opt.apply_gradients( 84 | zip(grads, train_vars), global_step=self.global_step) 85 | #capped_grads_vars = [[tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v] 86 | # for g, v in grads_vars if g is not None] 87 | #self.train_op = self.opt.apply_gradients(capped_grads_vars, self.global_step, ) 88 | 89 | # saver of the model 90 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) 91 | 92 | def bert_embedding(self): 93 | # load bert embedding 94 | bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json") # 配置文件地址。 95 | model = modeling.BertModel( 96 | config=bert_config, 97 | is_training=True, 98 | input_ids=self.input_ids, 99 | input_mask=self.input_mask, 100 | token_type_ids=self.segment_ids, 101 | use_one_hot_embeddings=False) 102 | embedding = model.get_sequence_output() 103 | return embedding 104 | 105 | def biLSTM_layer(self, lstm_inputs, lstm_dim, lengths, name=None): 106 | """ 107 | :param lstm_inputs: [batch_size, num_steps, emb_size] 108 | :return: [batch_size, num_steps, 2*lstm_dim] 109 | """ 110 | with tf.variable_scope("char_BiLSTM" if not name else name): 111 | lstm_cell = {} 112 | for direction in ["forward", "backward"]: 113 | with tf.variable_scope(direction): 114 | lstm_cell[direction] = rnn.CoupledInputForgetGateLSTMCell( 115 | lstm_dim, 116 | use_peepholes=True, 117 | initializer=self.initializer, 118 | state_is_tuple=True) 119 | outputs, final_states = tf.nn.bidirectional_dynamic_rnn( 120 | lstm_cell["forward"], 121 | lstm_cell["backward"], 122 | lstm_inputs, 123 | dtype=tf.float32, 124 | sequence_length=lengths) 125 | return tf.concat(outputs, axis=2) 126 | 127 | def project_layer(self, lstm_outputs, name=None): 128 | """ 129 | hidden layer between lstm layer and logits 130 | :param lstm_outputs: [batch_size, num_steps, emb_size] 131 | :return: [batch_size, num_steps, num_tags] 132 | """ 133 | with tf.variable_scope("project" if not name else name): 134 | with tf.variable_scope("hidden"): 135 | W = tf.get_variable("W", shape=[self.lstm_dim*2, self.lstm_dim], 136 | dtype=tf.float32, initializer=self.initializer) 137 | 138 | b = tf.get_variable("b", shape=[self.lstm_dim], dtype=tf.float32, 139 | initializer=tf.zeros_initializer()) 140 | output = tf.reshape(lstm_outputs, shape=[-1, self.lstm_dim*2]) 141 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 142 | 143 | # project to score of tags 144 | with tf.variable_scope("logits"): 145 | W = tf.get_variable("W", shape=[self.lstm_dim, self.num_tags], 146 | dtype=tf.float32, initializer=self.initializer) 147 | 148 | b = tf.get_variable("b", shape=[self.num_tags], dtype=tf.float32, 149 | initializer=tf.zeros_initializer()) 150 | 151 | pred = tf.nn.xw_plus_b(hidden, W, b) 152 | 153 | return tf.reshape(pred, [-1, self.num_steps, self.num_tags]) 154 | 155 | def loss_layer(self, project_logits, lengths, name=None): 156 | """ 157 | calculate crf loss 158 | :param project_logits: [1, num_steps, num_tags] 159 | :return: scalar loss 160 | """ 161 | with tf.variable_scope("crf_loss" if not name else name): 162 | small = -1000.0 163 | # pad logits for crf loss 164 | start_logits = tf.concat( 165 | [small * tf.ones(shape=[self.batch_size, 1, self.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])], axis=-1) 166 | pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32) 167 | logits = tf.concat([project_logits, pad_logits], axis=-1) 168 | logits = tf.concat([start_logits, logits], axis=1) 169 | targets = tf.concat( 170 | [tf.cast(self.num_tags*tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1) 171 | 172 | self.trans = tf.get_variable( 173 | "transitions", 174 | shape=[self.num_tags + 1, self.num_tags + 1], 175 | initializer=self.initializer) 176 | log_likelihood, self.trans = crf_log_likelihood( 177 | inputs=logits, 178 | tag_indices=targets, 179 | transition_params=self.trans, 180 | sequence_lengths=lengths+1) 181 | return tf.reduce_mean(-log_likelihood) 182 | 183 | def create_feed_dict(self, is_train, batch): 184 | """ 185 | :param is_train: Flag, True for train batch 186 | :param batch: list train/evaluate data 187 | :return: structured data to feed 188 | """ 189 | _, segment_ids, chars, mask, tags = batch 190 | feed_dict = { 191 | self.input_ids: np.asarray(chars), 192 | self.input_mask: np.asarray(mask), 193 | self.segment_ids: np.asarray(segment_ids), 194 | self.dropout: 1.0, 195 | } 196 | if is_train: 197 | feed_dict[self.targets] = np.asarray(tags) 198 | feed_dict[self.dropout] = self.config["dropout_keep"] 199 | return feed_dict 200 | 201 | def run_step(self, sess, is_train, batch): 202 | """ 203 | :param sess: session to run the batch 204 | :param is_train: a flag indicate if it is a train batch 205 | :param batch: a dict containing batch data 206 | :return: batch result, loss of the batch or logits 207 | """ 208 | feed_dict = self.create_feed_dict(is_train, batch) 209 | if is_train: 210 | global_step, loss, _ = sess.run( 211 | [self.global_step, self.loss, self.train_op], 212 | feed_dict) 213 | return global_step, loss 214 | else: 215 | lengths, logits = sess.run([self.lengths, self.logits], feed_dict) 216 | return lengths, logits 217 | 218 | def decode(self, logits, lengths, matrix): 219 | """ 220 | :param logits: [batch_size, num_steps, num_tags]float32, logits 221 | :param lengths: [batch_size]int32, real length of each sequence 222 | :param matrix: transaction matrix for inference 223 | :return: 224 | """ 225 | # inference final labels usa viterbi Algorithm 226 | paths = [] 227 | small = -1000.0 228 | start = np.asarray([[small]*self.num_tags +[0]]) 229 | for score, length in zip(logits, lengths): 230 | score = score[:length] 231 | pad = small * np.ones([length, 1]) 232 | logits = np.concatenate([score, pad], axis=1) 233 | logits = np.concatenate([start, logits], axis=0) 234 | path, _ = viterbi_decode(logits, matrix) 235 | 236 | paths.append(path[1:]) 237 | return paths 238 | 239 | def evaluate(self, sess, data_manager, id_to_tag): 240 | """ 241 | :param sess: session to run the model 242 | :param data: list of data 243 | :param id_to_tag: index to tag name 244 | :return: evaluate result 245 | """ 246 | results = [] 247 | trans = self.trans.eval() 248 | for batch in data_manager.iter_batch(): 249 | strings = batch[0] 250 | labels = batch[-1] 251 | lengths, scores = self.run_step(sess, False, batch) 252 | batch_paths = self.decode(scores, lengths, trans) 253 | for i in range(len(strings)): 254 | result = [] 255 | string = strings[i][:lengths[i]] 256 | gold = [id_to_tag[int(x)] for x in labels[i][1:lengths[i]]] 257 | pred = [id_to_tag[int(x)] for x in batch_paths[i][1:lengths[i]]] 258 | for char, gold, pred in zip(string, gold, pred): 259 | result.append(" ".join([char, gold, pred])) 260 | results.append(result) 261 | return results 262 | 263 | def evaluate_line(self, sess, inputs, id_to_tag): 264 | trans = self.trans.eval(sess) 265 | lengths, scores = self.run_step(sess, False, inputs) 266 | batch_paths = self.decode(scores, lengths, trans) 267 | tags = [id_to_tag[idx] for idx in batch_paths[0]] 268 | return bio_to_json(inputs[0], tags[1:-1]) -------------------------------------------------------------------------------- /pictures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumath/bertNER/ee044eb1333b532bed2147610abb4d600b1cd8cf/pictures/results.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import tensorflow as tf 4 | from utils import create_model, get_logger 5 | from model import Model 6 | from loader import input_from_line 7 | from train import FLAGS, load_config 8 | 9 | def main(_): 10 | config = load_config(FLAGS.config_file) 11 | logger = get_logger(FLAGS.log_file) 12 | # limit GPU memory 13 | tf_config = tf.ConfigProto() 14 | tf_config.gpu_options.allow_growth = True 15 | with open(FLAGS.map_file, "rb") as f: 16 | tag_to_id, id_to_tag = pickle.load(f) 17 | with tf.Session(config=tf_config) as sess: 18 | model = create_model(sess, Model, FLAGS.ckpt_path, config, logger) 19 | while True: 20 | line = input("input sentence, please:") 21 | result = model.evaluate_line(sess, input_from_line(line, FLAGS.max_seq_len, tag_to_id), id_to_tag) 22 | print(result['entities']) 23 | 24 | if __name__ == '__main__': 25 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 26 | tf.app.run(main) -------------------------------------------------------------------------------- /rnncell.py: -------------------------------------------------------------------------------- 1 | """Module for constructing RNN Cells.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import collections 7 | import math 8 | import tensorflow as tf 9 | from tensorflow.contrib.compiler import jit 10 | from tensorflow.contrib.layers.python.layers import layers 11 | from tensorflow.python.framework import dtypes 12 | from tensorflow.python.framework import op_def_registry 13 | from tensorflow.python.framework import ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.ops import clip_ops 16 | from tensorflow.python.ops import init_ops 17 | from tensorflow.python.ops import math_ops 18 | from tensorflow.python.ops import nn_ops 19 | from tensorflow.python.ops import random_ops 20 | from tensorflow.python.ops import rnn_cell_impl 21 | from tensorflow.python.ops import variable_scope as vs 22 | from tensorflow.python.platform import tf_logging as logging 23 | from tensorflow.python.util import nest 24 | 25 | 26 | def _get_concat_variable(name, shape, dtype, num_shards): 27 | """Get a sharded variable concatenated into one tensor.""" 28 | sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 29 | if len(sharded_variable) == 1: 30 | return sharded_variable[0] 31 | 32 | concat_name = name + "/concat" 33 | concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 34 | for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 35 | if value.name == concat_full_name: 36 | return value 37 | 38 | concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 39 | ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, 40 | concat_variable) 41 | return concat_variable 42 | 43 | 44 | def _get_sharded_variable(name, shape, dtype, num_shards): 45 | """Get a list of sharded variables with the given dtype.""" 46 | if num_shards > shape[0]: 47 | raise ValueError("Too many shards: shape=%s, num_shards=%d" % 48 | (shape, num_shards)) 49 | unit_shard_size = int(math.floor(shape[0] / num_shards)) 50 | remaining_rows = shape[0] - unit_shard_size * num_shards 51 | 52 | shards = [] 53 | for i in range(num_shards): 54 | current_size = unit_shard_size 55 | if i < remaining_rows: 56 | current_size += 1 57 | shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:], 58 | dtype=dtype)) 59 | return shards 60 | 61 | 62 | class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): 63 | """Long short-term memory unit (LSTM) recurrent network cell. 64 | 65 | The default non-peephole implementation is based on: 66 | 67 | http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 68 | 69 | S. Hochreiter and J. Schmidhuber. 70 | "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 71 | 72 | The peephole implementation is based on: 73 | 74 | https://research.google.com/pubs/archive/43905.pdf 75 | 76 | Hasim Sak, Andrew Senior, and Francoise Beaufays. 77 | "Long short-term memory recurrent neural network architectures for 78 | large scale acoustic modeling." INTERSPEECH, 2014. 79 | 80 | The coupling of input and forget gate is based on: 81 | 82 | http://arxiv.org/pdf/1503.04069.pdf 83 | 84 | Greff et al. "LSTM: A Search Space Odyssey" 85 | 86 | The class uses optional peep-hole connections, and an optional projection 87 | layer. 88 | """ 89 | 90 | def __init__(self, num_units, use_peepholes=False, 91 | initializer=None, num_proj=None, proj_clip=None, 92 | num_unit_shards=1, num_proj_shards=1, 93 | forget_bias=1.0, state_is_tuple=True, 94 | activation=math_ops.tanh, reuse=None): 95 | """Initialize the parameters for an LSTM cell. 96 | 97 | Args: 98 | num_units: int, The number of units in the LSTM cell 99 | use_peepholes: bool, set True to enable diagonal/peephole connections. 100 | initializer: (optional) The initializer to use for the weight and 101 | projection matrices. 102 | num_proj: (optional) int, The output dimensionality for the projection 103 | matrices. If None, no projection is performed. 104 | proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 105 | provided, then the projected values are clipped elementwise to within 106 | `[-proj_clip, proj_clip]`. 107 | num_unit_shards: How to split the weight matrix. If >1, the weight 108 | matrix is stored across num_unit_shards. 109 | num_proj_shards: How to split the projection matrix. If >1, the 110 | projection matrix is stored across num_proj_shards. 111 | forget_bias: Biases of the forget gate are initialized by default to 1 112 | in order to reduce the scale of forgetting at the beginning of 113 | the training. 114 | state_is_tuple: If True, accepted and returned states are 2-tuples of 115 | the `c_state` and `m_state`. By default (False), they are concatenated 116 | along the column axis. This default behavior will soon be deprecated. 117 | activation: Activation function of the inner states. 118 | reuse: (optional) Python boolean describing whether to reuse variables 119 | in an existing scope. If not `True`, and the existing scope already has 120 | the given variables, an error is raised. 121 | """ 122 | super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) 123 | if not state_is_tuple: 124 | logging.warn( 125 | "%s: Using a concatenated state is slower and will soon be " 126 | "deprecated. Use state_is_tuple=True.", self) 127 | self._num_units = num_units 128 | self._use_peepholes = use_peepholes 129 | self._initializer = initializer 130 | self._num_proj = num_proj 131 | self._proj_clip = proj_clip 132 | self._num_unit_shards = num_unit_shards 133 | self._num_proj_shards = num_proj_shards 134 | self._forget_bias = forget_bias 135 | self._state_is_tuple = state_is_tuple 136 | self._activation = activation 137 | self._reuse = reuse 138 | 139 | if num_proj: 140 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 141 | if state_is_tuple else num_units + num_proj) 142 | self._output_size = num_proj 143 | else: 144 | self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units) 145 | if state_is_tuple else 2 * num_units) 146 | self._output_size = num_units 147 | 148 | @property 149 | def state_size(self): 150 | return self._state_size 151 | 152 | @property 153 | def output_size(self): 154 | return self._output_size 155 | 156 | def call(self, inputs, state): 157 | """Run one step of LSTM. 158 | 159 | Args: 160 | inputs: input Tensor, 2D, batch x num_units. 161 | state: if `state_is_tuple` is False, this must be a state Tensor, 162 | `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 163 | tuple of state Tensors, both `2-D`, with column sizes `c_state` and 164 | `m_state`. 165 | scope: VariableScope for the created subgraph; defaults to "LSTMCell". 166 | 167 | Returns: 168 | A tuple containing: 169 | - A `2-D, [batch x output_dim]`, Tensor representing the output of the 170 | LSTM after reading `inputs` when previous state was `state`. 171 | Here output_dim is: 172 | num_proj if num_proj was set, 173 | num_units otherwise. 174 | - Tensor(s) representing the new state of LSTM after reading `inputs` when 175 | the previous state was `state`. Same type and shape(s) as `state`. 176 | 177 | Raises: 178 | ValueError: If input size cannot be inferred from inputs via 179 | static shape inference. 180 | """ 181 | sigmoid = math_ops.sigmoid 182 | 183 | num_proj = self._num_units if self._num_proj is None else self._num_proj 184 | 185 | if self._state_is_tuple: 186 | (c_prev, m_prev) = state 187 | else: 188 | c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 189 | m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 190 | 191 | dtype = inputs.dtype 192 | input_size = inputs.get_shape().with_rank(2)[1] 193 | 194 | if input_size.value is None: 195 | raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 196 | 197 | # Input gate weights 198 | self.w_xi = tf.get_variable("_w_xi", [input_size.value, self._num_units]) 199 | self.w_hi = tf.get_variable("_w_hi", [self._num_units, self._num_units]) 200 | self.w_ci = tf.get_variable("_w_ci", [self._num_units, self._num_units]) 201 | # Output gate weights 202 | self.w_xo = tf.get_variable("_w_xo", [input_size.value, self._num_units]) 203 | self.w_ho = tf.get_variable("_w_ho", [self._num_units, self._num_units]) 204 | self.w_co = tf.get_variable("_w_co", [self._num_units, self._num_units]) 205 | 206 | # Cell weights 207 | self.w_xc = tf.get_variable("_w_xc", [input_size.value, self._num_units]) 208 | self.w_hc = tf.get_variable("_w_hc", [self._num_units, self._num_units]) 209 | 210 | # Initialize the bias vectors 211 | self.b_i = tf.get_variable("_b_i", [self._num_units], initializer=init_ops.zeros_initializer()) 212 | self.b_c = tf.get_variable("_b_c", [self._num_units], initializer=init_ops.zeros_initializer()) 213 | self.b_o = tf.get_variable("_b_o", [self._num_units], initializer=init_ops.zeros_initializer()) 214 | 215 | i_t = sigmoid(math_ops.matmul(inputs, self.w_xi) + 216 | math_ops.matmul(m_prev, self.w_hi) + 217 | math_ops.matmul(c_prev, self.w_ci) + 218 | self.b_i) 219 | c_t = ((1 - i_t) * c_prev + i_t * self._activation(math_ops.matmul(inputs, self.w_xc) + 220 | math_ops.matmul(m_prev, self.w_hc) + self.b_c)) 221 | 222 | o_t = sigmoid(math_ops.matmul(inputs, self.w_xo) + 223 | math_ops.matmul(m_prev, self.w_ho) + 224 | math_ops.matmul(c_t, self.w_co) + 225 | self.b_o) 226 | 227 | h_t = o_t * self._activation(c_t) 228 | 229 | new_state = (rnn_cell_impl.LSTMStateTuple(c_t, h_t) if self._state_is_tuple else 230 | array_ops.concat([c_t, h_t], 1)) 231 | return h_t, new_state -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # encoding=utf8 2 | import os 3 | import codecs 4 | import pickle 5 | import itertools 6 | from collections import OrderedDict 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from model import Model 11 | from loader import load_sentences, update_tag_scheme 12 | from loader import char_mapping, tag_mapping 13 | from loader import augment_with_pretrained, prepare_dataset 14 | from utils import get_logger, make_path, clean, create_model, save_model 15 | from utils import print_config, save_config, load_config, test_ner 16 | from data_utils import create_input, BatchManager 17 | 18 | flags = tf.app.flags 19 | flags.DEFINE_boolean("clean", False, "clean train folder") 20 | flags.DEFINE_boolean("train", False, "Wither train the model") 21 | # configurations for the model 22 | flags.DEFINE_integer("batch_size", 128, "batch size") 23 | flags.DEFINE_integer("seg_dim", 20, "Embedding size for segmentation, 0 if not used") 24 | flags.DEFINE_integer("char_dim", 100, "Embedding size for characters") 25 | flags.DEFINE_integer("lstm_dim", 200, "Num of hidden units in LSTM") 26 | flags.DEFINE_string("tag_schema", "iob", "tagging schema iobes or iob") 27 | 28 | # configurations for training 29 | flags.DEFINE_float("clip", 5, "Gradient clip") 30 | flags.DEFINE_float("dropout", 0.5, "Dropout rate") 31 | flags.DEFINE_float("lr", 0.001, "Initial learning rate") 32 | flags.DEFINE_string("optimizer", "adam", "Optimizer for training") 33 | flags.DEFINE_boolean("zeros", False, "Wither replace digits with zero") 34 | flags.DEFINE_boolean("lower", True, "Wither lower case") 35 | 36 | flags.DEFINE_integer("max_seq_len", 128, "max sequence length for bert") 37 | flags.DEFINE_integer("max_epoch", 100, "maximum training epochs") 38 | flags.DEFINE_integer("steps_check", 100, "steps per checkpoint") 39 | flags.DEFINE_string("ckpt_path", "ckpt", "Path to save model") 40 | flags.DEFINE_string("summary_path", "summary", "Path to store summaries") 41 | flags.DEFINE_string("log_file", "train.log", "File for log") 42 | flags.DEFINE_string("map_file", "maps.pkl", "file for maps") 43 | flags.DEFINE_string("vocab_file", "vocab.json", "File for vocab") 44 | flags.DEFINE_string("config_file", "config_file", "File for config") 45 | flags.DEFINE_string("script", "conlleval", "evaluation script") 46 | flags.DEFINE_string("result_path", "result", "Path for results") 47 | flags.DEFINE_string("train_file", os.path.join("data", "example.train"), "Path for train data") 48 | flags.DEFINE_string("dev_file", os.path.join("data", "example.dev"), "Path for dev data") 49 | flags.DEFINE_string("test_file", os.path.join("data", "example.test"), "Path for test data") 50 | 51 | 52 | FLAGS = tf.app.flags.FLAGS 53 | assert FLAGS.clip < 5.1, "gradient clip should't be too much" 54 | assert 0 <= FLAGS.dropout < 1, "dropout rate between 0 and 1" 55 | assert FLAGS.lr > 0, "learning rate must larger than zero" 56 | assert FLAGS.optimizer in ["adam", "sgd", "adagrad"] 57 | 58 | 59 | # config for the model 60 | def config_model(tag_to_id): 61 | config = OrderedDict() 62 | config["num_tags"] = len(tag_to_id) 63 | config["lstm_dim"] = FLAGS.lstm_dim 64 | config["batch_size"] = FLAGS.batch_size 65 | config['max_seq_len'] = FLAGS.max_seq_len 66 | 67 | config["clip"] = FLAGS.clip 68 | config["dropout_keep"] = 1.0 - FLAGS.dropout 69 | config["optimizer"] = FLAGS.optimizer 70 | config["lr"] = FLAGS.lr 71 | config["tag_schema"] = FLAGS.tag_schema 72 | config["zeros"] = FLAGS.zeros 73 | config["lower"] = FLAGS.lower 74 | return config 75 | 76 | def evaluate(sess, model, name, data, id_to_tag, logger): 77 | logger.info("evaluate:{}".format(name)) 78 | ner_results = model.evaluate(sess, data, id_to_tag) 79 | eval_lines = test_ner(ner_results, FLAGS.result_path) 80 | for line in eval_lines: 81 | logger.info(line) 82 | f1 = float(eval_lines[1].strip().split()[-1]) 83 | 84 | if name == "dev": 85 | best_test_f1 = model.best_dev_f1.eval() 86 | if f1 > best_test_f1: 87 | tf.assign(model.best_dev_f1, f1).eval() 88 | logger.info("new best dev f1 score:{:>.3f}".format(f1)) 89 | return f1 > best_test_f1 90 | elif name == "test": 91 | best_test_f1 = model.best_test_f1.eval() 92 | if f1 > best_test_f1: 93 | tf.assign(model.best_test_f1, f1).eval() 94 | logger.info("new best test f1 score:{:>.3f}".format(f1)) 95 | return f1 > best_test_f1 96 | 97 | def train(): 98 | # load data sets 99 | train_sentences = load_sentences(FLAGS.train_file, FLAGS.lower, FLAGS.zeros) 100 | dev_sentences = load_sentences(FLAGS.dev_file, FLAGS.lower, FLAGS.zeros) 101 | test_sentences = load_sentences(FLAGS.test_file, FLAGS.lower, FLAGS.zeros) 102 | 103 | # Use selected tagging scheme (IOB / IOBES) 104 | #update_tag_scheme(train_sentences, FLAGS.tag_schema) 105 | #update_tag_scheme(test_sentences, FLAGS.tag_schema) 106 | 107 | # create maps if not exist 108 | if not os.path.isfile(FLAGS.map_file): 109 | # Create a dictionary and a mapping for tags 110 | _t, tag_to_id, id_to_tag = tag_mapping(train_sentences) 111 | with open(FLAGS.map_file, "wb") as f: 112 | pickle.dump([tag_to_id, id_to_tag], f) 113 | else: 114 | with open(FLAGS.map_file, "rb") as f: 115 | tag_to_id, id_to_tag = pickle.load(f) 116 | 117 | # prepare data, get a collection of list containing index 118 | train_data = prepare_dataset( 119 | train_sentences, FLAGS.max_seq_len, tag_to_id, FLAGS.lower 120 | ) 121 | dev_data = prepare_dataset( 122 | dev_sentences, FLAGS.max_seq_len, tag_to_id, FLAGS.lower 123 | ) 124 | test_data = prepare_dataset( 125 | test_sentences, FLAGS.max_seq_len, tag_to_id, FLAGS.lower 126 | ) 127 | print("%i / %i / %i sentences in train / dev / test." % ( 128 | len(train_data), 0, len(test_data))) 129 | 130 | train_manager = BatchManager(train_data, FLAGS.batch_size) 131 | dev_manager = BatchManager(dev_data, FLAGS.batch_size) 132 | test_manager = BatchManager(test_data, FLAGS.batch_size) 133 | # make path for store log and model if not exist 134 | make_path(FLAGS) 135 | if os.path.isfile(FLAGS.config_file): 136 | config = load_config(FLAGS.config_file) 137 | else: 138 | config = config_model(tag_to_id) 139 | save_config(config, FLAGS.config_file) 140 | make_path(FLAGS) 141 | 142 | log_path = os.path.join("log", FLAGS.log_file) 143 | logger = get_logger(log_path) 144 | print_config(config, logger) 145 | 146 | # limit GPU memory 147 | tf_config = tf.ConfigProto() 148 | tf_config.gpu_options.allow_growth = True 149 | steps_per_epoch = train_manager.len_data 150 | with tf.Session(config=tf_config) as sess: 151 | model = create_model(sess, Model, FLAGS.ckpt_path, config, logger) 152 | 153 | logger.info("start training") 154 | loss = [] 155 | for i in range(100): 156 | for batch in train_manager.iter_batch(shuffle=True): 157 | step, batch_loss = model.run_step(sess, True, batch) 158 | 159 | loss.append(batch_loss) 160 | if step % FLAGS.steps_check == 0: 161 | iteration = step // steps_per_epoch + 1 162 | logger.info("iteration:{} step:{}/{}, " 163 | "NER loss:{:>9.6f}".format( 164 | iteration, step%steps_per_epoch, steps_per_epoch, np.mean(loss))) 165 | loss = [] 166 | 167 | best = evaluate(sess, model, "dev", dev_manager, id_to_tag, logger) 168 | if best: 169 | save_model(sess, model, FLAGS.ckpt_path, logger, global_steps=step) 170 | evaluate(sess, model, "test", test_manager, id_to_tag, logger) 171 | 172 | def main(_): 173 | FLAGS.train = True 174 | FLAGS.clean = True 175 | clean(FLAGS) 176 | train() 177 | 178 | if __name__ == "__main__": 179 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 180 | tf.app.run(main) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import shutil 5 | import logging 6 | import codecs 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from conlleval import return_report 11 | 12 | from bert import modeling 13 | 14 | models_path = "./models" 15 | eval_path = "./evaluation" 16 | eval_temp = os.path.join(eval_path, "temp") 17 | eval_script = os.path.join(eval_path, "conlleval") 18 | 19 | 20 | def get_logger(log_file): 21 | logger = logging.getLogger(log_file) 22 | logger.setLevel(logging.DEBUG) 23 | fh = logging.FileHandler(log_file) 24 | fh.setLevel(logging.DEBUG) 25 | ch = logging.StreamHandler() 26 | ch.setLevel(logging.INFO) 27 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 28 | ch.setFormatter(formatter) 29 | fh.setFormatter(formatter) 30 | logger.addHandler(ch) 31 | logger.addHandler(fh) 32 | return logger 33 | 34 | 35 | # def test_ner(results, path): 36 | # """ 37 | # Run perl script to evaluate model 38 | # """ 39 | # script_file = "conlleval" 40 | # output_file = os.path.join(path, "ner_predict.utf8") 41 | # result_file = os.path.join(path, "ner_result.utf8") 42 | # with open(output_file, "w") as f: 43 | # to_write = [] 44 | # for block in results: 45 | # for line in block: 46 | # to_write.append(line + "\n") 47 | # to_write.append("\n") 48 | # 49 | # f.writelines(to_write) 50 | # os.system("perl {} < {} > {}".format(script_file, output_file, result_file)) 51 | # eval_lines = [] 52 | # with open(result_file) as f: 53 | # for line in f: 54 | # eval_lines.append(line.strip()) 55 | # return eval_lines 56 | 57 | 58 | def test_ner(results, path): 59 | """ 60 | Run perl script to evaluate model 61 | """ 62 | output_file = os.path.join(path, "ner_predict.utf8") 63 | with codecs.open(output_file, "w", 'utf8') as f: 64 | to_write = [] 65 | for block in results: 66 | for line in block: 67 | to_write.append(line + "\n") 68 | to_write.append("\n") 69 | 70 | f.writelines(to_write) 71 | eval_lines = return_report(output_file) 72 | return eval_lines 73 | 74 | 75 | def print_config(config, logger): 76 | """ 77 | Print configuration of the model 78 | """ 79 | for k, v in config.items(): 80 | logger.info("{}:\t{}".format(k.ljust(15), v)) 81 | 82 | 83 | def make_path(params): 84 | """ 85 | Make folders for training and evaluation 86 | """ 87 | if not os.path.isdir(params.result_path): 88 | os.makedirs(params.result_path) 89 | if not os.path.isdir(params.ckpt_path): 90 | os.makedirs(params.ckpt_path) 91 | if not os.path.isdir("log"): 92 | os.makedirs("log") 93 | 94 | 95 | def clean(params): 96 | """ 97 | Clean current folder 98 | remove saved model and training log 99 | """ 100 | if os.path.isfile(params.vocab_file): 101 | os.remove(params.vocab_file) 102 | 103 | if os.path.isfile(params.map_file): 104 | os.remove(params.map_file) 105 | 106 | if os.path.isdir(params.ckpt_path): 107 | shutil.rmtree(params.ckpt_path) 108 | 109 | if os.path.isdir(params.summary_path): 110 | shutil.rmtree(params.summary_path) 111 | 112 | if os.path.isdir(params.result_path): 113 | shutil.rmtree(params.result_path) 114 | 115 | if os.path.isdir("log"): 116 | shutil.rmtree("log") 117 | 118 | if os.path.isdir("__pycache__"): 119 | shutil.rmtree("__pycache__") 120 | 121 | if os.path.isfile(params.config_file): 122 | os.remove(params.config_file) 123 | 124 | if os.path.isfile(params.vocab_file): 125 | os.remove(params.vocab_file) 126 | 127 | 128 | def save_config(config, config_file): 129 | """ 130 | Save configuration of the model 131 | parameters are stored in json format 132 | """ 133 | with open(config_file, "w", encoding="utf8") as f: 134 | json.dump(config, f, ensure_ascii=False, indent=4) 135 | 136 | 137 | def load_config(config_file): 138 | """ 139 | Load configuration of the model 140 | parameters are stored in json format 141 | """ 142 | with open(config_file, encoding="utf8") as f: 143 | return json.load(f) 144 | 145 | 146 | def convert_to_text(line): 147 | """ 148 | Convert conll data to text 149 | """ 150 | to_print = [] 151 | for item in line: 152 | 153 | try: 154 | if item[0] == " ": 155 | to_print.append(" ") 156 | continue 157 | word, gold, tag = item.split(" ") 158 | if tag[0] in "SB": 159 | to_print.append("[") 160 | to_print.append(word) 161 | if tag[0] in "SE": 162 | to_print.append("@" + tag.split("-")[-1]) 163 | to_print.append("]") 164 | except: 165 | print(list(item)) 166 | return "".join(to_print) 167 | 168 | 169 | def save_model(sess, model, path, logger, global_steps): 170 | checkpoint_path = os.path.join(path, "ner.ckpt") 171 | model.saver.save(sess, checkpoint_path, global_step = global_steps) 172 | logger.info("model saved") 173 | 174 | 175 | def create_model(session, Model_class, path, config, logger): 176 | # create model, reuse parameters if exists 177 | model = Model_class(config) 178 | 179 | ckpt = tf.train.get_checkpoint_state(path) 180 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 181 | logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path) 182 | #saver = tf.train.import_meta_graph('ckpt/ner.ckpt.meta') 183 | #saver.restore(session, tf.train.latest_checkpoint("ckpt/")) 184 | model.saver.restore(session, ckpt.model_checkpoint_path) 185 | else: 186 | logger.info("Created model with fresh parameters.") 187 | session.run(tf.global_variables_initializer()) 188 | return model 189 | 190 | 191 | def result_to_json(string, tags): 192 | item = {"string": string, "entities": []} 193 | entity_name = "" 194 | entity_start = 0 195 | idx = 0 196 | for char, tag in zip(string, tags): 197 | if tag[0] == "S": 198 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]}) 199 | elif tag[0] == "B": 200 | entity_name += char 201 | entity_start = idx 202 | elif tag[0] == "I": 203 | entity_name += char 204 | elif tag[0] == "E": 205 | entity_name += char 206 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx + 1, "type": tag[2:]}) 207 | entity_name = "" 208 | else: 209 | entity_name = "" 210 | entity_start = idx 211 | idx += 1 212 | return item 213 | 214 | def bio_to_json(string, tags): 215 | item = {"string": string, "entities": []} 216 | entity_name = "" 217 | entity_start = 0 218 | iCount = 0 219 | entity_tag = "" 220 | #assert len(string)==len(tags), "string length is: {}, tags length is: {}".format(len(string), len(tags)) 221 | 222 | for c_idx in range(len(tags)): 223 | c, tag = string[c_idx], tags[c_idx] 224 | if c_idx < len(tags)-1: 225 | tag_next = tags[c_idx+1] 226 | else: 227 | tag_next = '' 228 | 229 | if tag[0] == 'B': 230 | entity_tag = tag[2:] 231 | entity_name = c 232 | entity_start = iCount 233 | if tag_next[2:] != entity_tag: 234 | item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]}) 235 | elif tag[0] == "I": 236 | if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O': 237 | tags[c_idx] = 'O' 238 | pass 239 | else: 240 | entity_name = entity_name + c 241 | if tag_next[2:] != entity_tag: 242 | item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag}) 243 | entity_name = '' 244 | iCount += 1 245 | return item 246 | 247 | def convert_single_example(char_line, tag_to_id, max_seq_length, tokenizer, label_line): 248 | """ 249 | 将一个样本进行分析,然后将字转化为id, 标签转化为lb 250 | """ 251 | text_list = char_line.split(' ') 252 | label_list = label_line.split(' ') 253 | 254 | tokens = [] 255 | labels = [] 256 | for i, word in enumerate(text_list): 257 | token = tokenizer.tokenize(word) 258 | tokens.extend(token) 259 | label_1 = label_list[i] 260 | for m in range(len(token)): 261 | if m == 0: 262 | labels.append(label_1) 263 | else: 264 | labels.append("X") 265 | # 序列截断 266 | if len(tokens) >= max_seq_length - 1: 267 | tokens = tokens[0:(max_seq_length - 2)] 268 | labels = labels[0:(max_seq_length - 2)] 269 | ntokens = [] 270 | segment_ids = [] 271 | label_ids = [] 272 | ntokens.append("[CLS]") 273 | segment_ids.append(0) 274 | # append("O") or append("[CLS]") not sure! 275 | label_ids.append(tag_to_id["[CLS]"]) 276 | for i, token in enumerate(tokens): 277 | ntokens.append(token) 278 | segment_ids.append(0) 279 | label_ids.append(tag_to_id[labels[i]]) 280 | ntokens.append("[SEP]") 281 | segment_ids.append(0) 282 | # append("O") or append("[SEP]") not sure! 283 | label_ids.append(tag_to_id["[SEP]"]) 284 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 285 | input_mask = [1] * len(input_ids) 286 | 287 | # padding 288 | while len(input_ids) < max_seq_length: 289 | input_ids.append(0) 290 | input_mask.append(0) 291 | segment_ids.append(0) 292 | # we don't concerned about it! 293 | label_ids.append(0) 294 | ntokens.append("**NULL**") 295 | 296 | return input_ids, input_mask, segment_ids, label_ids --------------------------------------------------------------------------------