├── README.md ├── bert_base ├── __init__.py ├── bert │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── create_pretraining_data.py │ ├── extract_features.py │ ├── modeling.py │ ├── modeling_test.py │ ├── multilingual.md │ ├── optimization.py │ ├── optimization_test.py │ ├── requirements.txt │ ├── run_classifier.py │ ├── run_pretraining.py │ ├── run_squad.py │ ├── sample_text.txt │ ├── tokenization.py │ └── tokenization_test.py ├── client │ └── __init__.py ├── runs │ └── __init__.py ├── server │ ├── __init__.py │ ├── graph.py │ ├── helper.py │ ├── http.py │ ├── simple_flask_http_service.py │ └── zmq_decor.py └── train │ ├── __init__.py │ ├── bert_lstm_ner.py │ ├── conlleval.pl │ ├── conlleval.py │ ├── lstm_crf_layer.py │ ├── models.py │ ├── tf_metrics.py │ └── train_helper.py ├── build.sh ├── client_test.py ├── data_process.py ├── pictures ├── 03E18A6A9C16082CF22A9E8837F7E35F.png ├── ner_help.png ├── picture1.png ├── picture2.png ├── predict.png ├── server_help.png ├── server_ner_rst.png ├── server_run.png ├── service_1.png ├── service_2.png └── text_class_rst.png ├── requirement.txt ├── run.py ├── setup.py ├── terminal_predict.py └── thu_classification.py /README.md: -------------------------------------------------------------------------------- 1 | # BERT-BiLSTM-CRF-NER 2 | Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning 3 | 4 | 使用谷歌的BERT模型在BLSTM-CRF模型上进行预训练用于中文命名实体识别的Tensorflow代码' 5 | 6 | 中文文档请查看https://blog.csdn.net/macanv/article/details/85684284 如果对您有帮助,麻烦点个star,谢谢~~ 7 | 8 | Welcome to star this repository! 9 | 10 | The Chinese training data($PATH/NERdata/) come from:https://github.com/zjy-ucas/ChineseNER 11 | 12 | The CoNLL-2003 data($PATH/NERdata/ori/) come from:https://github.com/kyzhouhzau/BERT-NER 13 | 14 | The evaluation codes come from:https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 15 | 16 | 17 | Try to implement NER work based on google's BERT code and BiLSTM-CRF network! 18 | This project may be more close to process Chinese data. but other language only need Modify a small amount of code. 19 | 20 | THIS PROJECT ONLY SUPPORT Python3. 21 | ################################################################### 22 | ## Download project and install 23 | You can install this project by: 24 | ``` 25 | pip install bert-base==0.0.9 -i https://pypi.python.org/simple 26 | ``` 27 | 28 | OR 29 | ```angular2html 30 | git clone https://github.com/macanv/BERT-BiLSTM-CRF-NER 31 | cd BERT-BiLSTM-CRF-NER/ 32 | python3 setup.py install 33 | ``` 34 | 35 | if you do not want to install, you just need clone this project and reference the file of to train the model or start the service. 36 | 37 | ## UPDATE: 38 | - 2020.2.6 add simple flask ner service code 39 | - 2019.2.25 Fix some bug for ner service 40 | - 2019.2.19: add text classification service 41 | - fix Missing loss error 42 | - add label_list params in train process, so you can using -label_list xxx to special labels in training process. 43 | 44 | 45 | ## Train model: 46 | You can use -help to view the relevant parameters of the training named entity recognition model, where data_dir, bert_config_file, output_dir, init_checkpoint, vocab_file must be specified. 47 | ```angular2html 48 | bert-base-ner-train -help 49 | ``` 50 | ![](./pictures/ner_help.png) 51 | 52 | 53 | train/dev/test dataset is like this: 54 | ``` 55 | 海 O 56 | 钓 O 57 | 比 O 58 | 赛 O 59 | 地 O 60 | 点 O 61 | 在 O 62 | 厦 B-LOC 63 | 门 I-LOC 64 | 与 O 65 | 金 B-LOC 66 | 门 I-LOC 67 | 之 O 68 | 间 O 69 | 的 O 70 | 海 O 71 | 域 O 72 | 。 O 73 | ``` 74 | The first one of each line is a token, the second is token's label, and the line is divided by a blank line. The maximum length of each sentence is [max_seq_length] params. 75 | You can get training data from above two git repos 76 | You can training ner model by running below command: 77 | ```angular2html 78 | bert-base-ner-train \ 79 | -data_dir {your dataset dir}\ 80 | -output_dir {training output dir}\ 81 | -init_checkpoint {Google BERT model dir}\ 82 | -bert_config_file {bert_config.json under the Google BERT model dir} \ 83 | -vocab_file {vocab.txt under the Google BERT model dir} 84 | ``` 85 | like my init_checkpoint: 86 | ``` 87 | init_checkpoint = F:\chinese_L-12_H-768_A-12\bert_model.ckpt 88 | ``` 89 | you can special labels using -label_list params, the project get labels from training data. 90 | ```angular2html 91 | # using , split 92 | -labels 'B-LOC, I-LOC ...' 93 | OR save label in a file like labels.txt, one line one label 94 | -labels labels.txt 95 | ``` 96 | 97 | After training model, the NER model will be saved in {output_dir} which you special above cmd line. 98 | ##### My Training environment:Tesla P40 24G mem 99 | 100 | ## As Service 101 | Many server and client code comes from excellent open source projects: [bert as service of hanxiao](https://github.com/hanxiao/bert-as-service) If my code violates any license agreement, please let me know and I will correct it the first time. 102 | ~~and NER server/client service code can be applied to other tasks with simple modifications, such as text categorization, which I will provide later.~~ 103 | this project private Named Entity Recognition and Text Classification server service. 104 | Welcome to submit your request or share your model, if you want to share it on Github or my work. 105 | 106 | You can use -help to view the relevant parameters of the NER as Service: 107 | which model_dir, bert_model_dir is need 108 | ``` 109 | bert-base-serving-start -help 110 | ``` 111 | ![](./pictures/server_help.png) 112 | 113 | and than you can using below cmd start ner service: 114 | ```angular2html 115 | bert-base-serving-start \ 116 | -model_dir C:\workspace\python\BERT_Base\output\ner2 \ 117 | -bert_model_dir F:\chinese_L-12_H-768_A-12 118 | -model_pb_dir C:\workspace\python\BERT_Base\model_pb_dir 119 | -mode NER 120 | ``` 121 | or text classification service: 122 | ```angular2html 123 | bert-base-serving-start \ 124 | -model_dir C:\workspace\python\BERT_Base\output\ner2 \ 125 | -bert_model_dir F:\chinese_L-12_H-768_A-12 126 | -model_pb_dir C:\workspace\python\BERT_Base\model_pb_dir 127 | -mode CLASS 128 | -max_seq_len 202 129 | ``` 130 | 131 | as you see: 132 | mode: If mode is NER/CLASS, then the service identified by the Named Entity Recognition/Text Classification will be started. If it is BERT, it will be the same as the [bert as service] project. 133 | bert_model_dir: bert_model_dir is a BERT model, you can download from https://github.com/google-research/bert 134 | ner_model_dir: your ner model checkpoint dir 135 | model_pb_dir: model freeze save dir, after run optimize func, there will contains like ner_model.pb binary file 136 | >You can download my ner model from:https://pan.baidu.com/s/1m9VcueQ5gF-TJc00sFD88w, ex_code: guqq 137 | > Or text classification model from: https://pan.baidu.com/s/1oFPsOUh1n5AM2HjDIo2XCw, ex_code: bbu8 138 | Set ner_mode.pb/classification_model.pb to model_pb_dir, and set other file to model_dir(Different models need to be stored separately, you can set ner models label_list.pkl and label2id.pkl to model_dir/ner/ and set text classification file to model_dir/text_classification) , Text classification model can classify 12 categories of Chinese data: '游戏', '娱乐', '财经', '时政', '股票', '教育', '社会', '体育', '家居', '时尚', '房产', '彩票' 139 | 140 | You can see below service starting info: 141 | ![](./pictures/service_1.png) 142 | ![](./pictures/service_2.png) 143 | 144 | 145 | you can using below code test client: 146 | #### 1. NER Client 147 | ```angular2html 148 | import time 149 | from bert_base.client import BertClient 150 | 151 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='NER') as bc: 152 | start_t = time.perf_counter() 153 | str = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。' 154 | rst = bc.encode([str, str]) 155 | print('rst:', rst) 156 | print(time.perf_counter() - start_t) 157 | ``` 158 | you can see this after run the above code: 159 | ![](./pictures/server_ner_rst.png) 160 | If you want to customize the word segmentation method, you only need to make the following simple changes on the client side code. 161 | 162 | ```angular2html 163 | rst = bc.encode([list(str), list(str)], is_tokenized=True) 164 | ``` 165 | 166 | #### 2. Text Classification Client 167 | ```angular2html 168 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='CLASS') as bc: 169 | start_t = time.perf_counter() 170 | str1 = '北京时间2月17日凌晨,第69届柏林国际电影节公布主竞赛单元获奖名单,王景春、咏梅凭借王小帅执导的中国影片《地久天长》连夺最佳男女演员双银熊大奖,这是中国演员首次包揽柏林电影节最佳男女演员奖,为华语影片刷新纪录。与此同时,由青年导演王丽娜执导的影片《第一次的别离》也荣获了本届柏林电影节新生代单元国际评审团最佳影片,可以说,在经历数个获奖小年之后,中国电影在柏林影展再次迎来了高光时刻。' 171 | str2 = '受粤港澳大湾区规划纲要提振,港股周二高开,恒指开盘上涨近百点,涨幅0.33%,报28440.49点,相关概念股亦集体上涨,电子元件、新能源车、保险、基建概念多数上涨。粤泰股份、珠江实业、深天地A等10余股涨停;中兴通讯、丘钛科技、舜宇光学分别高开1.4%、4.3%、1.6%。比亚迪电子、比亚迪股份、光宇国际分别高开1.7%、1.2%、1%。越秀交通基建涨近2%,粤海投资、碧桂园等多股涨超1%。其他方面,日本软银集团股价上涨超0.4%,推动日经225和东证指数齐齐高开,但随后均回吐涨幅转跌东证指数跌0.2%,日经225指数跌0.11%,报21258.4点。受芯片制造商SK海力士股价下跌1.34%拖累,韩国综指下跌0.34%至2203.9点。澳大利亚ASX 200指数早盘上涨0.39%至6089.8点,大多数行业板块均现涨势。在保健品品牌澳佳宝下调下半财年的销售预期后,其股价暴跌超过23%。澳佳宝CEO亨弗里(Richard Henfrey)认为,公司下半年的利润可能会低于上半年,主要是受到销售额疲弱的影响。同时,亚市早盘澳洲联储公布了2月会议纪要,政策委员将继续谨慎评估经济增长前景,因前景充满不确定性的影响,稳定当前的利率水平比贸然调整利率更为合适,而且当前利率水平将有利于趋向通胀目标及改善就业,当前劳动力市场数据表现强势于其他经济数据。另一方面,经济增长前景亦令消费者消费意愿下滑,如果房价出现下滑,消费可能会进一步疲弱。在澳洲联储公布会议纪要后,澳元兑美元下跌近30点,报0.7120 。美元指数在昨日触及96.65附近的低点之后反弹至96.904。日元兑美元报110.56,接近上一交易日的低点。' 172 | str3 = '新京报快讯 据国家市场监管总局消息,针对媒体报道水饺等猪肉制品检出非洲猪瘟病毒核酸阳性问题,市场监管总局、农业农村部已要求企业立即追溯猪肉原料来源并对猪肉制品进行了处置。两部门已派出联合督查组调查核实相关情况,要求猪肉制品生产企业进一步加强对猪肉原料的管控,落实检验检疫票证查验规定,完善非洲猪瘟检测和复核制度,防止染疫猪肉原料进入食品加工环节。市场监管总局、农业农村部等部门要求各地全面落实防控责任,强化防控措施,规范信息报告和发布,对不按要求履行防控责任的企业,一旦发现将严厉查处。专家认为,非洲猪瘟不是人畜共患病,虽然对猪有致命危险,但对人没有任何危害,属于只传猪不传人型病毒,不会影响食品安全。开展猪肉制品病毒核酸检测,可为防控溯源工作提供线索。' 173 | rst = bc.encode([str1, str2, str3]) 174 | print('rst:', rst) 175 | print('time used:{}'.format(time.perf_counter() - start_t)) 176 | ``` 177 | you can see this after run the above code: 178 | ![](./pictures/text_class_rst.png) 179 | 180 | Note that it can not start NER service and Text Classification service together. but you can using twice command line start ner service and text classification with different port. 181 | 182 | ### Flask server service 183 | sometimes, multi thread deep learning model service may not use C/S service, you can useing simple http service replace that, like using flask. 184 | now you can reference code:bert_base/server/simple_flask_http_service.py,building your simple http server service 185 | 186 | ## License 187 | MIT. 188 | 189 | # The following tutorial is an old version and will be removed in the future. 190 | 191 | ## How to train 192 | #### 1. Download BERT chinese model : 193 | ``` 194 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 195 | ``` 196 | #### 2. create output dir 197 | create output path in project path: 198 | ```angular2html 199 | mkdir output 200 | ``` 201 | #### 3. Train model 202 | 203 | ##### first method 204 | ``` 205 | python3 bert_lstm_ner.py \ 206 | --task_name="NER" \ 207 | --do_train=True \ 208 | --do_eval=True \ 209 | --do_predict=True 210 | --data_dir=NERdata \ 211 | --vocab_file=checkpoint/vocab.txt \ 212 | --bert_config_file=checkpoint/bert_config.json \ 213 | --init_checkpoint=checkpoint/bert_model.ckpt \ 214 | --max_seq_length=128 \ 215 | --train_batch_size=32 \ 216 | --learning_rate=2e-5 \ 217 | --num_train_epochs=3.0 \ 218 | --output_dir=./output/result_dir/ 219 | ``` 220 | ##### OR replace the BERT path and project path in bert_lstm_ner.py 221 | ``` 222 | if os.name == 'nt': #windows path config 223 | bert_path = '{your BERT model path}' 224 | root_path = '{project path}' 225 | else: # linux path config 226 | bert_path = '{your BERT model path}' 227 | root_path = '{project path}' 228 | ``` 229 | Than Run: 230 | ```angular2html 231 | python3 bert_lstm_ner.py 232 | ``` 233 | 234 | ### USING BLSTM-CRF OR ONLY CRF FOR DECODE! 235 | Just alter bert_lstm_ner.py line of 450, the params of the function of add_blstm_crf_layer: crf_only=True or False 236 | 237 | ONLY CRF output layer: 238 | ``` 239 | blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=FLAGS.lstm_size, cell_type=FLAGS.cell, num_layers=FLAGS.num_layers, 240 | dropout_rate=FLAGS.droupout_rate, initializers=initializers, num_labels=num_labels, 241 | seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training) 242 | rst = blstm_crf.add_blstm_crf_layer(crf_only=True) 243 | ``` 244 | 245 | 246 | BiLSTM with CRF output layer 247 | ``` 248 | blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=FLAGS.lstm_size, cell_type=FLAGS.cell, num_layers=FLAGS.num_layers, 249 | dropout_rate=FLAGS.droupout_rate, initializers=initializers, num_labels=num_labels, 250 | seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training) 251 | rst = blstm_crf.add_blstm_crf_layer(crf_only=False) 252 | ``` 253 | 254 | ## Result: 255 | all params using default 256 | #### In dev data set: 257 | ![](./pictures/picture1.png) 258 | 259 | #### In test data set 260 | ![](./pictures/picture2.png) 261 | 262 | #### entity leval result: 263 | last two result are label level result, the entitly level result in code of line 796-798,this result will be output in predict process. 264 | show my entity level result : 265 | ![](./pictures/03E18A6A9C16082CF22A9E8837F7E35F.png) 266 | > my model can download from baidu cloud: 267 | >链接:https://pan.baidu.com/s/1GfDFleCcTv5393ufBYdgqQ 提取码:4cus 268 | NOTE: My model is trained by crf_only params 269 | 270 | ## ONLINE PREDICT 271 | If model is train finished, just run 272 | ```angular2html 273 | python3 terminal_predict.py 274 | ``` 275 | ![](./pictures/predict.png) 276 | 277 | ## Using NER as Service 278 | 279 | #### Service 280 | Using NER as Service is simple, you just need to run the python script below in the project root path: 281 | ```angular2html 282 | python3 runs.py \ 283 | -mode NER 284 | -bert_model_dir /home/macan/ml/data/chinese_L-12_H-768_A-12 \ 285 | -ner_model_dir /home/macan/ml/data/bert_ner \ 286 | -model_pd_dir /home/macan/ml/workspace/BERT_Base/output/predict_optimizer \ 287 | -num_worker 8 288 | ``` 289 | 290 | 291 | You can download my ner model from:https://pan.baidu.com/s/1m9VcueQ5gF-TJc00sFD88w, ex_code: guqq 292 | Set ner_mode.pb to model_pd_dir, and set other file to ner_model_dir and than run last cmd 293 | ![](./pictures/service_1.png) 294 | ![](./pictures/service_2.png) 295 | 296 | 297 | #### Client 298 | The client using methods can reference client_test.py script 299 | ```angular2html 300 | import time 301 | from client.client import BertClient 302 | 303 | ner_model_dir = 'C:\workspace\python\BERT_Base\output\predict_ner' 304 | with BertClient( ner_model_dir=ner_model_dir, show_server_config=False, check_version=False, check_length=False, mode='NER') as bc: 305 | start_t = time.perf_counter() 306 | str = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。' 307 | rst = bc.encode([str]) 308 | print('rst:', rst) 309 | print(time.perf_counter() - start_t) 310 | ``` 311 | NOTE: input format you can sometime reference bert as service project. 312 | Welcome to provide more client language code like java or others. 313 | ## Using yourself data to train 314 | if you want to use yourself data to train ner model,you just modify the get_labes func. 315 | ```angular2html 316 | def get_labels(self): 317 | return ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] 318 | ``` 319 | NOTE: "X", “[CLS]”, “[SEP]” These three are necessary, you just replace your data label to this return list. 320 | Or you can use last code lets the program automatically get the label from training data 321 | ```angular2html 322 | def get_labels(self): 323 | # 通过读取train文件获取标签的方法会出现一定的风险。 324 | if os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')): 325 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'rb') as rf: 326 | self.labels = pickle.load(rf) 327 | else: 328 | if len(self.labels) > 0: 329 | self.labels = self.labels.union(set(["X", "[CLS]", "[SEP]"])) 330 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as rf: 331 | pickle.dump(self.labels, rf) 332 | else: 333 | self.labels = ["O", 'B-TIM', 'I-TIM', "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] 334 | return self.labels 335 | 336 | ``` 337 | 338 | 339 | ## NEW UPDATE 340 | 2019.1.30 Support pip install and command line control 341 | 342 | 2019.1.30 Add Service/Client for NER process 343 | 344 | 2019.1.9: Add code to remove the adam related parameters in the model, and reduce the size of the model file from 1.3GB to 400MB. 345 | 346 | 2019.1.3: Add online predict code 347 | 348 | 349 | 350 | ## reference: 351 | + The evaluation codes come from:https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 352 | 353 | + [https://github.com/google-research/bert](https://github.com/google-research/bert) 354 | 355 | + [https://github.com/kyzhouhzau/BERT-NER](https://github.com/kyzhouhzau/BERT-NER) 356 | 357 | + [https://github.com/zjy-ucas/ChineseNER](https://github.com/zjy-ucas/ChineseNER) 358 | 359 | + [https://github.com/hanxiao/bert-as-service](https://github.com/hanxiao/bert-as-service) 360 | > Any problem please open issue OR email me(ma_cancan@163.com) 361 | -------------------------------------------------------------------------------- /bert_base/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @Time : 2019/1/30 19:09 6 | @Author : MaCan (ma_cancan@163.com) 7 | @File : __init__.py.py 8 | """ -------------------------------------------------------------------------------- /bert_base/bert/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /bert_base/bert/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /bert_base/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/bert_base/bert/__init__.py -------------------------------------------------------------------------------- /bert_base/bert/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | 24 | import tokenization 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("input_file", None, 32 | "Input raw text file (or comma-separated list of files).") 33 | 34 | flags.DEFINE_string( 35 | "output_file", None, 36 | "Output TF example file (or comma-separated list of files).") 37 | 38 | flags.DEFINE_string("vocab_file", None, 39 | "The vocabulary file that the BERT model was trained on.") 40 | 41 | flags.DEFINE_bool( 42 | "do_lower_case", True, 43 | "Whether to lower case the input text. Should be True for uncased " 44 | "models and False for cased models.") 45 | 46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 47 | 48 | flags.DEFINE_integer("max_predictions_per_seq", 20, 49 | "Maximum number of masked LM predictions per sequence.") 50 | 51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 52 | 53 | flags.DEFINE_integer( 54 | "dupe_factor", 10, 55 | "Number of times to duplicate the input data (with different masks).") 56 | 57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 58 | 59 | flags.DEFINE_float( 60 | "short_seq_prob", 0.1, 61 | "Probability of creating sequences which are shorter than the " 62 | "maximum length.") 63 | 64 | 65 | class TrainingInstance(object): 66 | """A single training instance (sentence pair).""" 67 | 68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 69 | is_random_next): 70 | self.tokens = tokens 71 | self.segment_ids = segment_ids 72 | self.is_random_next = is_random_next 73 | self.masked_lm_positions = masked_lm_positions 74 | self.masked_lm_labels = masked_lm_labels 75 | 76 | def __str__(self): 77 | s = "" 78 | s += "tokens: %s\n" % (" ".join( 79 | [tokenization.printable_text(x) for x in self.tokens])) 80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 81 | s += "is_random_next: %s\n" % self.is_random_next 82 | s += "masked_lm_positions: %s\n" % (" ".join( 83 | [str(x) for x in self.masked_lm_positions])) 84 | s += "masked_lm_labels: %s\n" % (" ".join( 85 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 86 | s += "\n" 87 | return s 88 | 89 | def __repr__(self): 90 | return self.__str__() 91 | 92 | 93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 94 | max_predictions_per_seq, output_files): 95 | """Create TF example files from `TrainingInstance`s.""" 96 | writers = [] 97 | for output_file in output_files: 98 | writers.append(tf.python_io.TFRecordWriter(output_file)) 99 | 100 | writer_index = 0 101 | 102 | total_written = 0 103 | for (inst_index, instance) in enumerate(instances): 104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 105 | input_mask = [1] * len(input_ids) 106 | segment_ids = list(instance.segment_ids) 107 | assert len(input_ids) <= max_seq_length 108 | 109 | while len(input_ids) < max_seq_length: 110 | input_ids.append(0) 111 | input_mask.append(0) 112 | segment_ids.append(0) 113 | 114 | assert len(input_ids) == max_seq_length 115 | assert len(input_mask) == max_seq_length 116 | assert len(segment_ids) == max_seq_length 117 | 118 | masked_lm_positions = list(instance.masked_lm_positions) 119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 120 | masked_lm_weights = [1.0] * len(masked_lm_ids) 121 | 122 | while len(masked_lm_positions) < max_predictions_per_seq: 123 | masked_lm_positions.append(0) 124 | masked_lm_ids.append(0) 125 | masked_lm_weights.append(0.0) 126 | 127 | next_sentence_label = 1 if instance.is_random_next else 0 128 | 129 | features = collections.OrderedDict() 130 | features["input_ids"] = create_int_feature(input_ids) 131 | features["input_mask"] = create_int_feature(input_mask) 132 | features["segment_ids"] = create_int_feature(segment_ids) 133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 137 | 138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 139 | 140 | writers[writer_index].write(tf_example.SerializeToString()) 141 | writer_index = (writer_index + 1) % len(writers) 142 | 143 | total_written += 1 144 | 145 | if inst_index < 20: 146 | tf.logging.info("*** Example ***") 147 | tf.logging.info("tokens: %s" % " ".join( 148 | [tokenization.printable_text(x) for x in instance.tokens])) 149 | 150 | for feature_name in features.keys(): 151 | feature = features[feature_name] 152 | values = [] 153 | if feature.int64_list.value: 154 | values = feature.int64_list.value 155 | elif feature.float_list.value: 156 | values = feature.float_list.value 157 | tf.logging.info( 158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 159 | 160 | for writer in writers: 161 | writer.close() 162 | 163 | tf.logging.info("Wrote %d total instances", total_written) 164 | 165 | 166 | def create_int_feature(values): 167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 168 | return feature 169 | 170 | 171 | def create_float_feature(values): 172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 173 | return feature 174 | 175 | 176 | def create_training_instances(input_files, tokenizer, max_seq_length, 177 | dupe_factor, short_seq_prob, masked_lm_prob, 178 | max_predictions_per_seq, rng): 179 | """Create `TrainingInstance`s from raw text.""" 180 | all_documents = [[]] 181 | 182 | # Input file format: 183 | # (1) One sentence per line. These should ideally be actual sentences, not 184 | # entire paragraphs or arbitrary spans of text. (Because we use the 185 | # sentence boundaries for the "next sentence prediction" task). 186 | # (2) Blank lines between documents. Document boundaries are needed so 187 | # that the "next sentence prediction" task doesn't span between documents. 188 | for input_file in input_files: 189 | with tf.gfile.GFile(input_file, "r") as reader: 190 | while True: 191 | line = tokenization.convert_to_unicode(reader.readline()) 192 | if not line: 193 | break 194 | line = line.strip() 195 | 196 | # Empty lines are used as document delimiters 197 | if not line: 198 | all_documents.append([]) 199 | tokens = tokenizer.tokenize(line) 200 | if tokens: 201 | all_documents[-1].append(tokens) 202 | 203 | # Remove empty documents 204 | all_documents = [x for x in all_documents if x] 205 | rng.shuffle(all_documents) 206 | 207 | vocab_words = list(tokenizer.vocab.keys()) 208 | instances = [] 209 | for _ in range(dupe_factor): 210 | for document_index in range(len(all_documents)): 211 | instances.extend( 212 | create_instances_from_document( 213 | all_documents, document_index, max_seq_length, short_seq_prob, 214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 215 | 216 | rng.shuffle(instances) 217 | return instances 218 | 219 | 220 | def create_instances_from_document( 221 | all_documents, document_index, max_seq_length, short_seq_prob, 222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 223 | """Creates `TrainingInstance`s for a single document.""" 224 | document = all_documents[document_index] 225 | 226 | # Account for [CLS], [SEP], [SEP] 227 | max_num_tokens = max_seq_length - 3 228 | 229 | # We *usually* want to fill up the entire sequence since we are padding 230 | # to `max_seq_length` anyways, so short sequences are generally wasted 231 | # computation. However, we *sometimes* 232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 233 | # sequences to minimize the mismatch between pre-training and fine-tuning. 234 | # The `target_seq_length` is just a rough target however, whereas 235 | # `max_seq_length` is a hard limit. 236 | target_seq_length = max_num_tokens 237 | if rng.random() < short_seq_prob: 238 | target_seq_length = rng.randint(2, max_num_tokens) 239 | 240 | # We DON'T just concatenate all of the tokens from a document into a long 241 | # sequence and choose an arbitrary split point because this would make the 242 | # next sentence prediction task too easy. Instead, we split the input into 243 | # segments "A" and "B" based on the actual "sentences" provided by the user 244 | # input. 245 | instances = [] 246 | current_chunk = [] 247 | current_length = 0 248 | i = 0 249 | while i < len(document): 250 | segment = document[i] 251 | current_chunk.append(segment) 252 | current_length += len(segment) 253 | if i == len(document) - 1 or current_length >= target_seq_length: 254 | if current_chunk: 255 | # `a_end` is how many segments from `current_chunk` go into the `A` 256 | # (first) sentence. 257 | a_end = 1 258 | if len(current_chunk) >= 2: 259 | a_end = rng.randint(1, len(current_chunk) - 1) 260 | 261 | tokens_a = [] 262 | for j in range(a_end): 263 | tokens_a.extend(current_chunk[j]) 264 | 265 | tokens_b = [] 266 | # Random next 267 | is_random_next = False 268 | if len(current_chunk) == 1 or rng.random() < 0.5: 269 | is_random_next = True 270 | target_b_length = target_seq_length - len(tokens_a) 271 | 272 | # This should rarely go for more than one iteration for large 273 | # corpora. However, just to be careful, we try to make sure that 274 | # the random document is not the same as the document 275 | # we're processing. 276 | for _ in range(10): 277 | random_document_index = rng.randint(0, len(all_documents) - 1) 278 | if random_document_index != document_index: 279 | break 280 | 281 | random_document = all_documents[random_document_index] 282 | random_start = rng.randint(0, len(random_document) - 1) 283 | for j in range(random_start, len(random_document)): 284 | tokens_b.extend(random_document[j]) 285 | if len(tokens_b) >= target_b_length: 286 | break 287 | # We didn't actually use these segments so we "put them back" so 288 | # they don't go to waste. 289 | num_unused_segments = len(current_chunk) - a_end 290 | i -= num_unused_segments 291 | # Actual next 292 | else: 293 | is_random_next = False 294 | for j in range(a_end, len(current_chunk)): 295 | tokens_b.extend(current_chunk[j]) 296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 297 | 298 | assert len(tokens_a) >= 1 299 | assert len(tokens_b) >= 1 300 | 301 | tokens = [] 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in tokens_a: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | 309 | tokens.append("[SEP]") 310 | segment_ids.append(0) 311 | 312 | for token in tokens_b: 313 | tokens.append(token) 314 | segment_ids.append(1) 315 | tokens.append("[SEP]") 316 | segment_ids.append(1) 317 | 318 | (tokens, masked_lm_positions, 319 | masked_lm_labels) = create_masked_lm_predictions( 320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 321 | instance = TrainingInstance( 322 | tokens=tokens, 323 | segment_ids=segment_ids, 324 | is_random_next=is_random_next, 325 | masked_lm_positions=masked_lm_positions, 326 | masked_lm_labels=masked_lm_labels) 327 | instances.append(instance) 328 | current_chunk = [] 329 | current_length = 0 330 | i += 1 331 | 332 | return instances 333 | 334 | 335 | def create_masked_lm_predictions(tokens, masked_lm_prob, 336 | max_predictions_per_seq, vocab_words, rng): 337 | """Creates the predictions for the masked LM objective.""" 338 | 339 | cand_indexes = [] 340 | for (i, token) in enumerate(tokens): 341 | if token == "[CLS]" or token == "[SEP]": 342 | continue 343 | cand_indexes.append(i) 344 | 345 | rng.shuffle(cand_indexes) 346 | 347 | output_tokens = list(tokens) 348 | 349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name 350 | 351 | num_to_predict = min(max_predictions_per_seq, 352 | max(1, int(round(len(tokens) * masked_lm_prob)))) 353 | 354 | masked_lms = [] 355 | covered_indexes = set() 356 | for index in cand_indexes: 357 | if len(masked_lms) >= num_to_predict: 358 | break 359 | if index in covered_indexes: 360 | continue 361 | covered_indexes.add(index) 362 | 363 | masked_token = None 364 | # 80% of the time, replace with [MASK] 365 | if rng.random() < 0.8: 366 | masked_token = "[MASK]" 367 | else: 368 | # 10% of the time, keep original 369 | if rng.random() < 0.5: 370 | masked_token = tokens[index] 371 | # 10% of the time, replace with random word 372 | else: 373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 374 | 375 | output_tokens[index] = masked_token 376 | 377 | masked_lms.append(masked_lm(index=index, label=tokens[index])) 378 | 379 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 380 | 381 | masked_lm_positions = [] 382 | masked_lm_labels = [] 383 | for p in masked_lms: 384 | masked_lm_positions.append(p.index) 385 | masked_lm_labels.append(p.label) 386 | 387 | return (output_tokens, masked_lm_positions, masked_lm_labels) 388 | 389 | 390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 391 | """Truncates a pair of sequences to a maximum sequence length.""" 392 | while True: 393 | total_length = len(tokens_a) + len(tokens_b) 394 | if total_length <= max_num_tokens: 395 | break 396 | 397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 398 | assert len(trunc_tokens) >= 1 399 | 400 | # We want to sometimes truncate from the front and sometimes from the 401 | # back to add more randomness and avoid biases. 402 | if rng.random() < 0.5: 403 | del trunc_tokens[0] 404 | else: 405 | trunc_tokens.pop() 406 | 407 | 408 | def main(_): 409 | tf.logging.set_verbosity(tf.logging.INFO) 410 | 411 | tokenizer = tokenization.FullTokenizer( 412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 413 | 414 | input_files = [] 415 | for input_pattern in FLAGS.input_file.split(","): 416 | input_files.extend(tf.gfile.Glob(input_pattern)) 417 | 418 | tf.logging.info("*** Reading from input files ***") 419 | for input_file in input_files: 420 | tf.logging.info(" %s", input_file) 421 | 422 | rng = random.Random(FLAGS.random_seed) 423 | instances = create_training_instances( 424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 426 | rng) 427 | 428 | output_files = FLAGS.output_file.split(",") 429 | tf.logging.info("*** Writing to output files ***") 430 | for output_file in output_files: 431 | tf.logging.info(" %s", output_file) 432 | 433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 434 | FLAGS.max_predictions_per_seq, output_files) 435 | 436 | 437 | if __name__ == "__main__": 438 | flags.mark_flag_as_required("input_file") 439 | flags.mark_flag_as_required("output_file") 440 | flags.mark_flag_as_required("vocab_file") 441 | tf.app.run() 442 | -------------------------------------------------------------------------------- /bert_base/bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /bert_base/bert/multilingual.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | There are two multilingual models currently available. We do not plan to release 4 | more single-language models, but we may release `BERT-Large` versions of these 5 | two in the future: 6 | 7 | * **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**: 8 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 9 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: 10 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M 11 | parameters 12 | 13 | See the [list of languages](#list-of-languages) that the Multilingual model 14 | supports. The Multilingual model does include Chinese (and English), but if your 15 | fine-tuning data is Chinese-only, then the Chinese model will likely produce 16 | better results. 17 | 18 | ## Results 19 | 20 | To evaluate these systems, we use the 21 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a 22 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the 23 | dev and test sets have been translated (by humans) into 15 languages. Note that 24 | the training set was *machine* translated (we used the translations provided by 25 | XNLI, not Google NMT). For clarity, we only report on 6 languages below: 26 | 27 | 28 | 29 | | System | English | Chinese | Spanish | German | Arabic | Urdu | 30 | | ------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- | 31 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 | 32 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 | 33 | | BERT -Translate Train | **81.4** | **74.2** | **77.3** | **75.2** | **70.5** | 61.7 | 34 | | BERT - Translate Test | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** | 35 | | BERT - Zero Shot | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 | 36 | 37 | 38 | 39 | The first two rows are baselines from the XNLI paper and the last three rows are 40 | our results with BERT. 41 | 42 | **Translate Train** means that the MultiNLI training set was machine translated 43 | from English into the foreign language. So training and evaluation were both 44 | done in the foreign language. Unfortunately, training was done on 45 | machine-translated data, so it is impossible to quantify how much of the lower 46 | accuracy (compared to English) is due to the quality of the machine translation 47 | vs. the quality of the pre-trained model. 48 | 49 | **Translate Test** means that the XNLI test set was machine translated from the 50 | foreign language into English. So training and evaluation were both done on 51 | English. However, test evaluation was done on machine-translated English, so the 52 | accuracy depends on the quality of the machine translation system. 53 | 54 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English 55 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case, 56 | machine translation was not involved at all in either the pre-training or 57 | fine-tuning. 58 | 59 | Note that the English result is worse than the 84.2 MultiNLI baseline because 60 | this training used Multilingual BERT rather than English-only BERT. This implies 61 | that for high-resource languages, the Multilingual model is somewhat worse than 62 | a single-language model. However, it is not feasible for us to train and 63 | maintain dozens of single-language model. Therefore, if your goal is to maximize 64 | performance with a language other than English or Chinese, you might find it 65 | beneficial to run pre-training for additional steps starting from our 66 | Multilingual model on data from your language of interest. 67 | 68 | Here is a comparison of training Chinese models with the Multilingual 69 | `BERT-Base` and Chinese-only `BERT-Base`: 70 | 71 | System | Chinese 72 | ----------------------- | ------- 73 | XNLI Baseline | 67.0 74 | BERT Multilingual Model | 74.2 75 | BERT Chinese-only Model | 77.2 76 | 77 | Similar to English, the single-language model does 3% better than the 78 | Multilingual model. 79 | 80 | ## Fine-tuning Example 81 | 82 | The multilingual model does **not** require any special consideration or API 83 | changes. We did update the implementation of `BasicTokenizer` in 84 | `tokenization.py` to support Chinese character tokenization, so please update if 85 | you forked it. However, we did not change the tokenization API. 86 | 87 | To test the new models, we did modify `run_classifier.py` to add support for the 88 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language 89 | version of MultiNLI where the dev/test sets have been human-translated, and the 90 | training set has been machine-translated. 91 | 92 | To run the fine-tuning code, please download the 93 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the 94 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip) 95 | and then unpack both .zip files into some directory `$XNLI_DIR`. 96 | 97 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py` 98 | (Chinese by default), so please modify `XnliProcessor` if you want to run on 99 | another language. 100 | 101 | This is a large dataset, so this will training will take a few hours on a GPU 102 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for 103 | debugging, just set `num_train_epochs` to a small value like `0.1`. 104 | 105 | ```shell 106 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12 107 | export XNLI_DIR=/path/to/xnli 108 | 109 | python run_classifier.py \ 110 | --task_name=XNLI \ 111 | --do_train=true \ 112 | --do_eval=true \ 113 | --data_dir=$XNLI_DIR \ 114 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 115 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 116 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 117 | --max_seq_length=128 \ 118 | --train_batch_size=32 \ 119 | --learning_rate=5e-5 \ 120 | --num_train_epochs=2.0 \ 121 | --output_dir=/tmp/xnli_output/ 122 | ``` 123 | 124 | With the Chinese-only model, the results should look something like this: 125 | 126 | ``` 127 | ***** Eval results ***** 128 | eval_accuracy = 0.774116 129 | eval_loss = 0.83554 130 | global_step = 24543 131 | loss = 0.74603 132 | ``` 133 | 134 | ## Details 135 | 136 | ### Data Source and Sampling 137 | 138 | The languages chosen were the 139 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias). 140 | The entire Wikipedia dump for each language (excluding user and talk pages) was 141 | taken as the training data for each language 142 | 143 | However, the size of the Wikipedia for a given language varies greatly, and 144 | therefore low-resource languages may be "under-represented" in terms of the 145 | neural network model (under the assumption that languages are "competing" for 146 | limited model capacity to some extent). 147 | 148 | However, the size of a Wikipedia also correlates with the number of speakers of 149 | a language, and we also don't want to overfit the model by performing thousands 150 | of epochs over a tiny Wikipedia for a particular language. 151 | 152 | To balance these two factors, we performed exponentially smoothed weighting of 153 | the data during pre-training data creation (and WordPiece vocab creation). In 154 | other words, let's say that the probability of a language is *P(L)*, e.g., 155 | *P(English) = 0.21* means that after concatenating all of the Wikipedias 156 | together, 21% of our data is English. We exponentiate each probability by some 157 | factor *S* and then re-normalize, and sample from that distribution. In our case 158 | we use *S=0.7*. So, high-resource languages like English will be under-sampled, 159 | and low-resource languages like Icelandic will be over-sampled. E.g., in the 160 | original distribution English would be sampled 1000x more than Icelandic, but 161 | after smoothing it's only sampled 100x more. 162 | 163 | ### Tokenization 164 | 165 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are 166 | weighted the same way as the data, so low-resource languages are upweighted by 167 | some factor. We intentionally do *not* use any marker to denote the input 168 | language (so that zero-shot training can work). 169 | 170 | Because Chinese does not have whitespace characters, we add spaces around every 171 | character in the 172 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\)) 173 | before applying WordPiece. This means that Chinese is effectively 174 | character-tokenized. Note that the CJK Unicode block only includes 175 | Chinese-origin characters and does *not* include Hangul Korean or 176 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like 177 | all other languages. 178 | 179 | For all other languages, we apply the 180 | [same recipe as English](https://github.com/google-research/bert#tokenization): 181 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace 182 | tokenization. We understand that accent markers have substantial meaning in some 183 | languages, but felt that the benefits of reducing the effective vocabulary make 184 | up for this. Generally the strong contextual models of BERT should make up for 185 | any ambiguity introduced by stripping accent markers. 186 | 187 | ### List of Languages 188 | 189 | The multilingual model supports the following languages. These languages were 190 | chosen because they are the top 100 languages with the largest Wikipedias: 191 | 192 | * Afrikaans 193 | * Albanian 194 | * Arabic 195 | * Aragonese 196 | * Armenian 197 | * Asturian 198 | * Azerbaijani 199 | * Bashkir 200 | * Basque 201 | * Bavarian 202 | * Belarusian 203 | * Bengali 204 | * Bishnupriya Manipuri 205 | * Bosnian 206 | * Breton 207 | * Bulgarian 208 | * Burmese 209 | * Catalan 210 | * Cebuano 211 | * Chechen 212 | * Chinese (Simplified) 213 | * Chinese (Traditional) 214 | * Chuvash 215 | * Croatian 216 | * Czech 217 | * Danish 218 | * Dutch 219 | * English 220 | * Estonian 221 | * Finnish 222 | * French 223 | * Galician 224 | * Georgian 225 | * German 226 | * Greek 227 | * Gujarati 228 | * Haitian 229 | * Hebrew 230 | * Hindi 231 | * Hungarian 232 | * Icelandic 233 | * Ido 234 | * Indonesian 235 | * Irish 236 | * Italian 237 | * Japanese 238 | * Javanese 239 | * Kannada 240 | * Kazakh 241 | * Kirghiz 242 | * Korean 243 | * Latin 244 | * Latvian 245 | * Lithuanian 246 | * Lombard 247 | * Low Saxon 248 | * Luxembourgish 249 | * Macedonian 250 | * Malagasy 251 | * Malay 252 | * Malayalam 253 | * Marathi 254 | * Minangkabau 255 | * Nepali 256 | * Newar 257 | * Norwegian (Bokmal) 258 | * Norwegian (Nynorsk) 259 | * Occitan 260 | * Persian (Farsi) 261 | * Piedmontese 262 | * Polish 263 | * Portuguese 264 | * Punjabi 265 | * Romanian 266 | * Russian 267 | * Scots 268 | * Serbian 269 | * Serbo-Croatian 270 | * Sicilian 271 | * Slovak 272 | * Slovenian 273 | * South Azerbaijani 274 | * Spanish 275 | * Sundanese 276 | * Swahili 277 | * Swedish 278 | * Tagalog 279 | * Tajik 280 | * Tamil 281 | * Tatar 282 | * Telugu 283 | * Turkish 284 | * Ukrainian 285 | * Urdu 286 | * Uzbek 287 | * Vietnamese 288 | * Volapük 289 | * Waray-Waray 290 | * Welsh 291 | * West 292 | * Western Punjabi 293 | * Yoruba 294 | 295 | The only language which we had to unfortunately exclude was Thai, since it is 296 | the only language (other than Chinese) that does not use whitespace to delimit 297 | words, and it has too many characters-per-word to use character-based 298 | tokenization. Our WordPiece algorithm is quadratic with respect to the size of 299 | the input token so very long character strings do not work with it. 300 | -------------------------------------------------------------------------------- /bert_base/bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /bert_base/bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /bert_base/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /bert_base/bert/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /bert_base/bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | #TODO: modify for oov, using [unk] replace, if you using english language do not change this 90 | # output.append(vocab.[item]) 91 | output.append(vocab.get(item, 100)) 92 | return output 93 | 94 | 95 | def convert_tokens_to_ids(vocab, tokens): 96 | return convert_by_vocab(vocab, tokens) 97 | 98 | 99 | def convert_ids_to_tokens(inv_vocab, ids): 100 | return convert_by_vocab(inv_vocab, ids) 101 | 102 | 103 | def whitespace_tokenize(text): 104 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 105 | text = text.strip() 106 | if not text: 107 | return [] 108 | tokens = text.split() 109 | return tokens 110 | 111 | 112 | class FullTokenizer(object): 113 | """Runs end-to-end tokenziation.""" 114 | 115 | def __init__(self, vocab_file, do_lower_case=True): 116 | self.vocab = load_vocab(vocab_file) 117 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 118 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 119 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 120 | 121 | def tokenize(self, text): 122 | split_tokens = [] 123 | for token in self.basic_tokenizer.tokenize(text): 124 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 125 | split_tokens.append(sub_token) 126 | 127 | return split_tokens 128 | 129 | def convert_tokens_to_ids(self, tokens): 130 | return convert_by_vocab(self.vocab, tokens) 131 | 132 | def convert_ids_to_tokens(self, ids): 133 | return convert_by_vocab(self.inv_vocab, ids) 134 | 135 | 136 | class BasicTokenizer(object): 137 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 138 | 139 | def __init__(self, do_lower_case=True): 140 | """Constructs a BasicTokenizer. 141 | 142 | Args: 143 | do_lower_case: Whether to lower case the input. 144 | """ 145 | self.do_lower_case = do_lower_case 146 | 147 | def tokenize(self, text): 148 | """Tokenizes a piece of text.""" 149 | text = convert_to_unicode(text) 150 | text = self._clean_text(text) 151 | 152 | # This was added on November 1st, 2018 for the multilingual and Chinese 153 | # models. This is also applied to the English models now, but it doesn't 154 | # matter since the English models were not trained on any Chinese data 155 | # and generally don't have any Chinese data in them (there are Chinese 156 | # characters in the vocabulary because Wikipedia does have some Chinese 157 | # words in the English Wikipedia.). 158 | text = self._tokenize_chinese_chars(text) 159 | 160 | orig_tokens = whitespace_tokenize(text) 161 | split_tokens = [] 162 | for token in orig_tokens: 163 | if self.do_lower_case: 164 | token = token.lower() 165 | token = self._run_strip_accents(token) 166 | split_tokens.extend(self._run_split_on_punc(token)) 167 | 168 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 169 | return output_tokens 170 | 171 | def _run_strip_accents(self, text): 172 | """Strips accents from a piece of text.""" 173 | text = unicodedata.normalize("NFD", text) 174 | output = [] 175 | for char in text: 176 | cat = unicodedata.category(char) 177 | if cat == "Mn": 178 | continue 179 | output.append(char) 180 | return "".join(output) 181 | 182 | def _run_split_on_punc(self, text): 183 | """Splits punctuation on a piece of text.""" 184 | chars = list(text) 185 | i = 0 186 | start_new_word = True 187 | output = [] 188 | while i < len(chars): 189 | char = chars[i] 190 | if _is_punctuation(char): 191 | output.append([char]) 192 | start_new_word = True 193 | else: 194 | if start_new_word: 195 | output.append([]) 196 | start_new_word = False 197 | output[-1].append(char) 198 | i += 1 199 | 200 | return ["".join(x) for x in output] 201 | 202 | def _tokenize_chinese_chars(self, text): 203 | """Adds whitespace around any CJK character.""" 204 | output = [] 205 | for char in text: 206 | cp = ord(char) 207 | if self._is_chinese_char(cp): 208 | output.append(" ") 209 | output.append(char) 210 | output.append(" ") 211 | else: 212 | output.append(char) 213 | return "".join(output) 214 | 215 | def _is_chinese_char(self, cp): 216 | """Checks whether CP is the codepoint of a CJK character.""" 217 | # This defines a "chinese character" as anything in the CJK Unicode block: 218 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 219 | # 220 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 221 | # despite its name. The modern Korean Hangul alphabet is a different block, 222 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 223 | # space-separated words, so they are not treated specially and handled 224 | # like the all of the other languages. 225 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 226 | (cp >= 0x3400 and cp <= 0x4DBF) or # 227 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 228 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 229 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 230 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 231 | (cp >= 0xF900 and cp <= 0xFAFF) or # 232 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 233 | return True 234 | 235 | return False 236 | 237 | def _clean_text(self, text): 238 | """Performs invalid character removal and whitespace cleanup on text.""" 239 | output = [] 240 | for char in text: 241 | cp = ord(char) 242 | if cp == 0 or cp == 0xfffd or _is_control(char): 243 | continue 244 | if _is_whitespace(char): 245 | output.append(" ") 246 | else: 247 | output.append(char) 248 | return "".join(output) 249 | 250 | 251 | class WordpieceTokenizer(object): 252 | """Runs WordPiece tokenziation.""" 253 | 254 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 255 | self.vocab = vocab 256 | self.unk_token = unk_token 257 | self.max_input_chars_per_word = max_input_chars_per_word 258 | 259 | def tokenize(self, text): 260 | """Tokenizes a piece of text into its word pieces. 261 | 262 | This uses a greedy longest-match-first algorithm to perform tokenization 263 | using the given vocabulary. 264 | 265 | For example: 266 | input = "unaffable" 267 | output = ["un", "##aff", "##able"] 268 | 269 | Args: 270 | text: A single token or whitespace separated tokens. This should have 271 | already been passed through `BasicTokenizer. 272 | 273 | Returns: 274 | A list of wordpiece tokens. 275 | """ 276 | 277 | text = convert_to_unicode(text) 278 | 279 | output_tokens = [] 280 | for token in whitespace_tokenize(text): 281 | chars = list(token) 282 | if len(chars) > self.max_input_chars_per_word: 283 | output_tokens.append(self.unk_token) 284 | continue 285 | 286 | is_bad = False 287 | start = 0 288 | sub_tokens = [] 289 | while start < len(chars): 290 | end = len(chars) 291 | cur_substr = None 292 | while start < end: 293 | substr = "".join(chars[start:end]) 294 | if start > 0: 295 | substr = "##" + substr 296 | if substr in self.vocab: 297 | cur_substr = substr 298 | break 299 | end -= 1 300 | if cur_substr is None: 301 | is_bad = True 302 | break 303 | sub_tokens.append(cur_substr) 304 | start = end 305 | 306 | if is_bad: 307 | output_tokens.append(self.unk_token) 308 | else: 309 | output_tokens.extend(sub_tokens) 310 | return output_tokens 311 | 312 | 313 | def _is_whitespace(char): 314 | """Checks whether `chars` is a whitespace character.""" 315 | # \t, \n, and \r are technically contorl characters but we treat them 316 | # as whitespace since they are generally considered as such. 317 | if char == " " or char == "\t" or char == "\n" or char == "\r": 318 | return True 319 | cat = unicodedata.category(char) 320 | if cat == "Zs": 321 | return True 322 | return False 323 | 324 | 325 | def _is_control(char): 326 | """Checks whether `chars` is a control character.""" 327 | # These are technically control characters but we count them as whitespace 328 | # characters. 329 | if char == "\t" or char == "\n" or char == "\r": 330 | return False 331 | cat = unicodedata.category(char) 332 | if cat.startswith("C"): 333 | return True 334 | return False 335 | 336 | 337 | def _is_punctuation(char): 338 | """Checks whether `chars` is a punctuation character.""" 339 | cp = ord(char) 340 | # We treat all non-letter/number ASCII as punctuation. 341 | # Characters such as "^", "$", and "`" are not in the Unicode 342 | # Punctuation class but we treat them as punctuation anyways, for 343 | # consistency. 344 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 345 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 346 | return True 347 | cat = unicodedata.category(char) 348 | if cat.startswith("P"): 349 | return True 350 | return False 351 | 352 | -------------------------------------------------------------------------------- /bert_base/bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | 22 | import tokenization 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | 36 | vocab_file = vocab_writer.name 37 | 38 | tokenizer = tokenization.FullTokenizer(vocab_file) 39 | os.unlink(vocab_file) 40 | 41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 42 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 43 | 44 | self.assertAllEqual( 45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 46 | 47 | def test_chinese(self): 48 | tokenizer = tokenization.BasicTokenizer() 49 | 50 | self.assertAllEqual( 51 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 52 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 53 | 54 | def test_basic_tokenizer_lower(self): 55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 56 | 57 | self.assertAllEqual( 58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 59 | ["hello", "!", "how", "are", "you", "?"]) 60 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 61 | 62 | def test_basic_tokenizer_no_lower(self): 63 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 64 | 65 | self.assertAllEqual( 66 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 67 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 68 | 69 | def test_wordpiece_tokenizer(self): 70 | vocab_tokens = [ 71 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 72 | "##ing" 73 | ] 74 | 75 | vocab = {} 76 | for (i, token) in enumerate(vocab_tokens): 77 | vocab[token] = i 78 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 79 | 80 | self.assertAllEqual(tokenizer.tokenize(""), []) 81 | 82 | self.assertAllEqual( 83 | tokenizer.tokenize("unwanted running"), 84 | ["un", "##want", "##ed", "runn", "##ing"]) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 88 | 89 | def test_convert_tokens_to_ids(self): 90 | vocab_tokens = [ 91 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 92 | "##ing" 93 | ] 94 | 95 | vocab = {} 96 | for (i, token) in enumerate(vocab_tokens): 97 | vocab[token] = i 98 | 99 | self.assertAllEqual( 100 | tokenization.convert_tokens_to_ids( 101 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 102 | 103 | def test_is_whitespace(self): 104 | self.assertTrue(tokenization._is_whitespace(u" ")) 105 | self.assertTrue(tokenization._is_whitespace(u"\t")) 106 | self.assertTrue(tokenization._is_whitespace(u"\r")) 107 | self.assertTrue(tokenization._is_whitespace(u"\n")) 108 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 109 | 110 | self.assertFalse(tokenization._is_whitespace(u"A")) 111 | self.assertFalse(tokenization._is_whitespace(u"-")) 112 | 113 | def test_is_control(self): 114 | self.assertTrue(tokenization._is_control(u"\u0005")) 115 | 116 | self.assertFalse(tokenization._is_control(u"A")) 117 | self.assertFalse(tokenization._is_control(u" ")) 118 | self.assertFalse(tokenization._is_control(u"\t")) 119 | self.assertFalse(tokenization._is_control(u"\r")) 120 | 121 | def test_is_punctuation(self): 122 | self.assertTrue(tokenization._is_punctuation(u"-")) 123 | self.assertTrue(tokenization._is_punctuation(u"$")) 124 | self.assertTrue(tokenization._is_punctuation(u"`")) 125 | self.assertTrue(tokenization._is_punctuation(u".")) 126 | 127 | self.assertFalse(tokenization._is_punctuation(u"A")) 128 | self.assertFalse(tokenization._is_punctuation(u" ")) 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.test.main() 133 | -------------------------------------------------------------------------------- /bert_base/runs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @Time : 2019/1/30 16:47 6 | @Author : MaCan (ma_cancan@163.com) 7 | @File : __init__.py.py 8 | """ 9 | 10 | 11 | def start_server(): 12 | from bert_base.server import BertServer 13 | from bert_base.server.helper import get_run_args 14 | 15 | args = get_run_args() 16 | # print(args) 17 | server = BertServer(args) 18 | server.start() 19 | server.join() 20 | 21 | 22 | def start_client(): 23 | pass 24 | 25 | 26 | def train_ner(): 27 | import os 28 | from bert_base.train.train_helper import get_args_parser 29 | from bert_base.train.bert_lstm_ner import train 30 | 31 | args = get_args_parser() 32 | if True: 33 | import sys 34 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) 35 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) 36 | # print(args) 37 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_map 38 | train(args=args) 39 | 40 | # if __name__ == '__main__': 41 | # # start_server() 42 | # train_ner() -------------------------------------------------------------------------------- /bert_base/server/graph.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | import os 4 | from enum import Enum 5 | 6 | from termcolor import colored 7 | 8 | from .helper import import_tf, set_logger 9 | 10 | import sys 11 | sys.path.append('..') 12 | from bert_base.bert import modeling 13 | 14 | __all__ = ['PoolingStrategy', 'optimize_bert_graph', 'optimize_ner_model', 'optimize_class_model'] 15 | 16 | 17 | class PoolingStrategy(Enum): 18 | NONE = 0 19 | REDUCE_MAX = 1 20 | REDUCE_MEAN = 2 21 | REDUCE_MEAN_MAX = 3 22 | FIRST_TOKEN = 4 # corresponds to [CLS] for single sequences 23 | LAST_TOKEN = 5 # corresponds to [SEP] for single sequences 24 | CLS_TOKEN = 4 # corresponds to the first token for single seq. 25 | SEP_TOKEN = 5 # corresponds to the last token for single seq. 26 | 27 | def __str__(self): 28 | return self.name 29 | 30 | @staticmethod 31 | def from_string(s): 32 | try: 33 | return PoolingStrategy[s] 34 | except KeyError: 35 | raise ValueError() 36 | 37 | 38 | def optimize_bert_graph(args, logger=None): 39 | if not logger: 40 | logger = set_logger(colored('GRAPHOPT', 'cyan'), args.verbose) 41 | try: 42 | if not os.path.exists(args.model_pb_dir): 43 | os.mkdir(args.model_pb_dir) 44 | pb_file = os.path.join(args.model_pb_dir, 'bert_model.pb') 45 | if os.path.exists(pb_file): 46 | return pb_file 47 | # we don't need GPU for optimizing the graph 48 | tf = import_tf(verbose=args.verbose) 49 | from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference 50 | 51 | config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) 52 | 53 | config_fp = os.path.join(args.model_dir, args.config_name) 54 | init_checkpoint = os.path.join(args.tuned_model_dir or args.bert_model_dir, args.ckpt_name) 55 | if args.fp16: 56 | logger.warning('fp16 is turned on! ' 57 | 'Note that not all CPU GPU support fast fp16 instructions, ' 58 | 'worst case you will have degraded performance!') 59 | logger.info('model config: %s' % config_fp) 60 | logger.info( 61 | 'checkpoint%s: %s' % ( 62 | ' (override by the fine-tuned model)' if args.tuned_model_dir else '', init_checkpoint)) 63 | with tf.gfile.GFile(config_fp, 'r') as f: 64 | bert_config = modeling.BertConfig.from_dict(json.load(f)) 65 | 66 | logger.info('build graph...') 67 | # input placeholders, not sure if they are friendly to XLA 68 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids') 69 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask') 70 | input_type_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_type_ids') 71 | 72 | jit_scope = tf.contrib.compiler.jit.experimental_jit_scope if args.xla else contextlib.suppress 73 | 74 | with jit_scope(): 75 | input_tensors = [input_ids, input_mask, input_type_ids] 76 | 77 | model = modeling.BertModel( 78 | config=bert_config, 79 | is_training=False, 80 | input_ids=input_ids, 81 | input_mask=input_mask, 82 | token_type_ids=input_type_ids, 83 | use_one_hot_embeddings=False) 84 | 85 | tvars = tf.trainable_variables() 86 | 87 | (assignment_map, initialized_variable_names 88 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 89 | 90 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 91 | 92 | minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30 93 | mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) 94 | masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1) 95 | masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( 96 | tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) 97 | 98 | with tf.variable_scope("pooling"): 99 | if len(args.pooling_layer) == 1: 100 | encoder_layer = model.all_encoder_layers[args.pooling_layer[0]] 101 | else: 102 | all_layers = [model.all_encoder_layers[l] for l in args.pooling_layer] 103 | encoder_layer = tf.concat(all_layers, -1) 104 | 105 | input_mask = tf.cast(input_mask, tf.float32) 106 | if args.pooling_strategy == PoolingStrategy.REDUCE_MEAN: 107 | pooled = masked_reduce_mean(encoder_layer, input_mask) 108 | elif args.pooling_strategy == PoolingStrategy.REDUCE_MAX: 109 | pooled = masked_reduce_max(encoder_layer, input_mask) 110 | elif args.pooling_strategy == PoolingStrategy.REDUCE_MEAN_MAX: 111 | pooled = tf.concat([masked_reduce_mean(encoder_layer, input_mask), 112 | masked_reduce_max(encoder_layer, input_mask)], axis=1) 113 | elif args.pooling_strategy == PoolingStrategy.FIRST_TOKEN or \ 114 | args.pooling_strategy == PoolingStrategy.CLS_TOKEN: 115 | pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1) 116 | elif args.pooling_strategy == PoolingStrategy.LAST_TOKEN or \ 117 | args.pooling_strategy == PoolingStrategy.SEP_TOKEN: 118 | seq_len = tf.cast(tf.reduce_sum(input_mask, axis=1), tf.int32) 119 | rng = tf.range(0, tf.shape(seq_len)[0]) 120 | indexes = tf.stack([rng, seq_len - 1], 1) 121 | pooled = tf.gather_nd(encoder_layer, indexes) 122 | elif args.pooling_strategy == PoolingStrategy.NONE: 123 | pooled = mul_mask(encoder_layer, input_mask) 124 | else: 125 | raise NotImplementedError() 126 | 127 | if args.fp16: 128 | pooled = tf.cast(pooled, tf.float16) 129 | 130 | pooled = tf.identity(pooled, 'final_encodes') 131 | output_tensors = [pooled] 132 | tmp_g = tf.get_default_graph().as_graph_def() 133 | 134 | with tf.Session(config=config) as sess: 135 | logger.info('load parameters from checkpoint...') 136 | 137 | sess.run(tf.global_variables_initializer()) 138 | dtypes = [n.dtype for n in input_tensors] 139 | logger.info('optimize...') 140 | tmp_g = optimize_for_inference( 141 | tmp_g, 142 | [n.name[:-2] for n in input_tensors], 143 | [n.name[:-2] for n in output_tensors], 144 | [dtype.as_datatype_enum for dtype in dtypes], 145 | False) 146 | 147 | logger.info('freeze...') 148 | tmp_g = convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors], 149 | use_fp16=args.fp16) 150 | 151 | logger.info('write graph to a tmp file: %s' % args.model_pb_dir) 152 | with tf.gfile.GFile(pb_file, 'wb') as f: 153 | f.write(tmp_g.SerializeToString()) 154 | except Exception: 155 | logger.error('fail to optimize the graph!', exc_info=True) 156 | 157 | 158 | def convert_variables_to_constants(sess, 159 | input_graph_def, 160 | output_node_names, 161 | variable_names_whitelist=None, 162 | variable_names_blacklist=None, 163 | use_fp16=False): 164 | from tensorflow.python.framework.graph_util_impl import extract_sub_graph 165 | from tensorflow.core.framework import graph_pb2 166 | from tensorflow.core.framework import node_def_pb2 167 | from tensorflow.core.framework import attr_value_pb2 168 | from tensorflow.core.framework import types_pb2 169 | from tensorflow.python.framework import tensor_util 170 | 171 | def patch_dtype(input_node, field_name, output_node): 172 | if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT): 173 | output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF)) 174 | 175 | inference_graph = extract_sub_graph(input_graph_def, output_node_names) 176 | 177 | variable_names = [] 178 | variable_dict_names = [] 179 | for node in inference_graph.node: 180 | if node.op in ["Variable", "VariableV2", "VarHandleOp"]: 181 | variable_name = node.name 182 | if ((variable_names_whitelist is not None and 183 | variable_name not in variable_names_whitelist) or 184 | (variable_names_blacklist is not None and 185 | variable_name in variable_names_blacklist)): 186 | continue 187 | variable_dict_names.append(variable_name) 188 | if node.op == "VarHandleOp": 189 | variable_names.append(variable_name + "/Read/ReadVariableOp:0") 190 | else: 191 | variable_names.append(variable_name + ":0") 192 | if variable_names: 193 | returned_variables = sess.run(variable_names) 194 | else: 195 | returned_variables = [] 196 | found_variables = dict(zip(variable_dict_names, returned_variables)) 197 | 198 | output_graph_def = graph_pb2.GraphDef() 199 | how_many_converted = 0 200 | for input_node in inference_graph.node: 201 | output_node = node_def_pb2.NodeDef() 202 | if input_node.name in found_variables: 203 | output_node.op = "Const" 204 | output_node.name = input_node.name 205 | dtype = input_node.attr["dtype"] 206 | data = found_variables[input_node.name] 207 | 208 | if use_fp16 and dtype.type == types_pb2.DT_FLOAT: 209 | output_node.attr["value"].CopyFrom( 210 | attr_value_pb2.AttrValue( 211 | tensor=tensor_util.make_tensor_proto(data.astype('float16'), 212 | dtype=types_pb2.DT_HALF, 213 | shape=data.shape))) 214 | else: 215 | output_node.attr["dtype"].CopyFrom(dtype) 216 | output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue( 217 | tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, 218 | shape=data.shape))) 219 | how_many_converted += 1 220 | elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables): 221 | # placeholder nodes 222 | # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"])) 223 | output_node.op = "Identity" 224 | output_node.name = input_node.name 225 | output_node.input.extend([input_node.input[0]]) 226 | output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) 227 | if "_class" in input_node.attr: 228 | output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) 229 | else: 230 | # mostly op nodes 231 | output_node.CopyFrom(input_node) 232 | 233 | patch_dtype(input_node, 'dtype', output_node) 234 | patch_dtype(input_node, 'T', output_node) 235 | patch_dtype(input_node, 'DstT', output_node) 236 | patch_dtype(input_node, 'SrcT', output_node) 237 | patch_dtype(input_node, 'Tparams', output_node) 238 | 239 | if use_fp16 and ('value' in output_node.attr) and ( 240 | output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT): 241 | # hard-coded value need to be converted as well 242 | output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( 243 | tensor=tensor_util.make_tensor_proto( 244 | output_node.attr['value'].tensor.float_val[0], 245 | dtype=types_pb2.DT_HALF))) 246 | 247 | output_graph_def.node.extend([output_node]) 248 | 249 | output_graph_def.library.CopyFrom(inference_graph.library) 250 | return output_graph_def 251 | 252 | 253 | def optimize_ner_model(args, num_labels, logger=None): 254 | """ 255 | 加载中文NER模型 256 | :param args: 257 | :param num_labels: 258 | :param logger: 259 | :return: 260 | """ 261 | if not logger: 262 | logger = set_logger(colored('NER_MODEL, Lodding...', 'cyan'), args.verbose) 263 | try: 264 | # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径 265 | if args.model_pb_dir is None: 266 | # 获取当前的运行路径 267 | tmp_file = os.path.join(os.getcwd(), 'predict_optimizer') 268 | if not os.path.exists(tmp_file): 269 | os.mkdir(tmp_file) 270 | else: 271 | tmp_file = args.model_pb_dir 272 | pb_file = os.path.join(tmp_file, 'ner_model.pb') 273 | if os.path.exists(pb_file): 274 | print('pb_file exits', pb_file) 275 | return pb_file 276 | 277 | import tensorflow as tf 278 | 279 | graph = tf.Graph() 280 | with graph.as_default(): 281 | with tf.Session() as sess: 282 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids') 283 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask') 284 | 285 | bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json')) 286 | from bert_base.train.models import create_model 287 | (total_loss, logits, trans, pred_ids) = create_model( 288 | bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=None, 289 | labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0, lstm_size=args.lstm_size) 290 | pred_ids = tf.identity(pred_ids, 'pred_ids') 291 | saver = tf.train.Saver() 292 | 293 | with tf.Session() as sess: 294 | sess.run(tf.global_variables_initializer()) 295 | saver.restore(sess, tf.train.latest_checkpoint(args.model_dir)) 296 | logger.info('freeze...') 297 | from tensorflow.python.framework import graph_util 298 | tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_ids']) 299 | logger.info('model cut finished !!!') 300 | # 存储二进制模型到文件中 301 | logger.info('write graph to a tmp file: %s' % pb_file) 302 | with tf.gfile.GFile(pb_file, 'wb') as f: 303 | f.write(tmp_g.SerializeToString()) 304 | return pb_file 305 | except Exception as e: 306 | logger.error('fail to optimize the graph! %s' % e, exc_info=True) 307 | 308 | 309 | def optimize_class_model(args, num_labels, logger=None): 310 | """ 311 | 加载中文分类模型 312 | :param args: 313 | :param num_labels: 314 | :param logger: 315 | :return: 316 | """ 317 | if not logger: 318 | logger = set_logger(colored('CLASSIFICATION_MODEL, Lodding...', 'cyan'), args.verbose) 319 | try: 320 | # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径 321 | if args.model_pb_dir is None: 322 | # 获取当前的运行路径 323 | tmp_file = os.path.join(os.getcwd(), 'predict_optimizer') 324 | if not os.path.exists(tmp_file): 325 | os.mkdir(tmp_file) 326 | else: 327 | tmp_file = args.model_pb_dir 328 | pb_file = os.path.join(tmp_file, 'classification_model.pb') 329 | if os.path.exists(pb_file): 330 | print('pb_file exits', pb_file) 331 | return pb_file 332 | import tensorflow as tf 333 | 334 | graph = tf.Graph() 335 | with graph.as_default(): 336 | with tf.Session() as sess: 337 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids') 338 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask') 339 | 340 | bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json')) 341 | from bert_base.train.models import create_classification_model 342 | #为了兼容多输入,增加segment_id特征,即训练代码中的input_type_ids特征。 343 | #loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, 344 | #input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels) 345 | segment_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'segment_ids') 346 | loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, num_labels=num_labels) 347 | # pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids') 348 | # pred_ids = tf.identity(pred_ids, 'pred_ids') 349 | probabilities = tf.identity(probabilities, 'pred_prob') 350 | saver = tf.train.Saver() 351 | 352 | with tf.Session() as sess: 353 | sess.run(tf.global_variables_initializer()) 354 | saver.restore(sess, tf.train.latest_checkpoint(args.model_dir)) 355 | logger.info('freeze...') 356 | from tensorflow.python.framework import graph_util 357 | tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob']) 358 | logger.info('predict cut finished !!!') 359 | # 存储二进制模型到文件中 360 | logger.info('write graph to a tmp file: %s' % pb_file) 361 | with tf.gfile.GFile(pb_file, 'wb') as f: 362 | f.write(tmp_g.SerializeToString()) 363 | return pb_file 364 | except Exception as e: 365 | logger.error('fail to optimize the graph! %s' % e, exc_info=True) 366 | 367 | 368 | 369 | -------------------------------------------------------------------------------- /bert_base/server/helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | import uuid 6 | import pickle 7 | import zmq 8 | from zmq.utils import jsonapi 9 | 10 | __all__ = ['set_logger', 'send_ndarray', 'get_args_parser', 11 | 'check_tf_version', 'auto_bind', 'import_tf'] 12 | 13 | 14 | def set_logger(context, verbose=False): 15 | #if os.name == 'nt': # for Windows 16 | # return NTLogger(context, verbose) 17 | 18 | logger = logging.getLogger(context) 19 | logger.setLevel(logging.DEBUG if verbose else logging.INFO) 20 | formatter = logging.Formatter( 21 | '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt= 22 | '%m-%d %H:%M:%S') 23 | console_handler = logging.StreamHandler() 24 | console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) 25 | console_handler.setFormatter(formatter) 26 | logger.handlers = [] 27 | logger.addHandler(console_handler) 28 | return logger 29 | 30 | 31 | class NTLogger: 32 | def __init__(self, context, verbose): 33 | self.context = context 34 | self.verbose = verbose 35 | 36 | def info(self, msg, **kwargs): 37 | print('I:%s:%s' % (self.context, msg), flush=True) 38 | 39 | def debug(self, msg, **kwargs): 40 | if self.verbose: 41 | print('D:%s:%s' % (self.context, msg), flush=True) 42 | 43 | def error(self, msg, **kwargs): 44 | print('E:%s:%s' % (self.context, msg), flush=True) 45 | 46 | def warning(self, msg, **kwargs): 47 | print('W:%s:%s' % (self.context, msg), flush=True) 48 | 49 | 50 | def send_ndarray(src, dest, X, req_id=b'', flags=0, copy=True, track=False): 51 | """send a numpy array with metadata""" 52 | # md = dict(dtype=str(X.dtype), shape=X.shape) 53 | if type(X) == list and type(X[0]) == dict: # 分类for sink发送消息的处理 54 | md = dict(dtype='json', shape=(len(X[0]['pred_label']), 1)) 55 | elif type(X) == dict: # 分类 bertwork 发送消息的处理 56 | md = dict(dtype='json', shape=(len(X['pred_label']), 1)) 57 | else: 58 | md = dict(dtype='str', shape=(len(X), len(X[0]))) 59 | # print('md', md) 60 | return src.send_multipart([dest, jsonapi.dumps(md), pickle.dumps(X), req_id], flags, copy=copy, track=track) 61 | 62 | 63 | def get_args_parser(): 64 | from . import __version__ 65 | from .graph import PoolingStrategy 66 | 67 | parser = argparse.ArgumentParser() 68 | 69 | group1 = parser.add_argument_group('File Paths', 70 | 'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model') 71 | 72 | group1.add_argument('-bert_model_dir', type=str, required=True, 73 | help='chinese google bert model path') 74 | 75 | group1.add_argument('-model_dir', type=str, required=True, 76 | help='directory of a pretrained BERT model') 77 | group1.add_argument('-model_pb_dir', type=str, default=None, 78 | help='directory of a pretrained BERT model') 79 | 80 | group1.add_argument('-tuned_model_dir', type=str, 81 | help='directory of a fine-tuned BERT model') 82 | group1.add_argument('-ckpt_name', type=str, default='bert_model.ckpt', 83 | help='filename of the checkpoint file. By default it is "bert_model.ckpt", but \ 84 | for a fine-tuned model the name could be different.') 85 | group1.add_argument('-config_name', type=str, default='bert_config.json', 86 | help='filename of the JSON config file for BERT model.') 87 | 88 | group2 = parser.add_argument_group('BERT Parameters', 89 | 'config how BERT model and pooling works') 90 | group2.add_argument('-max_seq_len', type=int, default=128, 91 | help='maximum length of a sequence') 92 | group2.add_argument('-pooling_layer', type=int, nargs='+', default=[-2], 93 | help='the encoder layer(s) that receives pooling. \ 94 | Give a list in order to concatenate several layers into one') 95 | group2.add_argument('-pooling_strategy', type=PoolingStrategy.from_string, 96 | default=PoolingStrategy.REDUCE_MEAN, choices=list(PoolingStrategy), 97 | help='the pooling strategy for generating encoding vectors') 98 | group2.add_argument('-mask_cls_sep', action='store_true', default=False, 99 | help='masking the embedding on [CLS] and [SEP] with zero. \ 100 | When pooling_strategy is in {CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN} \ 101 | then the embedding is preserved, otherwise the embedding is masked to zero before pooling') 102 | group2.add_argument('-lstm_size', type=int, default=128, 103 | help='size of lstm units.') 104 | 105 | group3 = parser.add_argument_group('Serving Configs', 106 | 'config how server utilizes GPU/CPU resources') 107 | group3.add_argument('-port', '-port_in', '-port_data', type=int, default=5555, 108 | help='server port for receiving data from client') 109 | group3.add_argument('-port_out', '-port_result', type=int, default=5556, 110 | help='server port for sending result to client') 111 | group3.add_argument('-http_port', type=int, default=None, 112 | help='server port for receiving HTTP requests') 113 | group3.add_argument('-http_max_connect', type=int, default=10, 114 | help='maximum number of concurrent HTTP connections') 115 | group3.add_argument('-cors', type=str, default='*', 116 | help='setting "Access-Control-Allow-Origin" for HTTP requests') 117 | group3.add_argument('-num_worker', type=int, default=1, 118 | help='number of server instances') 119 | group3.add_argument('-max_batch_size', type=int, default=1024, 120 | help='maximum number of sequences handled by each worker') 121 | group3.add_argument('-priority_batch_size', type=int, default=16, 122 | help='batch smaller than this size will be labeled as high priority,' 123 | 'and jumps forward in the job queue') 124 | group3.add_argument('-cpu', action='store_true', default=False, 125 | help='running on CPU (default on GPU)') 126 | group3.add_argument('-xla', action='store_true', default=False, 127 | help='enable XLA compiler (experimental)') 128 | group3.add_argument('-fp16', action='store_true', default=False, 129 | help='use float16 precision (experimental)') 130 | group3.add_argument('-gpu_memory_fraction', type=float, default=0.5, 131 | help='determine the fraction of the overall amount of memory \ 132 | that each visible GPU should be allocated per worker. \ 133 | Should be in range [0.0, 1.0]') 134 | group3.add_argument('-device_map', type=int, nargs='+', default=[], 135 | help='specify the list of GPU device ids that will be used (id starts from 0). \ 136 | If num_worker > len(device_map), then device will be reused; \ 137 | if num_worker < len(device_map), then device_map[:num_worker] will be used') 138 | group3.add_argument('-prefetch_size', type=int, default=10, 139 | help='the number of batches to prefetch on each worker. When running on a CPU-only machine, \ 140 | this is set to 0 for comparability') 141 | 142 | parser.add_argument('-verbose', action='store_true', default=False, 143 | help='turn on tensorflow logging for debug') 144 | parser.add_argument('-mode', type=str, default='NER') 145 | parser.add_argument('-version', action='version', version='%(prog)s ' + __version__) 146 | return parser 147 | 148 | 149 | def check_tf_version(): 150 | import tensorflow as tf 151 | tf_ver = tf.__version__.split('.') 152 | assert int(tf_ver[0]) >= 1 and int(tf_ver[1]) >= 10, 'Tensorflow >=1.10 is required!' 153 | return tf_ver 154 | 155 | 156 | def import_tf(device_id=-1, verbose=False, use_fp16=False): 157 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id) 158 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3' 159 | os.environ['TF_FP16_MATMUL_USE_FP32_COMPUTE'] = '0' if use_fp16 else '1' 160 | os.environ['TF_FP16_CONV_USE_FP32_COMPUTE'] = '0' if use_fp16 else '1' 161 | import tensorflow as tf 162 | tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR) 163 | return tf 164 | 165 | 166 | def auto_bind(socket): 167 | """ 168 | 自动进行端口绑定 169 | :param socket: 170 | :return: 171 | """ 172 | if os.name == 'nt': # for Windows 173 | socket.bind_to_random_port('tcp://127.0.0.1') 174 | else: 175 | # Get the location for tmp file for sockets 176 | try: 177 | tmp_dir = os.environ['ZEROMQ_SOCK_TMP_DIR'] 178 | if not os.path.exists(tmp_dir): 179 | raise ValueError('This directory for sockets ({}) does not seems to exist.'.format(tmp_dir)) 180 | # 随机产生一个 181 | tmp_dir = os.path.join(tmp_dir, str(uuid.uuid1())[:8]) 182 | except KeyError: 183 | tmp_dir = '*' 184 | 185 | socket.bind('ipc://{}'.format(tmp_dir)) 186 | return socket.getsockopt(zmq.LAST_ENDPOINT).decode('ascii') 187 | 188 | 189 | def get_run_args(parser_fn=get_args_parser, printed=True): 190 | args = parser_fn().parse_args() 191 | if printed: 192 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) 193 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) 194 | return args 195 | 196 | 197 | def get_benchmark_parser(): 198 | parser = get_args_parser() 199 | 200 | parser.set_defaults(num_client=1, client_batch_size=4096) 201 | 202 | group = parser.add_argument_group('Benchmark parameters', 'config the experiments of the benchmark') 203 | 204 | group.add_argument('-test_client_batch_size', type=int, nargs='*', default=[1, 16, 256, 4096]) 205 | group.add_argument('-test_max_batch_size', type=int, nargs='*', default=[8, 32, 128, 512]) 206 | group.add_argument('-test_max_seq_len', type=int, nargs='*', default=[32, 64, 128, 256]) 207 | group.add_argument('-test_num_client', type=int, nargs='*', default=[1, 4, 16, 64]) 208 | group.add_argument('-test_pooling_layer', type=int, nargs='*', default=[[-j] for j in range(1, 13)]) 209 | 210 | group.add_argument('-wait_till_ready', type=int, default=30, 211 | help='seconds to wait until server is ready to serve') 212 | group.add_argument('-client_vocab_file', type=str, default='README.md', 213 | help='file path for building client vocabulary') 214 | group.add_argument('-num_repeat', type=int, default=10, 215 | help='number of repeats per experiment (must >2), ' 216 | 'as the first two results are omitted for warm-up effect') 217 | return parser 218 | -------------------------------------------------------------------------------- /bert_base/server/http.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | 3 | from termcolor import colored 4 | 5 | from .helper import set_logger 6 | 7 | 8 | class BertHTTPProxy(Process): 9 | def __init__(self, args): 10 | super().__init__() 11 | self.args = args 12 | 13 | def create_flask_app(self): 14 | try: 15 | from flask import Flask, request 16 | from flask_compress import Compress 17 | from flask_cors import CORS 18 | from flask_json import FlaskJSON, as_json, JsonError 19 | from bert_base.client import ConcurrentBertClient 20 | except ImportError: 21 | raise ImportError('BertClient or Flask or its dependencies are not fully installed, ' 22 | 'they are required for serving HTTP requests.' 23 | 'Please use "pip install -U bert-serving-server[http]" to install it.') 24 | 25 | # support up to 10 concurrent HTTP requests 26 | bc = ConcurrentBertClient(max_concurrency=self.args.http_max_connect, 27 | port=self.args.port, port_out=self.args.port_out, 28 | output_fmt='list', mode=self.args.mode) 29 | app = Flask(__name__) 30 | logger = set_logger(colored('PROXY', 'red')) 31 | 32 | @app.route('/status/server', methods=['GET']) 33 | @as_json 34 | def get_server_status(): 35 | return bc.server_status 36 | 37 | @app.route('/status/client', methods=['GET']) 38 | @as_json 39 | def get_client_status(): 40 | return bc.status 41 | 42 | @app.route('/encode', methods=['POST']) 43 | @as_json 44 | def encode_query(): 45 | data = request.form if request.form else request.json 46 | try: 47 | logger.info('new request from %s' % request.remote_addr) 48 | print(data) 49 | return {'id': data['id'], 50 | 'result': bc.encode(data['texts'], is_tokenized=bool( 51 | data['is_tokenized']) if 'is_tokenized' in data else False)} 52 | 53 | except Exception as e: 54 | logger.error('error when handling HTTP request', exc_info=True) 55 | raise JsonError(description=str(e), type=str(type(e).__name__)) 56 | 57 | CORS(app, origins=self.args.cors) 58 | FlaskJSON(app) 59 | Compress().init_app(app) 60 | return app 61 | 62 | def run(self): 63 | app = self.create_flask_app() 64 | app.run(port=self.args.http_port, threaded=True, host='0.0.0.0') 65 | -------------------------------------------------------------------------------- /bert_base/server/simple_flask_http_service.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | #@Time : ${DATE} ${TIME} 6 | # @Author : MaCan (ma_cancan@163.com) 7 | # @File : ${NAME}.py 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | import flask 16 | from flask import request, jsonify 17 | import json 18 | import pickle 19 | from datetime import datetime 20 | import tensorflow as tf 21 | from tensorflow import keras as K 22 | import numpy as np 23 | 24 | import sys 25 | sys.path.append('../..') 26 | from bert_base.train.models import create_model, InputFeatures 27 | from bert_base.bert import tokenization, modeling 28 | 29 | 30 | model_dir = r'../../output' 31 | bert_dir = 'H:\models\chinese_L-12_H-768_A-12' 32 | 33 | is_training=False 34 | use_one_hot_embeddings=False 35 | batch_size=1 36 | max_seq_length = 202 37 | 38 | gpu_config = tf.ConfigProto() 39 | gpu_config.gpu_options.allow_growth = True 40 | sess=tf.Session(config=gpu_config) 41 | model=None 42 | 43 | global graph 44 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None 45 | 46 | 47 | print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint"))) 48 | if not os.path.exists(os.path.join(model_dir, "checkpoint")): 49 | raise Exception("failed to get checkpoint. going to return ") 50 | 51 | # 加载label->id的词典 52 | with open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf: 53 | label2id = pickle.load(rf) 54 | id2label = {value: key for key, value in label2id.items()} 55 | 56 | with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf: 57 | label_list = pickle.load(rf) 58 | num_labels = len(label_list) + 1 59 | 60 | 61 | graph = tf.get_default_graph() 62 | with graph.as_default(): 63 | print("going to restore checkpoint") 64 | #sess.run(tf.global_variables_initializer()) 65 | input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids") 66 | input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask") 67 | 68 | bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json')) 69 | (total_loss, logits, trans, pred_ids) = create_model( 70 | bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, 71 | labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0) 72 | 73 | saver = tf.train.Saver() 74 | saver.restore(sess, tf.train.latest_checkpoint(model_dir)) 75 | 76 | tokenizer = tokenization.FullTokenizer( 77 | vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True) 78 | 79 | app = flask.Flask(__name__) 80 | 81 | 82 | @app.route('/ner_predict_service', methods=['GET']) 83 | def ner_predict_service(): 84 | """ 85 | do online prediction. each time make prediction for one instance. 86 | you can change to a batch if you want. 87 | 88 | :param line: a list. element is: [dummy_label,text_a,text_b] 89 | :return: 90 | """ 91 | def convert(line): 92 | feature = convert_single_example(0, line, label_list, max_seq_length, tokenizer, 'p') 93 | input_ids = np.reshape([feature.input_ids],(batch_size, max_seq_length)) 94 | input_mask = np.reshape([feature.input_mask],(batch_size, max_seq_length)) 95 | segment_ids = np.reshape([feature.segment_ids],(batch_size, max_seq_length)) 96 | label_ids =np.reshape([feature.label_ids],(batch_size, max_seq_length)) 97 | return input_ids, input_mask, segment_ids, label_ids 98 | 99 | global graph 100 | with graph.as_default(): 101 | result = {} 102 | result['code'] = 0 103 | try: 104 | sentence = request.args['query'] 105 | result['query'] = sentence 106 | start = datetime.now() 107 | if len(sentence) < 2: 108 | print(sentence) 109 | result['data'] = ['O'] * len(sentence) 110 | return json.dumps(result) 111 | sentence = tokenizer.tokenize(sentence) 112 | # print('your input is:{}'.format(sentence)) 113 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 114 | 115 | 116 | feed_dict = {input_ids_p: input_ids, 117 | input_mask_p: input_mask} 118 | # run session get current feed_dict result 119 | pred_ids_result = sess.run([pred_ids], feed_dict) 120 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 121 | print(pred_label_result) 122 | #todo: 组合策略 123 | result['data'] = pred_label_result 124 | print('time used: {} sec'.format((datetime.now() - start).total_seconds())) 125 | return json.dumps(result) 126 | except: 127 | result['code'] = -1 128 | result['data'] = 'error' 129 | return json.dumps(result) 130 | 131 | def online_predict(): 132 | """ 133 | do online prediction. each time make prediction for one instance. 134 | you can change to a batch if you want. 135 | 136 | :param line: a list. element is: [dummy_label,text_a,text_b] 137 | :return: 138 | """ 139 | def convert(line): 140 | feature = convert_single_example(0, line, label_list, max_seq_length, tokenizer, 'p') 141 | input_ids = np.reshape([feature.input_ids],(batch_size, max_seq_length)) 142 | input_mask = np.reshape([feature.input_mask],(batch_size, max_seq_length)) 143 | segment_ids = np.reshape([feature.segment_ids],(batch_size, max_seq_length)) 144 | label_ids =np.reshape([feature.label_ids],(batch_size, max_seq_length)) 145 | return input_ids, input_mask, segment_ids, label_ids 146 | 147 | global graph 148 | with graph.as_default(): 149 | 150 | sentence = '北京天安门' 151 | 152 | start = datetime.now() 153 | if len(sentence) < 2: 154 | print(sentence) 155 | 156 | sentence = tokenizer.tokenize(sentence) 157 | # print('your input is:{}'.format(sentence)) 158 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 159 | 160 | 161 | feed_dict = {input_ids_p: input_ids, 162 | input_mask_p: input_mask} 163 | # run session get current feed_dict result 164 | pred_ids_result = sess.run([pred_ids], feed_dict) 165 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 166 | print(pred_label_result) 167 | 168 | print('time used: {} sec'.format((datetime.now() - start).total_seconds())) 169 | 170 | 171 | 172 | 173 | 174 | def convert_id_to_label(pred_ids_result, idx2label): 175 | """ 176 | 将id形式的结果转化为真实序列结果 177 | :param pred_ids_result: 178 | :param idx2label: 179 | :return: 180 | """ 181 | result = [] 182 | for row in range(batch_size): 183 | curr_seq = [] 184 | for ids in pred_ids_result[row][0]: 185 | if ids == 0: 186 | break 187 | curr_label = idx2label[ids] 188 | if curr_label in ['[CLS]', '[SEP]']: 189 | continue 190 | curr_seq.append(curr_label) 191 | result.append(curr_seq) 192 | return result 193 | 194 | 195 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): 196 | """ 197 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 198 | :param ex_index: index 199 | :param example: 一个样本 200 | :param label_list: 标签列表 201 | :param max_seq_length: 202 | :param tokenizer: 203 | :param mode: 204 | :return: 205 | """ 206 | label_map = {} 207 | # 1表示从1开始对label进行index化 208 | for (i, label) in enumerate(label_list, 1): 209 | label_map[label] = i 210 | # 保存label->index 的map 211 | if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')): 212 | with open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w: 213 | pickle.dump(label_map, w) 214 | 215 | tokens = example 216 | # tokens = tokenizer.tokenize(example.text) 217 | # 序列截断 218 | if len(tokens) >= max_seq_length - 1: 219 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 220 | ntokens = [] 221 | segment_ids = [] 222 | label_ids = [] 223 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 224 | segment_ids.append(0) 225 | # append("O") or append("[CLS]") not sure! 226 | label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 227 | for i, token in enumerate(tokens): 228 | ntokens.append(token) 229 | segment_ids.append(0) 230 | label_ids.append(0) 231 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志 232 | segment_ids.append(0) 233 | # append("O") or append("[SEP]") not sure! 234 | label_ids.append(label_map["[SEP]"]) 235 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 236 | input_mask = [1] * len(input_ids) 237 | 238 | # padding, 使用 239 | while len(input_ids) < max_seq_length: 240 | input_ids.append(0) 241 | input_mask.append(0) 242 | segment_ids.append(0) 243 | # we don't concerned about it! 244 | label_ids.append(0) 245 | ntokens.append("**NULL**") 246 | # label_mask.append(0) 247 | # print(len(input_ids)) 248 | assert len(input_ids) == max_seq_length 249 | assert len(input_mask) == max_seq_length 250 | assert len(segment_ids) == max_seq_length 251 | assert len(label_ids) == max_seq_length 252 | # assert len(label_mask) == max_seq_length 253 | 254 | # 结构化为一个类 255 | feature = InputFeatures( 256 | input_ids=input_ids, 257 | input_mask=input_mask, 258 | segment_ids=segment_ids, 259 | label_ids=label_ids, 260 | # label_mask = label_mask 261 | ) 262 | return feature 263 | 264 | 265 | if __name__ == "__main__": 266 | app.run(host='0.0.0.0', port=12345) 267 | #online_predict() 268 | 269 | 270 | -------------------------------------------------------------------------------- /bert_base/server/zmq_decor.py: -------------------------------------------------------------------------------- 1 | from contextlib import ExitStack 2 | 3 | from zmq.decorators import _Decorator 4 | 5 | __all__ = ['multi_socket'] 6 | 7 | from functools import wraps 8 | 9 | import zmq 10 | 11 | 12 | class _MyDecorator(_Decorator): 13 | def __call__(self, *dec_args, **dec_kwargs): 14 | kw_name, dec_args, dec_kwargs = self.process_decorator_args(*dec_args, **dec_kwargs) 15 | num_socket_str = dec_kwargs.pop('num_socket') 16 | 17 | def decorator(func): 18 | @wraps(func) 19 | def wrapper(*args, **kwargs): 20 | num_socket = getattr(args[0], num_socket_str) 21 | targets = [self.get_target(*args, **kwargs) for _ in range(num_socket)] 22 | with ExitStack() as stack: 23 | for target in targets: 24 | obj = stack.enter_context(target(*dec_args, **dec_kwargs)) 25 | args = args + (obj,) 26 | 27 | return func(*args, **kwargs) 28 | 29 | return wrapper 30 | 31 | return decorator 32 | 33 | 34 | class _SocketDecorator(_MyDecorator): 35 | def process_decorator_args(self, *args, **kwargs): 36 | """Also grab context_name out of kwargs""" 37 | kw_name, args, kwargs = super(_SocketDecorator, self).process_decorator_args(*args, **kwargs) 38 | self.context_name = kwargs.pop('context_name', 'context') 39 | return kw_name, args, kwargs 40 | 41 | def get_target(self, *args, **kwargs): 42 | """Get context, based on call-time args""" 43 | context = self._get_context(*args, **kwargs) 44 | return context.socket 45 | 46 | def _get_context(self, *args, **kwargs): 47 | if self.context_name in kwargs: 48 | ctx = kwargs[self.context_name] 49 | 50 | if isinstance(ctx, zmq.Context): 51 | return ctx 52 | 53 | for arg in args: 54 | if isinstance(arg, zmq.Context): 55 | return arg 56 | # not specified by any decorator 57 | return zmq.Context.instance() 58 | 59 | 60 | def multi_socket(*args, **kwargs): 61 | return _SocketDecorator()(*args, **kwargs) 62 | -------------------------------------------------------------------------------- /bert_base/train/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @Time : 2019/1/30 16:53 6 | @Author : MaCan (ma_cancan@163.com) 7 | @File : __init__.py.py 8 | """ -------------------------------------------------------------------------------- /bert_base/train/conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | # @features = split(/\t/,$line); 86 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 87 | elsif ($nbrOfFeatures != $#features and @features != 0) { 88 | printf STDERR "unexpected number of features: %d (%d)\n", 89 | $#features+1,$nbrOfFeatures+1; 90 | exit(1); 91 | } 92 | if (@features == 0 or 93 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 94 | if (@features < 2) { 95 | printf STDERR "feature length is %d. \n", @features; 96 | die "conlleval: unexpected number of features in line $line\n"; 97 | } 98 | if ($raw) { 99 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 100 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 101 | if ($features[$#features] ne "O") { 102 | $features[$#features] = "B-$features[$#features]"; 103 | } 104 | if ($features[$#features-1] ne "O") { 105 | $features[$#features-1] = "B-$features[$#features-1]"; 106 | } 107 | } 108 | # 20040126 ET code which allows hyphens in the types 109 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 110 | $guessed = $1; 111 | $guessedType = $2; 112 | } else { 113 | $guessed = $features[$#features]; 114 | $guessedType = ""; 115 | } 116 | pop(@features); 117 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 118 | $correct = $1; 119 | $correctType = $2; 120 | } else { 121 | $correct = $features[$#features]; 122 | $correctType = ""; 123 | } 124 | pop(@features); 125 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 126 | # ($correct,$correctType) = split(/-/,pop(@features)); 127 | $guessedType = $guessedType ? $guessedType : ""; 128 | $correctType = $correctType ? $correctType : ""; 129 | $firstItem = shift(@features); 130 | 131 | # 1999-06-26 sentence breaks should always be counted as out of chunk 132 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 133 | 134 | if ($inCorrect) { 135 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 136 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 137 | $lastGuessedType eq $lastCorrectType) { 138 | $inCorrect=$false; 139 | $correctChunk++; 140 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 141 | $correctChunk{$lastCorrectType}+1 : 1; 142 | } elsif ( 143 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 144 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 145 | $guessedType ne $correctType ) { 146 | $inCorrect=$false; 147 | } 148 | } 149 | 150 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 151 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 152 | $guessedType eq $correctType) { $inCorrect = $true; } 153 | 154 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 155 | $foundCorrect++; 156 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 157 | $foundCorrect{$correctType}+1 : 1; 158 | } 159 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 160 | $foundGuessed++; 161 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 162 | $foundGuessed{$guessedType}+1 : 1; 163 | } 164 | if ( $firstItem ne $boundary ) { 165 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 166 | $correctTags++; 167 | } 168 | $tokenCounter++; 169 | } 170 | 171 | $lastGuessed = $guessed; 172 | $lastCorrect = $correct; 173 | $lastGuessedType = $guessedType; 174 | $lastCorrectType = $correctType; 175 | } 176 | if ($inCorrect) { 177 | $correctChunk++; 178 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 179 | $correctChunk{$lastCorrectType}+1 : 1; 180 | } 181 | 182 | if (not $latex) { 183 | # compute overall precision, recall and FB1 (default values are 0.0) 184 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 185 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 186 | $FB1 = 2*$precision*$recall/($precision+$recall) 187 | if ($precision+$recall > 0); 188 | 189 | # print overall performance 190 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 191 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 192 | if ($tokenCounter>0) { 193 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 194 | printf "precision: %6.2f%%; ",$precision; 195 | printf "recall: %6.2f%%; ",$recall; 196 | printf "FB1: %6.2f\n",$FB1; 197 | } 198 | } 199 | 200 | # sort chunk type names 201 | undef($lastType); 202 | @sortedTypes = (); 203 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 204 | if (not($lastType) or $lastType ne $i) { 205 | push(@sortedTypes,($i)); 206 | } 207 | $lastType = $i; 208 | } 209 | # print performance per chunk type 210 | if (not $latex) { 211 | for $i (@sortedTypes) { 212 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 213 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 214 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 215 | if (not($foundCorrect{$i})) { $recall = 0.0; } 216 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 217 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 218 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 219 | printf "%17s: ",$i; 220 | printf "precision: %6.2f%%; ",$precision; 221 | printf "recall: %6.2f%%; ",$recall; 222 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 223 | } 224 | } else { 225 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 226 | for $i (@sortedTypes) { 227 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 228 | if (not($foundGuessed{$i})) { $precision = 0.0; } 229 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 230 | if (not($foundCorrect{$i})) { $recall = 0.0; } 231 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 232 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 233 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 234 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 235 | $i,$precision,$recall,$FB1; 236 | } 237 | print "\\hline\n"; 238 | $precision = 0.0; 239 | $recall = 0; 240 | $FB1 = 0.0; 241 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 242 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 243 | $FB1 = 2*$precision*$recall/($precision+$recall) 244 | if ($precision+$recall > 0); 245 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 246 | $precision,$recall,$FB1; 247 | } 248 | 249 | exit 0; 250 | 251 | # endOfChunk: checks if a chunk ended between the previous and current word 252 | # arguments: previous and current chunk tags, previous and current types 253 | # note: this code is capable of handling other chunk representations 254 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 255 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 256 | 257 | sub endOfChunk { 258 | my $prevTag = shift(@_); 259 | my $tag = shift(@_); 260 | my $prevType = shift(@_); 261 | my $type = shift(@_); 262 | my $chunkEnd = $false; 263 | 264 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 266 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 267 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 268 | 269 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 271 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 272 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 273 | 274 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 275 | $chunkEnd = $true; 276 | } 277 | 278 | # corrected 1998-12-22: these chunks are assumed to have length 1 279 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 280 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 281 | 282 | return($chunkEnd); 283 | } 284 | 285 | # startOfChunk: checks if a chunk started between the previous and current word 286 | # arguments: previous and current chunk tags, previous and current types 287 | # note: this code is capable of handling other chunk representations 288 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 289 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 290 | 291 | sub startOfChunk { 292 | my $prevTag = shift(@_); 293 | my $tag = shift(@_); 294 | my $prevType = shift(@_); 295 | my $type = shift(@_); 296 | my $chunkStart = $false; 297 | 298 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 300 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 301 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 302 | 303 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 305 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 306 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 307 | 308 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 309 | $chunkStart = $true; 310 | } 311 | 312 | # corrected 1998-12-22: these chunks are assumed to have length 1 313 | if ( $tag eq "[" ) { $chunkStart = $true; } 314 | if ( $tag eq "]" ) { $chunkStart = $true; } 315 | 316 | return($chunkStart); 317 | } -------------------------------------------------------------------------------- /bert_base/train/conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | # add function :evaluate(predicted_label, ori_label): which will not read from file 13 | 14 | import sys 15 | import re 16 | import codecs 17 | from collections import defaultdict, namedtuple 18 | 19 | ANY_SPACE = '' 20 | 21 | 22 | class FormatError(Exception): 23 | pass 24 | 25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 26 | 27 | 28 | class EvalCounts(object): 29 | def __init__(self): 30 | self.correct_chunk = 0 # number of correctly identified chunks 31 | self.correct_tags = 0 # number of correct chunk tags 32 | self.found_correct = 0 # number of chunks in corpus 33 | self.found_guessed = 0 # number of identified chunks 34 | self.token_counter = 0 # token counter (ignores sentence breaks) 35 | 36 | # counts by type 37 | self.t_correct_chunk = defaultdict(int) 38 | self.t_found_correct = defaultdict(int) 39 | self.t_found_guessed = defaultdict(int) 40 | 41 | 42 | def parse_args(argv): 43 | import argparse 44 | parser = argparse.ArgumentParser( 45 | description='evaluate tagging results using CoNLL criteria', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 47 | ) 48 | arg = parser.add_argument 49 | arg('-b', '--boundary', metavar='STR', default='-X-', 50 | help='sentence boundary') 51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 52 | help='character delimiting items in input') 53 | arg('-o', '--otag', metavar='CHAR', default='O', 54 | help='alternative outside tag') 55 | arg('file', nargs='?', default=None) 56 | return parser.parse_args(argv) 57 | 58 | 59 | def parse_tag(t): 60 | m = re.match(r'^([^-]*)-(.*)$', t) 61 | return m.groups() if m else (t, '') 62 | 63 | 64 | def evaluate(iterable, options=None): 65 | if options is None: 66 | options = parse_args([]) # use defaults 67 | 68 | counts = EvalCounts() 69 | num_features = None # number of features per line 70 | in_correct = False # currently processed chunks is correct until now 71 | last_correct = 'O' # previous chunk tag in corpus 72 | last_correct_type = '' # type of previously identified chunk tag 73 | last_guessed = 'O' # previously identified chunk tag 74 | last_guessed_type = '' # type of previous chunk tag in corpus 75 | 76 | for line in iterable: 77 | line = line.rstrip('\r\n') 78 | 79 | if options.delimiter == ANY_SPACE: 80 | features = line.split() 81 | else: 82 | features = line.split(options.delimiter) 83 | 84 | if num_features is None: 85 | num_features = len(features) 86 | elif num_features != len(features) and len(features) != 0: 87 | raise FormatError('unexpected number of features: %d (%d)' % 88 | (len(features), num_features)) 89 | 90 | if len(features) == 0 or features[0] == options.boundary: 91 | features = [options.boundary, 'O', 'O'] 92 | if len(features) < 3: 93 | raise FormatError('unexpected number of features in line %s' % line) 94 | 95 | guessed, guessed_type = parse_tag(features.pop()) 96 | correct, correct_type = parse_tag(features.pop()) 97 | first_item = features.pop(0) 98 | 99 | if first_item == options.boundary: 100 | guessed = 'O' 101 | 102 | end_correct = end_of_chunk(last_correct, correct, 103 | last_correct_type, correct_type) 104 | end_guessed = end_of_chunk(last_guessed, guessed, 105 | last_guessed_type, guessed_type) 106 | start_correct = start_of_chunk(last_correct, correct, 107 | last_correct_type, correct_type) 108 | start_guessed = start_of_chunk(last_guessed, guessed, 109 | last_guessed_type, guessed_type) 110 | 111 | if in_correct: 112 | if (end_correct and end_guessed and 113 | last_guessed_type == last_correct_type): 114 | in_correct = False 115 | counts.correct_chunk += 1 116 | counts.t_correct_chunk[last_correct_type] += 1 117 | elif (end_correct != end_guessed or guessed_type != correct_type): 118 | in_correct = False 119 | 120 | if start_correct and start_guessed and guessed_type == correct_type: 121 | in_correct = True 122 | 123 | if start_correct: 124 | counts.found_correct += 1 125 | counts.t_found_correct[correct_type] += 1 126 | if start_guessed: 127 | counts.found_guessed += 1 128 | counts.t_found_guessed[guessed_type] += 1 129 | if first_item != options.boundary: 130 | if correct == guessed and guessed_type == correct_type: 131 | counts.correct_tags += 1 132 | counts.token_counter += 1 133 | 134 | last_guessed = guessed 135 | last_correct = correct 136 | last_guessed_type = guessed_type 137 | last_correct_type = correct_type 138 | 139 | if in_correct: 140 | counts.correct_chunk += 1 141 | counts.t_correct_chunk[last_correct_type] += 1 142 | 143 | return counts 144 | 145 | 146 | 147 | def uniq(iterable): 148 | seen = set() 149 | return [i for i in iterable if not (i in seen or seen.add(i))] 150 | 151 | 152 | def calculate_metrics(correct, guessed, total): 153 | tp, fp, fn = correct, guessed-correct, total-correct 154 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 155 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 156 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 157 | return Metrics(tp, fp, fn, p, r, f) 158 | 159 | 160 | def metrics(counts): 161 | c = counts 162 | overall = calculate_metrics( 163 | c.correct_chunk, c.found_guessed, c.found_correct 164 | ) 165 | by_type = {} 166 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 167 | by_type[t] = calculate_metrics( 168 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 169 | ) 170 | return overall, by_type 171 | 172 | 173 | def report(counts, out=None): 174 | if out is None: 175 | out = sys.stdout 176 | 177 | overall, by_type = metrics(counts) 178 | 179 | c = counts 180 | out.write('processed %d tokens with %d phrases; ' % 181 | (c.token_counter, c.found_correct)) 182 | out.write('found: %d phrases; correct: %d.\n' % 183 | (c.found_guessed, c.correct_chunk)) 184 | 185 | if c.token_counter > 0: 186 | out.write('accuracy: %6.2f%%; ' % 187 | (100.*c.correct_tags/c.token_counter)) 188 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 189 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 190 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 191 | 192 | for i, m in sorted(by_type.items()): 193 | out.write('%17s: ' % i) 194 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 195 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 196 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 197 | 198 | 199 | def report_notprint(counts, out=None): 200 | if out is None: 201 | out = sys.stdout 202 | 203 | overall, by_type = metrics(counts) 204 | 205 | c = counts 206 | final_report = [] 207 | line = [] 208 | line.append('processed %d tokens with %d phrases; ' % 209 | (c.token_counter, c.found_correct)) 210 | line.append('found: %d phrases; correct: %d.\n' % 211 | (c.found_guessed, c.correct_chunk)) 212 | final_report.append("".join(line)) 213 | 214 | if c.token_counter > 0: 215 | line = [] 216 | line.append('accuracy: %6.2f%%; ' % 217 | (100.*c.correct_tags/c.token_counter)) 218 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 219 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 220 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 221 | final_report.append("".join(line)) 222 | 223 | for i, m in sorted(by_type.items()): 224 | line = [] 225 | line.append('%17s: ' % i) 226 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 227 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 228 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 229 | final_report.append("".join(line)) 230 | return final_report 231 | 232 | 233 | def end_of_chunk(prev_tag, tag, prev_type, type_): 234 | # check if a chunk ended between the previous and current word 235 | # arguments: previous and current chunk tags, previous and current types 236 | chunk_end = False 237 | 238 | if prev_tag == 'E': chunk_end = True 239 | if prev_tag == 'S': chunk_end = True 240 | 241 | if prev_tag == 'B' and tag == 'B': chunk_end = True 242 | if prev_tag == 'B' and tag == 'S': chunk_end = True 243 | if prev_tag == 'B' and tag == 'O': chunk_end = True 244 | if prev_tag == 'I' and tag == 'B': chunk_end = True 245 | if prev_tag == 'I' and tag == 'S': chunk_end = True 246 | if prev_tag == 'I' and tag == 'O': chunk_end = True 247 | 248 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 249 | chunk_end = True 250 | 251 | # these chunks are assumed to have length 1 252 | if prev_tag == ']': chunk_end = True 253 | if prev_tag == '[': chunk_end = True 254 | 255 | return chunk_end 256 | 257 | 258 | def start_of_chunk(prev_tag, tag, prev_type, type_): 259 | # check if a chunk started between the previous and current word 260 | # arguments: previous and current chunk tags, previous and current types 261 | chunk_start = False 262 | 263 | if tag == 'B': chunk_start = True 264 | if tag == 'S': chunk_start = True 265 | 266 | if prev_tag == 'E' and tag == 'E': chunk_start = True 267 | if prev_tag == 'E' and tag == 'I': chunk_start = True 268 | if prev_tag == 'S' and tag == 'E': chunk_start = True 269 | if prev_tag == 'S' and tag == 'I': chunk_start = True 270 | if prev_tag == 'O' and tag == 'E': chunk_start = True 271 | if prev_tag == 'O' and tag == 'I': chunk_start = True 272 | 273 | if tag != 'O' and tag != '.' and prev_type != type_: 274 | chunk_start = True 275 | 276 | # these chunks are assumed to have length 1 277 | if tag == '[': chunk_start = True 278 | if tag == ']': chunk_start = True 279 | 280 | return chunk_start 281 | 282 | 283 | def return_report(input_file): 284 | with codecs.open(input_file, "r", "utf8") as f: 285 | counts = evaluate(f) 286 | return report_notprint(counts) 287 | 288 | 289 | def main(argv): 290 | args = parse_args(argv[1:]) 291 | 292 | if args.file is None: 293 | counts = evaluate(sys.stdin, args) 294 | else: 295 | with open(args.file) as f: 296 | counts = evaluate(f, args) 297 | report(counts) 298 | 299 | if __name__ == '__main__': 300 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /bert_base/train/lstm_crf_layer.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | bert-blstm-crf layer 5 | @Author:Macan 6 | """ 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib import rnn 10 | from tensorflow.contrib import crf 11 | 12 | 13 | class BLSTM_CRF(object): 14 | def __init__(self, embedded_chars, hidden_unit, cell_type, num_layers, dropout_rate, 15 | initializers, num_labels, seq_length, labels, lengths, is_training): 16 | """ 17 | BLSTM-CRF 网络 18 | :param embedded_chars: Fine-tuning embedding input 19 | :param hidden_unit: LSTM的隐含单元个数 20 | :param cell_type: RNN类型(LSTM OR GRU DICNN will be add in feature) 21 | :param num_layers: RNN的层数 22 | :param droupout_rate: droupout rate 23 | :param initializers: variable init class 24 | :param num_labels: 标签数量 25 | :param seq_length: 序列最大长度 26 | :param labels: 真实标签 27 | :param lengths: [batch_size] 每个batch下序列的真实长度 28 | :param is_training: 是否是训练过程 29 | """ 30 | self.hidden_unit = hidden_unit 31 | self.dropout_rate = dropout_rate 32 | self.cell_type = cell_type 33 | self.num_layers = num_layers 34 | self.embedded_chars = embedded_chars 35 | self.initializers = initializers 36 | self.seq_length = seq_length 37 | self.num_labels = num_labels 38 | self.labels = labels 39 | self.lengths = lengths 40 | self.embedding_dims = embedded_chars.shape[-1].value 41 | self.is_training = is_training 42 | 43 | def add_blstm_crf_layer(self, crf_only): 44 | """ 45 | blstm-crf网络 46 | :return: 47 | """ 48 | if self.is_training: 49 | # lstm input dropout rate i set 0.9 will get best score 50 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.dropout_rate) 51 | 52 | if crf_only: 53 | logits = self.project_crf_layer(self.embedded_chars) 54 | else: 55 | # blstm 56 | lstm_output = self.blstm_layer(self.embedded_chars) 57 | # project 58 | logits = self.project_bilstm_layer(lstm_output) 59 | # crf 60 | loss, trans = self.crf_layer(logits) 61 | # CRF decode, pred_ids 是一条最大概率的标注路径 62 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths) 63 | return (loss, logits, trans, pred_ids) 64 | 65 | def _witch_cell(self): 66 | """ 67 | RNN 类型 68 | :return: 69 | """ 70 | cell_tmp = None 71 | if self.cell_type == 'lstm': 72 | cell_tmp = rnn.LSTMCell(self.hidden_unit) 73 | elif self.cell_type == 'gru': 74 | cell_tmp = rnn.GRUCell(self.hidden_unit) 75 | return cell_tmp 76 | 77 | def _bi_dir_rnn(self): 78 | """ 79 | 双向RNN 80 | :return: 81 | """ 82 | cell_fw = self._witch_cell() 83 | cell_bw = self._witch_cell() 84 | if self.dropout_rate is not None: 85 | cell_bw = rnn.DropoutWrapper(cell_bw, output_keep_prob=self.dropout_rate) 86 | cell_fw = rnn.DropoutWrapper(cell_fw, output_keep_prob=self.dropout_rate) 87 | return cell_fw, cell_bw 88 | 89 | def blstm_layer(self, embedding_chars): 90 | """ 91 | 92 | :return: 93 | """ 94 | with tf.variable_scope('rnn_layer'): 95 | cell_fw, cell_bw = self._bi_dir_rnn() 96 | if self.num_layers > 1: 97 | cell_fw = rnn.MultiRNNCell([cell_fw] * self.num_layers, state_is_tuple=True) 98 | cell_bw = rnn.MultiRNNCell([cell_bw] * self.num_layers, state_is_tuple=True) 99 | 100 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, embedding_chars, 101 | dtype=tf.float32) 102 | outputs = tf.concat(outputs, axis=2) 103 | return outputs 104 | 105 | def project_bilstm_layer(self, lstm_outputs, name=None): 106 | """ 107 | hidden layer between lstm layer and logits 108 | :param lstm_outputs: [batch_size, num_steps, emb_size] 109 | :return: [batch_size, num_steps, num_tags] 110 | """ 111 | with tf.variable_scope("project" if not name else name): 112 | with tf.variable_scope("hidden"): 113 | W = tf.get_variable("W", shape=[self.hidden_unit * 2, self.hidden_unit], 114 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 115 | 116 | b = tf.get_variable("b", shape=[self.hidden_unit], dtype=tf.float32, 117 | initializer=tf.zeros_initializer()) 118 | output = tf.reshape(lstm_outputs, shape=[-1, self.hidden_unit * 2]) 119 | hidden = tf.nn.xw_plus_b(output, W, b) 120 | 121 | # project to score of tags 122 | with tf.variable_scope("logits"): 123 | W = tf.get_variable("W", shape=[self.hidden_unit, self.num_labels], 124 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 125 | 126 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 127 | initializer=tf.zeros_initializer()) 128 | 129 | pred = tf.nn.xw_plus_b(hidden, W, b) 130 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 131 | 132 | def project_crf_layer(self, embedding_chars, name=None): 133 | """ 134 | hidden layer between input layer and logits 135 | :param lstm_outputs: [batch_size, num_steps, emb_size] 136 | :return: [batch_size, num_steps, num_tags] 137 | """ 138 | with tf.variable_scope("project" if not name else name): 139 | with tf.variable_scope("logits"): 140 | W = tf.get_variable("W", shape=[self.embedding_dims, self.num_labels], 141 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 142 | 143 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 144 | initializer=tf.zeros_initializer()) 145 | output = tf.reshape(self.embedded_chars, 146 | shape=[-1, self.embedding_dims]) # [batch_size, embedding_dims] 147 | pred = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 148 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 149 | 150 | def crf_layer(self, logits): 151 | """ 152 | calculate crf loss 153 | :param project_logits: [1, num_steps, num_tags] 154 | :return: scalar loss 155 | """ 156 | with tf.variable_scope("crf_loss"): 157 | trans = tf.get_variable( 158 | "transitions", 159 | shape=[self.num_labels, self.num_labels], 160 | initializer=self.initializers.xavier_initializer()) 161 | if self.labels is None: 162 | return None, trans 163 | else: 164 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood( 165 | inputs=logits, 166 | tag_indices=self.labels, 167 | transition_params=trans, 168 | sequence_lengths=self.lengths) 169 | return tf.reduce_mean(-log_likelihood), trans 170 | -------------------------------------------------------------------------------- /bert_base/train/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 一些公共模型代码 5 | @Time : 2019/1/30 12:46 6 | @Author : MaCan (ma_cancan@163.com) 7 | @File : models.py 8 | """ 9 | 10 | from bert_base.train.lstm_crf_layer import BLSTM_CRF 11 | from tensorflow.contrib.layers.python.layers import initializers 12 | 13 | 14 | __all__ = ['InputExample', 'InputFeatures', 'decode_labels', 'create_model', 'convert_id_str', 15 | 'convert_id_to_label', 'result_to_json', 'create_classification_model'] 16 | 17 | class Model(object): 18 | def __init__(self, *args, **kwargs): 19 | pass 20 | 21 | 22 | class InputExample(object): 23 | """A single training/test example for simple sequence classification.""" 24 | 25 | def __init__(self, guid=None, text=None, label=None): 26 | """Constructs a InputExample. 27 | Args: 28 | guid: Unique id for the example. 29 | text_a: string. The untokenized text of the first sequence. For single 30 | sequence tasks, only this sequence must be specified. 31 | label: (Optional) string. The label of the example. This should be 32 | specified for train and dev examples, but not for test examples. 33 | """ 34 | self.guid = guid 35 | self.text = text 36 | self.label = label 37 | 38 | class InputFeatures(object): 39 | """A single set of features of data.""" 40 | 41 | def __init__(self, input_ids, input_mask, segment_ids, label_ids, ): 42 | self.input_ids = input_ids 43 | self.input_mask = input_mask 44 | self.segment_ids = segment_ids 45 | self.label_ids = label_ids 46 | # self.label_mask = label_mask 47 | 48 | 49 | class DataProcessor(object): 50 | """Base class for data converters for sequence classification data sets.""" 51 | 52 | def get_train_examples(self, data_dir): 53 | """Gets a collection of `InputExample`s for the train set.""" 54 | raise NotImplementedError() 55 | 56 | def get_dev_examples(self, data_dir): 57 | """Gets a collection of `InputExample`s for the dev set.""" 58 | raise NotImplementedError() 59 | 60 | def get_labels(self): 61 | """Gets the list of labels for this data set.""" 62 | raise NotImplementedError() 63 | 64 | 65 | def create_model(bert_config, is_training, input_ids, input_mask, 66 | segment_ids, labels, num_labels, use_one_hot_embeddings, 67 | dropout_rate=1.0, lstm_size=1, cell='lstm', num_layers=1): 68 | """ 69 | 创建X模型 70 | :param bert_config: bert 配置 71 | :param is_training: 72 | :param input_ids: 数据的idx 表示 73 | :param input_mask: 74 | :param segment_ids: 75 | :param labels: 标签的idx 表示 76 | :param num_labels: 类别数量 77 | :param use_one_hot_embeddings: 78 | :return: 79 | """ 80 | # 使用数据加载BertModel,获取对应的字embedding 81 | import tensorflow as tf 82 | from bert_base.bert import modeling 83 | model = modeling.BertModel( 84 | config=bert_config, 85 | is_training=is_training, 86 | input_ids=input_ids, 87 | input_mask=input_mask, 88 | token_type_ids=segment_ids, 89 | use_one_hot_embeddings=use_one_hot_embeddings 90 | ) 91 | # 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size] 92 | embedding = model.get_sequence_output() 93 | max_seq_length = embedding.shape[1].value 94 | # 算序列真实长度 95 | used = tf.sign(tf.abs(input_ids)) 96 | lengths = tf.reduce_sum(used, reduction_indices=1) # [batch_size] 大小的向量,包含了当前batch中的序列长度 97 | # 添加CRF output layer 98 | blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=lstm_size, cell_type=cell, num_layers=num_layers, 99 | dropout_rate=dropout_rate, initializers=initializers, num_labels=num_labels, 100 | seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training) 101 | rst = blstm_crf.add_blstm_crf_layer(crf_only=True) 102 | return rst 103 | 104 | 105 | def create_classification_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels): 106 | """ 107 | 108 | :param bert_config: 109 | :param is_training: 110 | :param input_ids: 111 | :param input_mask: 112 | :param segment_ids: 113 | :param labels: 114 | :param num_labels: 115 | :param use_one_hot_embedding: 116 | :return: 117 | """ 118 | import tensorflow as tf 119 | from bert_base.bert import modeling 120 | # 通过传入的训练数据,进行representation 121 | model = modeling.BertModel( 122 | config=bert_config, 123 | is_training=is_training, 124 | input_ids=input_ids, 125 | input_mask=input_mask, 126 | token_type_ids=segment_ids, 127 | ) 128 | 129 | embedding_layer = model.get_sequence_output() 130 | output_layer = model.get_pooled_output() 131 | hidden_size = output_layer.shape[-1].value 132 | 133 | # predict = CNN_Classification(embedding_chars=embedding_layer, 134 | # labels=labels, 135 | # num_tags=num_labels, 136 | # sequence_length=FLAGS.max_seq_length, 137 | # embedding_dims=embedding_layer.shape[-1].value, 138 | # vocab_size=0, 139 | # filter_sizes=[3, 4, 5], 140 | # num_filters=3, 141 | # dropout_keep_prob=FLAGS.dropout_keep_prob, 142 | # l2_reg_lambda=0.001) 143 | # loss, predictions, probabilities = predict.add_cnn_layer() 144 | 145 | output_weights = tf.get_variable( 146 | "output_weights", [num_labels, hidden_size], 147 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 148 | 149 | output_bias = tf.get_variable( 150 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 151 | 152 | with tf.variable_scope("loss"): 153 | if is_training: 154 | # I.e., 0.1 dropout 155 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 156 | 157 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 158 | logits = tf.nn.bias_add(logits, output_bias) 159 | probabilities = tf.nn.softmax(logits, axis=-1) 160 | log_probs = tf.nn.log_softmax(logits, axis=-1) 161 | 162 | if labels is not None: 163 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 164 | 165 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 166 | loss = tf.reduce_mean(per_example_loss) 167 | else: 168 | loss, per_example_loss = None, None 169 | return (loss, per_example_loss, logits, probabilities) 170 | 171 | 172 | def decode_labels(labels, batch_size): 173 | new_labels = [] 174 | for row in range(batch_size): 175 | label = [] 176 | for i in labels[row]: 177 | i = i.decode('utf-8') 178 | if i == '**PAD**': 179 | break 180 | if i in ['[CLS]', '[SEP]']: 181 | continue 182 | label.append(i) 183 | new_labels.append(label) 184 | return new_labels 185 | 186 | 187 | def convert_id_str(input_ids, batch_size): 188 | res = [] 189 | for row in range(batch_size): 190 | line = [] 191 | for i in input_ids[row]: 192 | i = i.decode('utf-8') 193 | if i == '**PAD**': 194 | break 195 | if i in ['[CLS]', '[SEP]']: 196 | continue 197 | 198 | line.append(i) 199 | res.append(line) 200 | return res 201 | 202 | 203 | def convert_id_to_label(pred_ids_result, idx2label, batch_size): 204 | """ 205 | 将id形式的结果转化为真实序列结果 206 | :param pred_ids_result: 207 | :param idx2label: 208 | :return: 209 | """ 210 | result = [] 211 | index_result = [] 212 | for row in range(batch_size): 213 | curr_seq = [] 214 | curr_idx = [] 215 | ids = pred_ids_result[row] 216 | for idx, id in enumerate(ids): 217 | if id == 0: 218 | break 219 | curr_label = idx2label[id] 220 | if curr_label in ['[CLS]', '[SEP]']: 221 | if id == 102 and (idx < len(ids) and ids[idx + 1] == 0): 222 | break 223 | continue 224 | # elif curr_label == '[SEP]': 225 | # break 226 | curr_seq.append(curr_label) 227 | curr_idx.append(id) 228 | result.append(curr_seq) 229 | index_result.append(curr_idx) 230 | return result, index_result 231 | 232 | 233 | def result_to_json(self, string, tags): 234 | """ 235 | 将模型标注序列和输入序列结合 转化为结果 236 | :param string: 输入序列 237 | :param tags: 标注结果 238 | :return: 239 | """ 240 | item = {"entities": []} 241 | entity_name = "" 242 | entity_start = 0 243 | idx = 0 244 | last_tag = '' 245 | 246 | for char, tag in zip(string, tags): 247 | if tag[0] == "S": 248 | self.append(char, idx, idx+1, tag[2:]) 249 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]}) 250 | elif tag[0] == "B": 251 | if entity_name != '': 252 | self.append(entity_name, entity_start, idx, last_tag[2:]) 253 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 254 | entity_name = "" 255 | entity_name += char 256 | entity_start = idx 257 | elif tag[0] == "I": 258 | entity_name += char 259 | elif tag[0] == "O": 260 | if entity_name != '': 261 | self.append(entity_name, entity_start, idx, last_tag[2:]) 262 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 263 | entity_name = "" 264 | else: 265 | entity_name = "" 266 | entity_start = idx 267 | idx += 1 268 | last_tag = tag 269 | if entity_name != '': 270 | self.append(entity_name, entity_start, idx, last_tag[2:]) 271 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 272 | return item 273 | -------------------------------------------------------------------------------- /bert_base/train/tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | __all__ = ['precision', 'recall', 'f1', 'fbeta', 'safe_div', 'pr_re_fbeta', 'pr_re_fbeta', 'metrics_from_confusion_matrix'] 15 | 16 | 17 | def precision(labels, predictions, num_classes, pos_indices=None, 18 | weights=None, average='micro'): 19 | """Multi-class precision metric for Tensorflow 20 | Parameters 21 | ---------- 22 | labels : Tensor of tf.int32 or tf.int64 23 | The true labels 24 | predictions : Tensor of tf.int32 or tf.int64 25 | The predictions, same shape as labels 26 | num_classes : int 27 | The number of classes 28 | pos_indices : list of int, optional 29 | The indices of the positive classes, default is all 30 | weights : Tensor of tf.int32, optional 31 | Mask, must be of compatible shape with labels 32 | average : str, optional 33 | 'micro': counts the total number of true positives, false 34 | positives, and false negatives for the classes in 35 | `pos_indices` and infer the metric from it. 36 | 'macro': will compute the metric separately for each class in 37 | `pos_indices` and average. Will not account for class 38 | imbalance. 39 | 'weighted': will compute the metric separately for each class in 40 | `pos_indices` and perform a weighted average by the total 41 | number of true labels for each class. 42 | Returns 43 | ------- 44 | tuple of (scalar float Tensor, update_op) 45 | """ 46 | cm, op = _streaming_confusion_matrix( 47 | labels, predictions, num_classes, weights) 48 | pr, _, _ = metrics_from_confusion_matrix( 49 | cm, pos_indices, average=average) 50 | op, _, _ = metrics_from_confusion_matrix( 51 | op, pos_indices, average=average) 52 | return (pr, op) 53 | 54 | 55 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 56 | average='micro'): 57 | """Multi-class recall metric for Tensorflow 58 | Parameters 59 | ---------- 60 | labels : Tensor of tf.int32 or tf.int64 61 | The true labels 62 | predictions : Tensor of tf.int32 or tf.int64 63 | The predictions, same shape as labels 64 | num_classes : int 65 | The number of classes 66 | pos_indices : list of int, optional 67 | The indices of the positive classes, default is all 68 | weights : Tensor of tf.int32, optional 69 | Mask, must be of compatible shape with labels 70 | average : str, optional 71 | 'micro': counts the total number of true positives, false 72 | positives, and false negatives for the classes in 73 | `pos_indices` and infer the metric from it. 74 | 'macro': will compute the metric separately for each class in 75 | `pos_indices` and average. Will not account for class 76 | imbalance. 77 | 'weighted': will compute the metric separately for each class in 78 | `pos_indices` and perform a weighted average by the total 79 | number of true labels for each class. 80 | Returns 81 | ------- 82 | tuple of (scalar float Tensor, update_op) 83 | """ 84 | cm, op = _streaming_confusion_matrix( 85 | labels, predictions, num_classes, weights) 86 | _, re, _ = metrics_from_confusion_matrix( 87 | cm, pos_indices, average=average) 88 | _, op, _ = metrics_from_confusion_matrix( 89 | op, pos_indices, average=average) 90 | return (re, op) 91 | 92 | 93 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 94 | average='micro'): 95 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 96 | average) 97 | 98 | 99 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 100 | average='micro', beta=1): 101 | """Multi-class fbeta metric for Tensorflow 102 | Parameters 103 | ---------- 104 | labels : Tensor of tf.int32 or tf.int64 105 | The true labels 106 | predictions : Tensor of tf.int32 or tf.int64 107 | The predictions, same shape as labels 108 | num_classes : int 109 | The number of classes 110 | pos_indices : list of int, optional 111 | The indices of the positive classes, default is all 112 | weights : Tensor of tf.int32, optional 113 | Mask, must be of compatible shape with labels 114 | average : str, optional 115 | 'micro': counts the total number of true positives, false 116 | positives, and false negatives for the classes in 117 | `pos_indices` and infer the metric from it. 118 | 'macro': will compute the metric separately for each class in 119 | `pos_indices` and average. Will not account for class 120 | imbalance. 121 | 'weighted': will compute the metric separately for each class in 122 | `pos_indices` and perform a weighted average by the total 123 | number of true labels for each class. 124 | beta : int, optional 125 | Weight of precision in harmonic mean 126 | Returns 127 | ------- 128 | tuple of (scalar float Tensor, update_op) 129 | """ 130 | cm, op = _streaming_confusion_matrix( 131 | labels, predictions, num_classes, weights) 132 | _, _, fbeta = metrics_from_confusion_matrix( 133 | cm, pos_indices, average=average, beta=beta) 134 | _, _, op = metrics_from_confusion_matrix( 135 | op, pos_indices, average=average, beta=beta) 136 | return (fbeta, op) 137 | 138 | 139 | def safe_div(numerator, denominator): 140 | """Safe division, return 0 if denominator is 0""" 141 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 142 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 143 | denominator_is_zero = tf.equal(denominator, zeros) 144 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 145 | 146 | 147 | def pr_re_fbeta(cm, pos_indices, beta=1): 148 | """Uses a confusion matrix to compute precision, recall and fbeta""" 149 | num_classes = cm.shape[0] 150 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 151 | cm_mask = np.ones([num_classes, num_classes]) 152 | cm_mask[neg_indices, neg_indices] = 0 153 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 154 | 155 | cm_mask = np.ones([num_classes, num_classes]) 156 | cm_mask[:, neg_indices] = 0 157 | tot_pred = tf.reduce_sum(cm * cm_mask) 158 | 159 | cm_mask = np.ones([num_classes, num_classes]) 160 | cm_mask[neg_indices, :] = 0 161 | tot_gold = tf.reduce_sum(cm * cm_mask) 162 | 163 | pr = safe_div(diag_sum, tot_pred) 164 | re = safe_div(diag_sum, tot_gold) 165 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 166 | 167 | return pr, re, fbeta 168 | 169 | 170 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 171 | beta=1): 172 | """Precision, Recall and F1 from the confusion matrix 173 | Parameters 174 | ---------- 175 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 176 | The streaming confusion matrix. 177 | pos_indices : list of int, optional 178 | The indices of the positive classes 179 | beta : int, optional 180 | Weight of precision in harmonic mean 181 | average : str, optional 182 | 'micro', 'macro' or 'weighted' 183 | """ 184 | num_classes = cm.shape[0] 185 | if pos_indices is None: 186 | pos_indices = [i for i in range(num_classes)] 187 | 188 | if average == 'micro': 189 | return pr_re_fbeta(cm, pos_indices, beta) 190 | elif average in {'macro', 'weighted'}: 191 | precisions, recalls, fbetas, n_golds = [], [], [], [] 192 | for idx in pos_indices: 193 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 194 | precisions.append(pr) 195 | recalls.append(re) 196 | fbetas.append(fbeta) 197 | cm_mask = np.zeros([num_classes, num_classes]) 198 | cm_mask[idx, :] = 1 199 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 200 | 201 | if average == 'macro': 202 | pr = tf.reduce_mean(precisions) 203 | re = tf.reduce_mean(recalls) 204 | fbeta = tf.reduce_mean(fbetas) 205 | return pr, re, fbeta 206 | if average == 'weighted': 207 | n_gold = tf.reduce_sum(n_golds) 208 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 209 | pr = safe_div(pr_sum, n_gold) 210 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 211 | re = safe_div(re_sum, n_gold) 212 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 213 | fbeta = safe_div(fbeta_sum, n_gold) 214 | return pr, re, fbeta 215 | 216 | else: 217 | raise NotImplementedError() -------------------------------------------------------------------------------- /bert_base/train/train_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @Time : 2019/1/30 14:01 6 | @Author : MaCan (ma_cancan@163.com) 7 | @File : train_helper.py 8 | """ 9 | 10 | import argparse 11 | import os 12 | 13 | __all__ = ['get_args_parser'] 14 | 15 | def get_args_parser(): 16 | from .bert_lstm_ner import __version__ 17 | parser = argparse.ArgumentParser() 18 | if os.name == 'nt': 19 | bert_path = 'F:\chinese_L-12_H-768_A-12' 20 | root_path = r'C:\workspace\python\BERT-BiLSTM-CRF-NER' 21 | else: 22 | bert_path = '/home/macan/ml/data/chinese_L-12_H-768_A-12/' 23 | root_path = '/home/macan/ml/workspace/BERT-BiLSTM-CRF-NER' 24 | 25 | group1 = parser.add_argument_group('File Paths', 26 | 'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model') 27 | group1.add_argument('-data_dir', type=str, default=os.path.join(root_path, 'NERdata'), 28 | help='train, dev and test data dir') 29 | group1.add_argument('-bert_config_file', type=str, default=os.path.join(bert_path, 'bert_config.json')) 30 | group1.add_argument('-output_dir', type=str, default=os.path.join(root_path, 'output'), 31 | help='directory of a pretrained BERT model') 32 | group1.add_argument('-init_checkpoint', type=str, default=os.path.join(bert_path, 'bert_model.ckpt'), 33 | help='Initial checkpoint (usually from a pre-trained BERT model).') 34 | group1.add_argument('-vocab_file', type=str, default=os.path.join(bert_path, 'vocab.txt'), 35 | help='') 36 | 37 | group2 = parser.add_argument_group('Model Config', 'config the model params') 38 | group2.add_argument('-max_seq_length', type=int, default=202, 39 | help='The maximum total input sequence length after WordPiece tokenization.') 40 | group2.add_argument('-do_train', action='store_false', default=True, 41 | help='Whether to run training.') 42 | group2.add_argument('-do_eval', action='store_false', default=True, 43 | help='Whether to run eval on the dev set.') 44 | group2.add_argument('-do_predict', action='store_false', default=True, 45 | help='Whether to run the predict in inference mode on the test set.') 46 | group2.add_argument('-batch_size', type=int, default=64, 47 | help='Total batch size for training, eval and predict.') 48 | group2.add_argument('-learning_rate', type=float, default=1e-5, 49 | help='The initial learning rate for Adam.') 50 | group2.add_argument('-num_train_epochs', type=float, default=10, 51 | help='Total number of training epochs to perform.') 52 | group2.add_argument('-dropout_rate', type=float, default=0.5, 53 | help='Dropout rate') 54 | group2.add_argument('-clip', type=float, default=0.5, 55 | help='Gradient clip') 56 | group2.add_argument('-warmup_proportion', type=float, default=0.1, 57 | help='Proportion of training to perform linear learning rate warmup for ' 58 | 'E.g., 0.1 = 10% of training.') 59 | group2.add_argument('-lstm_size', type=int, default=128, 60 | help='size of lstm units.') 61 | group2.add_argument('-num_layers', type=int, default=1, 62 | help='number of rnn layers, default is 1.') 63 | group2.add_argument('-cell', type=str, default='lstm', 64 | help='which rnn cell used.') 65 | group2.add_argument('-save_checkpoints_steps', type=int, default=500, 66 | help='save_checkpoints_steps') 67 | group2.add_argument('-save_summary_steps', type=int, default=500, 68 | help='save_summary_steps.') 69 | group2.add_argument('-filter_adam_var', type=bool, default=False, 70 | help='after training do filter Adam params from model and save no Adam params model in file.') 71 | group2.add_argument('-do_lower_case', type=bool, default=True, 72 | help='Whether to lower case the input text.') 73 | group2.add_argument('-clean', type=bool, default=True) 74 | group2.add_argument('-device_map', type=str, default='0', 75 | help='witch device using to train') 76 | 77 | # add labels 78 | group2.add_argument('-label_list', type=str, default=None, 79 | help='User define labels, can be a file with one label one line or a string using \',\' split') 80 | 81 | parser.add_argument('-verbose', action='store_true', default=False, 82 | help='turn on tensorflow logging for debug') 83 | parser.add_argument('-ner', type=str, default='ner', help='which modle to train') 84 | parser.add_argument('-version', action='version', version='%(prog)s ' + __version__) 85 | return parser.parse_args() 86 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | python setup.py sdist bdist_wheel 2 | python -m twine upload dist/* -------------------------------------------------------------------------------- /client_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @Time : 2019/1/29 14:32 6 | @Author : MaCan (ma_cancan@163.com) 7 | @File : client_test.py 8 | """ 9 | import time 10 | from bert_base.client import BertClient 11 | 12 | 13 | def ner_test(): 14 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='NER') as bc: 15 | start_t = time.perf_counter() 16 | str1 = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。' 17 | # rst = bc.encode([list(str1)], is_tokenized=True) 18 | # str1 = list(str1) 19 | rst = bc.encode([str1], is_tokenized=True) 20 | print('rst:', rst) 21 | print(len(rst[0])) 22 | print(time.perf_counter() - start_t) 23 | 24 | 25 | def ner_cu_seg(): 26 | """ 27 | 自定义分字 28 | :return: 29 | """ 30 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='NER') as bc: 31 | start_t = time.perf_counter() 32 | str1 = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。' 33 | rst = bc.encode([list(str1)], is_tokenized=True) 34 | print('rst:', rst) 35 | print(len(rst[0])) 36 | print(time.perf_counter() - start_t) 37 | 38 | 39 | def class_test(): 40 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='CLASS') as bc: 41 | start_t = time.perf_counter() 42 | str = '北京时间2月17日凌晨,第69届柏林国际电影节公布主竞赛单元获奖名单,王景春、咏梅凭借王小帅执导的中国影片《地久天长》连夺最佳男女演员双银熊大奖,这是中国演员首次包揽柏林电影节最佳男女演员奖,为华语影片刷新纪录。与此同时,由青年导演王丽娜执导的影片《第一次的别离》也荣获了本届柏林电影节新生代单元国际评审团最佳影片,可以说,在经历数个获奖小年之后,中国电影在柏林影展再次迎来了高光时刻。' 43 | str2 = '受粤港澳大湾区规划纲要提振,港股周二高开,恒指开盘上涨近百点,涨幅0.33%,报28440.49点,相关概念股亦集体上涨,电子元件、新能源车、保险、基建概念多数上涨。粤泰股份、珠江实业、深天地A等10余股涨停;中兴通讯、丘钛科技、舜宇光学分别高开1.4%、4.3%、1.6%。比亚迪电子、比亚迪股份、光宇国际分别高开1.7%、1.2%、1%。越秀交通基建涨近2%,粤海投资、碧桂园等多股涨超1%。其他方面,日本软银集团股价上涨超0.4%,推动日经225和东证指数齐齐高开,但随后均回吐涨幅转跌东证指数跌0.2%,日经225指数跌0.11%,报21258.4点。受芯片制造商SK海力士股价下跌1.34%拖累,韩国综指下跌0.34%至2203.9点。澳大利亚ASX 200指数早盘上涨0.39%至6089.8点,大多数行业板块均现涨势。在保健品品牌澳佳宝下调下半财年的销售预期后,其股价暴跌超过23%。澳佳宝CEO亨弗里(Richard Henfrey)认为,公司下半年的利润可能会低于上半年,主要是受到销售额疲弱的影响。同时,亚市早盘澳洲联储公布了2月会议纪要,政策委员将继续谨慎评估经济增长前景,因前景充满不确定性的影响,稳定当前的利率水平比贸然调整利率更为合适,而且当前利率水平将有利于趋向通胀目标及改善就业,当前劳动力市场数据表现强势于其他经济数据。另一方面,经济增长前景亦令消费者消费意愿下滑,如果房价出现下滑,消费可能会进一步疲弱。在澳洲联储公布会议纪要后,澳元兑美元下跌近30点,报0.7120 。美元指数在昨日触及96.65附近的低点之后反弹至96.904。日元兑美元报110.56,接近上一交易日的低点。' 44 | str3 = '新京报快讯 据国家市场监管总局消息,针对媒体报道水饺等猪肉制品检出非洲猪瘟病毒核酸阳性问题,市场监管总局、农业农村部已要求企业立即追溯猪肉原料来源并对猪肉制品进行了处置。两部门已派出联合督查组调查核实相关情况,要求猪肉制品生产企业进一步加强对猪肉原料的管控,落实检验检疫票证查验规定,完善非洲猪瘟检测和复核制度,防止染疫猪肉原料进入食品加工环节。市场监管总局、农业农村部等部门要求各地全面落实防控责任,强化防控措施,规范信息报告和发布,对不按要求履行防控责任的企业,一旦发现将严厉查处。专家认为,非洲猪瘟不是人畜共患病,虽然对猪有致命危险,但对人没有任何危害,属于只传猪不传人型病毒,不会影响食品安全。开展猪肉制品病毒核酸检测,可为防控溯源工作提供线索。' 45 | rst = bc.encode([str, str2, str3]) 46 | print('rst:', rst) 47 | print('time used:{}'.format(time.perf_counter() - start_t)) 48 | 49 | 50 | if __name__ == '__main__': 51 | # class_test() 52 | ner_test() 53 | ner_cu_seg() -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | 用于语料库的处理 5 | 1. 全部处理成小于max_seq_length的序列,这样可以避免解码出现不合法的数据或者在最后算结果的时候出现out of range 的错误。 6 | 7 | @Author: Macan 8 | """ 9 | 10 | 11 | import os 12 | import codecs 13 | import argparse 14 | 15 | def load_file(file_path): 16 | if not os.path.exists(file_path): 17 | return None 18 | with codecs.open(file_path, 'r', encoding='utf-8') as fd: 19 | for line in fd: 20 | yield line 21 | 22 | 23 | def _cut(sentence): 24 | new_sentence = [] 25 | sen = [] 26 | for i in sentence: 27 | if i.split(' ')[0] in ['。', '!', '?'] and len(sen) != 0: 28 | sen.append(i) 29 | new_sentence.append(sen) 30 | sen = [] 31 | continue 32 | sen.append(i) 33 | if len(new_sentence) == 1: #娄底那种一句话超过max_seq_length的且没有句号的,用,分割,再长的不考虑了。。。 34 | new_sentence = [] 35 | sen = [] 36 | for i in sentence: 37 | if i.split(' ')[0] in [','] and len(sen) != 0: 38 | sen.append(i) 39 | new_sentence.append(sen) 40 | sen = [] 41 | continue 42 | sen.append(i) 43 | return new_sentence 44 | 45 | 46 | def cut_sentence(file, max_seq_length): 47 | """ 48 | 句子截断 49 | :param file: 50 | :param max_seq_length: 51 | :return: 52 | """ 53 | context = [] 54 | sentence = [] 55 | cnt = 0 56 | for line in load_file(file): 57 | line = line.strip() 58 | if line == '' and len(sentence) != 0: 59 | # 判断这一句是否超过最大长度 60 | if len(sentence) > max_seq_length: 61 | sentence = _cut(sentence) 62 | context.extend(sentence) 63 | else: 64 | context.append(sentence) 65 | sentence = [] 66 | continue 67 | cnt += 1 68 | sentence.append(line) 69 | print('token cnt:{}'.format(cnt)) 70 | return context 71 | 72 | def write_to_file(file, context): 73 | # 首先将源文件改名为新文件名,避免覆盖 74 | os.rename(file, '{}.bak'.format(file)) 75 | with codecs.open(file, 'w', encoding='utf-8') as fd: 76 | for sen in context: 77 | for token in sen: 78 | fd.write(token + '\n') 79 | fd.write('\n') 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser(description='data pre process') 84 | parser.add_argument('--train_data', type=str, default='./NERdata/train.txt') 85 | parser.add_argument('--dev_data', type=str, default='./NERdata/dev.txt') 86 | parser.add_argument('--test_data', type=str, default='./NERdata/test.txt') 87 | parser.add_argument('--max_seq_length', type=int, default=126) 88 | args = parser.parse_args() 89 | 90 | print('cut train data to max sequence length:{}'.format(args.max_seq_length)) 91 | context = cut_sentence(args.train_data, args.max_seq_length) 92 | write_to_file(args.train_data, context) 93 | 94 | print('cut dev data to max sequence length:{}'.format(args.max_seq_length)) 95 | context = cut_sentence(args.dev_data, args.max_seq_length) 96 | write_to_file(args.dev_data, context) 97 | 98 | print('cut test data to max sequence length:{}'.format(args.max_seq_length)) 99 | context = cut_sentence(args.test_data, args.max_seq_length) 100 | write_to_file(args.test_data, context) -------------------------------------------------------------------------------- /pictures/03E18A6A9C16082CF22A9E8837F7E35F.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/03E18A6A9C16082CF22A9E8837F7E35F.png -------------------------------------------------------------------------------- /pictures/ner_help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/ner_help.png -------------------------------------------------------------------------------- /pictures/picture1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/picture1.png -------------------------------------------------------------------------------- /pictures/picture2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/picture2.png -------------------------------------------------------------------------------- /pictures/predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/predict.png -------------------------------------------------------------------------------- /pictures/server_help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/server_help.png -------------------------------------------------------------------------------- /pictures/server_ner_rst.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/server_ner_rst.png -------------------------------------------------------------------------------- /pictures/server_run.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/server_run.png -------------------------------------------------------------------------------- /pictures/service_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/service_1.png -------------------------------------------------------------------------------- /pictures/service_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/service_2.png -------------------------------------------------------------------------------- /pictures/text_class_rst.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/ccf3f093f0ac803e435cb8e8598fdddc2ba1105d/pictures/text_class_rst.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | # client-side requirements, pretty light-weight right? 2 | # tensorflow >= 1.12.0 3 | # tensorflow-gpu >= 1.12.0 # GPU version of TensorFlow. 4 | GPUtil >= 1.3.0 # no need if you dont have GPU 5 | pyzmq >= 17.1.0 # python zmq 6 | flask # no need if you do not need http 7 | flask_compress # no need if you do not need http 8 | flask_json # no need if you do not need http -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | 运行 BERT NER Server 6 | #@Time : 2019/1/26 21:00 7 | # @Author : MaCan (ma_cancan@163.com) 8 | # @File : run.py 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | def start_server(): 17 | from bert_base.server import BertServer 18 | from bert_base.server.helper import get_run_args 19 | 20 | args = get_run_args() 21 | print(args) 22 | server = BertServer(args) 23 | server.start() 24 | server.join() 25 | 26 | 27 | def train_ner(): 28 | import os 29 | from bert_base.train.train_helper import get_args_parser 30 | from bert_base.train.bert_lstm_ner import train 31 | 32 | args = get_args_parser() 33 | if True: 34 | import sys 35 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) 36 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) 37 | print(args) 38 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_map 39 | train(args=args) 40 | 41 | 42 | if __name__ == '__main__': 43 | """ 44 | 如果想训练,那么直接 指定参数跑,如果想启动服务,那么注释掉train,打开server即可 45 | """ 46 | train_ner() 47 | #start_server() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # encoding =utf-8 2 | 3 | from os import path 4 | import codecs 5 | from setuptools import setup, find_packages 6 | 7 | # setup metainfo 8 | # libinfo_py = 'bert_lstm_ner.py' 9 | # libinfo_content = open(libinfo_py, 'r', encoding='utf-8').readlines() 10 | # version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][0] 11 | # # exec(version_line) # produce __version__ 12 | # __version__ = version_line.split('=')[1].replace(' ', '') 13 | # print(__version__) 14 | setup( 15 | name='bert_base', 16 | version='0.0.9', 17 | description='Use Google\'s BERT for Chinese natural language processing tasks such as named entity recognition and provide server services', 18 | url='https://github.com/macanv/BERT-BiLSTM-CRF-NER', 19 | long_description=open('README.md', 'r', encoding='utf-8').read(), 20 | long_description_content_type='text/markdown', 21 | author='Ma Can', 22 | author_email='ma_cancan@163.com', 23 | license='MIT', 24 | packages=find_packages(), 25 | zip_safe=False, 26 | install_requires=[ 27 | 'numpy', 28 | 'six', 29 | 'pyzmq>=16.0.0', 30 | 'GPUtil>=1.3.0', 31 | 'termcolor>=1.1', 32 | ], 33 | extras_require={ 34 | 'cpu': ['tensorflow>=1.10.0'], 35 | 'gpu': ['tensorflow-gpu>=1.10.0'], 36 | 'http': ['flask', 'flask-compress', 'flask-cors', 'flask-json'] 37 | }, 38 | classifiers=( 39 | 'Programming Language :: Python :: 3.6', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Operating System :: OS Independent', 42 | #'Topic :: Scientific/Engineering :: Artificial Intelligence :: Natural Language Processing :: Named Entity Recognition', 43 | ), 44 | entry_points={ 45 | 'console_scripts': ['bert-base-serving-start=bert_base.runs:start_server', 46 | 'bert-base-ner-train=bert_base.runs:train_ner'], 47 | }, 48 | keywords='bert nlp ner NER named entity recognition bilstm crf tensorflow machine learning sentence encoding embedding serving', 49 | ) 50 | -------------------------------------------------------------------------------- /terminal_predict.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | 基于命令行的在线预测方法 5 | @Author: Macan (ma_cancan@163.com) 6 | """ 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | import codecs 11 | import pickle 12 | import os 13 | from datetime import datetime 14 | 15 | from bert_base.train.models import create_model, InputFeatures 16 | from bert_base.bert import tokenization, modeling 17 | from bert_base.train.train_helper import get_args_parser 18 | args = get_args_parser() 19 | 20 | model_dir = r'C:\Users\C\Documents\Tencent Files\389631699\FileRecv\semi_corpus_people_2014' 21 | bert_dir = 'F:\chinese_L-12_H-768_A-12' 22 | 23 | is_training=False 24 | use_one_hot_embeddings=False 25 | batch_size=1 26 | 27 | gpu_config = tf.ConfigProto() 28 | gpu_config.gpu_options.allow_growth = True 29 | sess=tf.Session(config=gpu_config) 30 | model=None 31 | 32 | global graph 33 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None 34 | 35 | 36 | print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint"))) 37 | if not os.path.exists(os.path.join(model_dir, "checkpoint")): 38 | raise Exception("failed to get checkpoint. going to return ") 39 | 40 | # 加载label->id的词典 41 | with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf: 42 | label2id = pickle.load(rf) 43 | id2label = {value: key for key, value in label2id.items()} 44 | 45 | with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf: 46 | label_list = pickle.load(rf) 47 | num_labels = len(label_list) + 1 48 | 49 | graph = tf.get_default_graph() 50 | with graph.as_default(): 51 | print("going to restore checkpoint") 52 | #sess.run(tf.global_variables_initializer()) 53 | input_ids_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length], name="input_ids") 54 | input_mask_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length], name="input_mask") 55 | 56 | bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json')) 57 | (total_loss, logits, trans, pred_ids) = create_model( 58 | bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, 59 | labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0) 60 | 61 | saver = tf.train.Saver() 62 | saver.restore(sess, tf.train.latest_checkpoint(model_dir)) 63 | 64 | 65 | tokenizer = tokenization.FullTokenizer( 66 | vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case) 67 | 68 | 69 | def predict_online(): 70 | """ 71 | do online prediction. each time make prediction for one instance. 72 | you can change to a batch if you want. 73 | 74 | :param line: a list. element is: [dummy_label,text_a,text_b] 75 | :return: 76 | """ 77 | def convert(line): 78 | feature = convert_single_example(0, line, label_list, args.max_seq_length, tokenizer, 'p') 79 | input_ids = np.reshape([feature.input_ids],(batch_size, args.max_seq_length)) 80 | input_mask = np.reshape([feature.input_mask],(batch_size, args.max_seq_length)) 81 | segment_ids = np.reshape([feature.segment_ids],(batch_size, args.max_seq_length)) 82 | label_ids =np.reshape([feature.label_ids],(batch_size, args.max_seq_length)) 83 | return input_ids, input_mask, segment_ids, label_ids 84 | 85 | global graph 86 | with graph.as_default(): 87 | print(id2label) 88 | while True: 89 | print('input the test sentence:') 90 | sentence = str(input()) 91 | start = datetime.now() 92 | if len(sentence) < 2: 93 | print(sentence) 94 | continue 95 | sentence = tokenizer.tokenize(sentence) 96 | # print('your input is:{}'.format(sentence)) 97 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 98 | 99 | feed_dict = {input_ids_p: input_ids, 100 | input_mask_p: input_mask} 101 | # run session get current feed_dict result 102 | pred_ids_result = sess.run([pred_ids], feed_dict) 103 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 104 | print(pred_label_result) 105 | #todo: 组合策略 106 | result = strage_combined_link_org_loc(sentence, pred_label_result[0]) 107 | print('time used: {} sec'.format((datetime.now() - start).total_seconds())) 108 | 109 | def convert_id_to_label(pred_ids_result, idx2label): 110 | """ 111 | 将id形式的结果转化为真实序列结果 112 | :param pred_ids_result: 113 | :param idx2label: 114 | :return: 115 | """ 116 | result = [] 117 | for row in range(batch_size): 118 | curr_seq = [] 119 | for ids in pred_ids_result[row][0]: 120 | if ids == 0: 121 | break 122 | curr_label = idx2label[ids] 123 | if curr_label in ['[CLS]', '[SEP]']: 124 | continue 125 | curr_seq.append(curr_label) 126 | result.append(curr_seq) 127 | return result 128 | 129 | 130 | 131 | def strage_combined_link_org_loc(tokens, tags): 132 | """ 133 | 组合策略 134 | :param pred_label_result: 135 | :param types: 136 | :return: 137 | """ 138 | def print_output(data, type): 139 | line = [] 140 | line.append(type) 141 | for i in data: 142 | line.append(i.word) 143 | print(', '.join(line)) 144 | 145 | params = None 146 | eval = Result(params) 147 | if len(tokens) > len(tags): 148 | tokens = tokens[:len(tags)] 149 | person, loc, org = eval.get_result(tokens, tags) 150 | print_output(loc, 'LOC') 151 | print_output(person, 'PER') 152 | print_output(org, 'ORG') 153 | 154 | 155 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): 156 | """ 157 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 158 | :param ex_index: index 159 | :param example: 一个样本 160 | :param label_list: 标签列表 161 | :param max_seq_length: 162 | :param tokenizer: 163 | :param mode: 164 | :return: 165 | """ 166 | label_map = {} 167 | # 1表示从1开始对label进行index化 168 | for (i, label) in enumerate(label_list, 1): 169 | label_map[label] = i 170 | # 保存label->index 的map 171 | if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')): 172 | with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w: 173 | pickle.dump(label_map, w) 174 | 175 | tokens = example 176 | # tokens = tokenizer.tokenize(example.text) 177 | # 序列截断 178 | if len(tokens) >= max_seq_length - 1: 179 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 180 | ntokens = [] 181 | segment_ids = [] 182 | label_ids = [] 183 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 184 | segment_ids.append(0) 185 | # append("O") or append("[CLS]") not sure! 186 | label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 187 | for i, token in enumerate(tokens): 188 | ntokens.append(token) 189 | segment_ids.append(0) 190 | label_ids.append(0) 191 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志 192 | segment_ids.append(0) 193 | # append("O") or append("[SEP]") not sure! 194 | label_ids.append(label_map["[SEP]"]) 195 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 196 | input_mask = [1] * len(input_ids) 197 | 198 | # padding, 使用 199 | while len(input_ids) < max_seq_length: 200 | input_ids.append(0) 201 | input_mask.append(0) 202 | segment_ids.append(0) 203 | # we don't concerned about it! 204 | label_ids.append(0) 205 | ntokens.append("**NULL**") 206 | # label_mask.append(0) 207 | # print(len(input_ids)) 208 | assert len(input_ids) == max_seq_length 209 | assert len(input_mask) == max_seq_length 210 | assert len(segment_ids) == max_seq_length 211 | assert len(label_ids) == max_seq_length 212 | # assert len(label_mask) == max_seq_length 213 | 214 | # 结构化为一个类 215 | feature = InputFeatures( 216 | input_ids=input_ids, 217 | input_mask=input_mask, 218 | segment_ids=segment_ids, 219 | label_ids=label_ids, 220 | # label_mask = label_mask 221 | ) 222 | return feature 223 | 224 | 225 | class Pair(object): 226 | def __init__(self, word, start, end, type, merge=False): 227 | self.__word = word 228 | self.__start = start 229 | self.__end = end 230 | self.__merge = merge 231 | self.__types = type 232 | 233 | @property 234 | def start(self): 235 | return self.__start 236 | @property 237 | def end(self): 238 | return self.__end 239 | @property 240 | def merge(self): 241 | return self.__merge 242 | @property 243 | def word(self): 244 | return self.__word 245 | 246 | @property 247 | def types(self): 248 | return self.__types 249 | @word.setter 250 | def word(self, word): 251 | self.__word = word 252 | @start.setter 253 | def start(self, start): 254 | self.__start = start 255 | @end.setter 256 | def end(self, end): 257 | self.__end = end 258 | @merge.setter 259 | def merge(self, merge): 260 | self.__merge = merge 261 | 262 | @types.setter 263 | def types(self, type): 264 | self.__types = type 265 | 266 | def __str__(self) -> str: 267 | line = [] 268 | line.append('entity:{}'.format(self.__word)) 269 | line.append('start:{}'.format(self.__start)) 270 | line.append('end:{}'.format(self.__end)) 271 | line.append('merge:{}'.format(self.__merge)) 272 | line.append('types:{}'.format(self.__types)) 273 | return '\t'.join(line) 274 | 275 | 276 | class Result(object): 277 | def __init__(self, config): 278 | self.config = config 279 | self.person = [] 280 | self.loc = [] 281 | self.org = [] 282 | self.others = [] 283 | def get_result(self, tokens, tags, config=None): 284 | # 先获取标注结果 285 | self.result_to_json(tokens, tags) 286 | return self.person, self.loc, self.org 287 | 288 | def result_to_json(self, string, tags): 289 | """ 290 | 将模型标注序列和输入序列结合 转化为结果 291 | :param string: 输入序列 292 | :param tags: 标注结果 293 | :return: 294 | """ 295 | item = {"entities": []} 296 | entity_name = "" 297 | entity_start = 0 298 | idx = 0 299 | last_tag = '' 300 | 301 | for char, tag in zip(string, tags): 302 | if tag[0] == "S": 303 | self.append(char, idx, idx+1, tag[2:]) 304 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]}) 305 | elif tag[0] == "B": 306 | if entity_name != '': 307 | self.append(entity_name, entity_start, idx, last_tag[2:]) 308 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 309 | entity_name = "" 310 | entity_name += char 311 | entity_start = idx 312 | elif tag[0] == "I": 313 | entity_name += char 314 | elif tag[0] == "O": 315 | if entity_name != '': 316 | self.append(entity_name, entity_start, idx, last_tag[2:]) 317 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 318 | entity_name = "" 319 | else: 320 | entity_name = "" 321 | entity_start = idx 322 | idx += 1 323 | last_tag = tag 324 | if entity_name != '': 325 | self.append(entity_name, entity_start, idx, last_tag[2:]) 326 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 327 | return item 328 | 329 | def append(self, word, start, end, tag): 330 | if tag == 'LOC': 331 | self.loc.append(Pair(word, start, end, 'LOC')) 332 | elif tag == 'PER': 333 | self.person.append(Pair(word, start, end, 'PER')) 334 | elif tag == 'ORG': 335 | self.org.append(Pair(word, start, end, 'ORG')) 336 | else: 337 | self.others.append(Pair(word, start, end, tag)) 338 | 339 | 340 | if __name__ == "__main__": 341 | predict_online() 342 | 343 | --------------------------------------------------------------------------------