├── bert_base
├── bert
│ ├── __init__.py
│ ├── requirements.txt
│ ├── CONTRIBUTING.md
│ ├── optimization_test.py
│ ├── sample_text.txt
│ ├── tokenization_test.py
│ ├── optimization.py
│ ├── modeling_test.py
│ ├── LICENSE
│ ├── multilingual.md
│ ├── tokenization.py
│ └── create_pretraining_data.py
├── __init__.py
├── train
│ ├── __init__.py
│ ├── train_helper.py
│ ├── lstm_crf_layer.py
│ ├── tf_metrics.py
│ ├── models.py
│ ├── conlleval.py
│ └── conlleval.pl
├── runs
│ └── __init__.py
└── server
│ ├── zmq_decor.py
│ ├── http.py
│ ├── simple_flask_http_service.py
│ ├── helper.py
│ └── graph.py
├── pictures
├── predict.png
├── ner_help.png
├── picture1.png
├── picture2.png
├── server_run.png
├── service_1.png
├── service_2.png
├── server_help.png
├── server_ner_rst.png
├── text_class_rst.png
└── 03E18A6A9C16082CF22A9E8837F7E35F.png
├── requirement.txt
├── run.py
├── setup.py
├── client_test.py
├── data_process.py
├── terminal_predict.py
└── README.md
/bert_base/bert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pictures/predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/predict.png
--------------------------------------------------------------------------------
/pictures/ner_help.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/ner_help.png
--------------------------------------------------------------------------------
/pictures/picture1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/picture1.png
--------------------------------------------------------------------------------
/pictures/picture2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/picture2.png
--------------------------------------------------------------------------------
/pictures/server_run.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/server_run.png
--------------------------------------------------------------------------------
/pictures/service_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/service_1.png
--------------------------------------------------------------------------------
/pictures/service_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/service_2.png
--------------------------------------------------------------------------------
/pictures/server_help.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/server_help.png
--------------------------------------------------------------------------------
/pictures/server_ner_rst.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/server_ner_rst.png
--------------------------------------------------------------------------------
/pictures/text_class_rst.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/text_class_rst.png
--------------------------------------------------------------------------------
/bert_base/bert/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow.
2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow.
3 |
--------------------------------------------------------------------------------
/pictures/03E18A6A9C16082CF22A9E8837F7E35F.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/macanv/BERT-BiLSTM-CRF-NER/HEAD/pictures/03E18A6A9C16082CF22A9E8837F7E35F.png
--------------------------------------------------------------------------------
/bert_base/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 |
5 | @Time : 2019/1/30 19:09
6 | @Author : MaCan (ma_cancan@163.com)
7 | @File : __init__.py.py
8 | """
--------------------------------------------------------------------------------
/bert_base/train/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 |
5 | @Time : 2019/1/30 16:53
6 | @Author : MaCan (ma_cancan@163.com)
7 | @File : __init__.py.py
8 | """
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | # client-side requirements, pretty light-weight right?
2 | # tensorflow >= 1.12.0
3 | # tensorflow-gpu >= 1.12.0 # GPU version of TensorFlow.
4 | GPUtil >= 1.3.0 # no need if you dont have GPU
5 | pyzmq >= 17.1.0 # python zmq
6 | flask # no need if you do not need http
7 | flask_compress # no need if you do not need http
8 | flask_json # no need if you do not need http
--------------------------------------------------------------------------------
/bert_base/runs/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 |
5 | @Time : 2019/1/30 16:47
6 | @Author : MaCan (ma_cancan@163.com)
7 | @File : __init__.py.py
8 | """
9 |
10 |
11 | def start_server():
12 | from bert_base.server import BertServer
13 | from bert_base.server.helper import get_run_args
14 |
15 | args = get_run_args()
16 | # print(args)
17 | server = BertServer(args)
18 | server.start()
19 | server.join()
20 |
21 |
22 | def start_client():
23 | pass
24 |
25 |
26 | def train_ner():
27 | import os
28 | from bert_base.train.train_helper import get_args_parser
29 | from bert_base.train.bert_lstm_ner import train
30 |
31 | args = get_args_parser()
32 | if True:
33 | import sys
34 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())])
35 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str))
36 | # print(args)
37 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_map
38 | train(args=args)
39 |
40 | # if __name__ == '__main__':
41 | # # start_server()
42 | # train_ner()
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | 运行 BERT NER Server
6 | #@Time : 2019/1/26 21:00
7 | # @Author : MaCan (ma_cancan@163.com)
8 | # @File : run.py
9 | """
10 |
11 | from __future__ import absolute_import
12 | from __future__ import division
13 | from __future__ import print_function
14 |
15 |
16 | def start_server():
17 | from bert_base.server import BertServer
18 | from bert_base.server.helper import get_run_args
19 |
20 | args = get_run_args()
21 | print(args)
22 | server = BertServer(args)
23 | server.start()
24 | server.join()
25 |
26 |
27 | def train_ner():
28 | import os
29 | from bert_base.train.train_helper import get_args_parser
30 | from bert_base.train.bert_lstm_ner import train
31 |
32 | args = get_args_parser()
33 | if True:
34 | import sys
35 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())])
36 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str))
37 | print(args)
38 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device_map
39 | train(args=args)
40 |
41 |
42 | if __name__ == '__main__':
43 | """
44 | 如果想训练,那么直接 指定参数跑,如果想启动服务,那么注释掉train,打开server即可
45 | """
46 | train_ner()
47 | #start_server()
--------------------------------------------------------------------------------
/bert_base/bert/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | BERT needs to maintain permanent compatibility with the pre-trained model files,
4 | so we do not plan to make any major changes to this library (other than what was
5 | promised in the README). However, we can accept small patches related to
6 | re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/bert_base/bert/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import optimization
20 | import tensorflow as tf
21 |
22 |
23 | class OptimizationTest(tf.test.TestCase):
24 |
25 | def test_adam(self):
26 | with self.test_session() as sess:
27 | w = tf.get_variable(
28 | "w",
29 | shape=[3],
30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
31 | x = tf.constant([0.4, 0.2, -0.5])
32 | loss = tf.reduce_mean(tf.square(x - w))
33 | tvars = tf.trainable_variables()
34 | grads = tf.gradients(loss, tvars)
35 | global_step = tf.train.get_or_create_global_step()
36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
38 | init_op = tf.group(tf.global_variables_initializer(),
39 | tf.local_variables_initializer())
40 | sess.run(init_op)
41 | for _ in range(100):
42 | sess.run(train_op)
43 | w_np = sess.run(w)
44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # encoding =utf-8
2 |
3 | from os import path
4 | import codecs
5 | from setuptools import setup, find_packages
6 |
7 | # setup metainfo
8 | # libinfo_py = 'bert_lstm_ner.py'
9 | # libinfo_content = open(libinfo_py, 'r', encoding='utf-8').readlines()
10 | # version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][0]
11 | # # exec(version_line) # produce __version__
12 | # __version__ = version_line.split('=')[1].replace(' ', '')
13 | # print(__version__)
14 | setup(
15 | name='bert_base',
16 | version='0.0.9',
17 | description='Use Google\'s BERT for Chinese natural language processing tasks such as named entity recognition and provide server services',
18 | url='https://github.com/macanv/BERT-BiLSTM-CRF-NER',
19 | long_description=open('README.md', 'r', encoding='utf-8').read(),
20 | long_description_content_type='text/markdown',
21 | author='Ma Can',
22 | author_email='ma_cancan@163.com',
23 | license='MIT',
24 | packages=find_packages(),
25 | zip_safe=False,
26 | install_requires=[
27 | 'numpy',
28 | 'six',
29 | 'pyzmq>=16.0.0',
30 | 'GPUtil>=1.3.0',
31 | 'termcolor>=1.1',
32 | ],
33 | extras_require={
34 | 'cpu': ['tensorflow>=1.10.0'],
35 | 'gpu': ['tensorflow-gpu>=1.10.0'],
36 | 'http': ['flask', 'flask-compress', 'flask-cors', 'flask-json']
37 | },
38 | classifiers=(
39 | 'Programming Language :: Python :: 3.6',
40 | 'License :: OSI Approved :: MIT License',
41 | 'Operating System :: OS Independent',
42 | #'Topic :: Scientific/Engineering :: Artificial Intelligence :: Natural Language Processing :: Named Entity Recognition',
43 | ),
44 | entry_points={
45 | 'console_scripts': ['bert-base-serving-start=bert_base.runs:start_server',
46 | 'bert-base-ner-train=bert_base.runs:train_ner'],
47 | },
48 | keywords='bert nlp ner NER named entity recognition bilstm crf tensorflow machine learning sentence encoding embedding serving',
49 | )
50 |
--------------------------------------------------------------------------------
/bert_base/server/zmq_decor.py:
--------------------------------------------------------------------------------
1 | from contextlib import ExitStack
2 |
3 | from zmq.decorators import _Decorator
4 |
5 | __all__ = ['multi_socket']
6 |
7 | from functools import wraps
8 |
9 | import zmq
10 |
11 |
12 | class _MyDecorator(_Decorator):
13 | def __call__(self, *dec_args, **dec_kwargs):
14 | kw_name, dec_args, dec_kwargs = self.process_decorator_args(*dec_args, **dec_kwargs)
15 | num_socket_str = dec_kwargs.pop('num_socket')
16 |
17 | def decorator(func):
18 | @wraps(func)
19 | def wrapper(*args, **kwargs):
20 | num_socket = getattr(args[0], num_socket_str)
21 | targets = [self.get_target(*args, **kwargs) for _ in range(num_socket)]
22 | with ExitStack() as stack:
23 | for target in targets:
24 | obj = stack.enter_context(target(*dec_args, **dec_kwargs))
25 | args = args + (obj,)
26 |
27 | return func(*args, **kwargs)
28 |
29 | return wrapper
30 |
31 | return decorator
32 |
33 |
34 | class _SocketDecorator(_MyDecorator):
35 | def process_decorator_args(self, *args, **kwargs):
36 | """Also grab context_name out of kwargs"""
37 | kw_name, args, kwargs = super(_SocketDecorator, self).process_decorator_args(*args, **kwargs)
38 | self.context_name = kwargs.pop('context_name', 'context')
39 | return kw_name, args, kwargs
40 |
41 | def get_target(self, *args, **kwargs):
42 | """Get context, based on call-time args"""
43 | context = self._get_context(*args, **kwargs)
44 | return context.socket
45 |
46 | def _get_context(self, *args, **kwargs):
47 | if self.context_name in kwargs:
48 | ctx = kwargs[self.context_name]
49 |
50 | if isinstance(ctx, zmq.Context):
51 | return ctx
52 |
53 | for arg in args:
54 | if isinstance(arg, zmq.Context):
55 | return arg
56 | # not specified by any decorator
57 | return zmq.Context.instance()
58 |
59 |
60 | def multi_socket(*args, **kwargs):
61 | return _SocketDecorator()(*args, **kwargs)
62 |
--------------------------------------------------------------------------------
/bert_base/server/http.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Process
2 |
3 | from termcolor import colored
4 |
5 | from .helper import set_logger
6 |
7 |
8 | class BertHTTPProxy(Process):
9 | def __init__(self, args):
10 | super().__init__()
11 | self.args = args
12 |
13 | def create_flask_app(self):
14 | try:
15 | from flask import Flask, request
16 | from flask_compress import Compress
17 | from flask_cors import CORS
18 | from flask_json import FlaskJSON, as_json, JsonError
19 | from bert_base.client import ConcurrentBertClient
20 | except ImportError:
21 | raise ImportError('BertClient or Flask or its dependencies are not fully installed, '
22 | 'they are required for serving HTTP requests.'
23 | 'Please use "pip install -U bert-serving-server[http]" to install it.')
24 |
25 | # support up to 10 concurrent HTTP requests
26 | bc = ConcurrentBertClient(max_concurrency=self.args.http_max_connect,
27 | port=self.args.port, port_out=self.args.port_out,
28 | output_fmt='list', mode=self.args.mode)
29 | app = Flask(__name__)
30 | logger = set_logger(colored('PROXY', 'red'))
31 |
32 | @app.route('/status/server', methods=['GET'])
33 | @as_json
34 | def get_server_status():
35 | return bc.server_status
36 |
37 | @app.route('/status/client', methods=['GET'])
38 | @as_json
39 | def get_client_status():
40 | return bc.status
41 |
42 | @app.route('/encode', methods=['POST'])
43 | @as_json
44 | def encode_query():
45 | data = request.form if request.form else request.json
46 | try:
47 | logger.info('new request from %s' % request.remote_addr)
48 | print(data)
49 | return {'id': data['id'],
50 | 'result': bc.encode(data['texts'], is_tokenized=bool(
51 | data['is_tokenized']) if 'is_tokenized' in data else False)}
52 |
53 | except Exception as e:
54 | logger.error('error when handling HTTP request', exc_info=True)
55 | raise JsonError(description=str(e), type=str(type(e).__name__))
56 |
57 | CORS(app, origins=self.args.cors)
58 | FlaskJSON(app)
59 | Compress().init_app(app)
60 | return app
61 |
62 | def run(self):
63 | app = self.create_flask_app()
64 | app.run(port=self.args.http_port, threaded=True, host='0.0.0.0')
65 |
--------------------------------------------------------------------------------
/client_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 |
5 | @Time : 2019/1/29 14:32
6 | @Author : MaCan (ma_cancan@163.com)
7 | @File : client_test.py
8 | """
9 | import time
10 | from bert_base.client import BertClient
11 |
12 |
13 | def ner_test():
14 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='NER') as bc:
15 | start_t = time.perf_counter()
16 | str1 = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。'
17 | # rst = bc.encode([list(str1)], is_tokenized=True)
18 | # str1 = list(str1)
19 | rst = bc.encode([str1], is_tokenized=True)
20 | print('rst:', rst)
21 | print(len(rst[0]))
22 | print(time.perf_counter() - start_t)
23 |
24 |
25 | def ner_cu_seg():
26 | """
27 | 自定义分字
28 | :return:
29 | """
30 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='NER') as bc:
31 | start_t = time.perf_counter()
32 | str1 = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。'
33 | rst = bc.encode([list(str1)], is_tokenized=True)
34 | print('rst:', rst)
35 | print(len(rst[0]))
36 | print(time.perf_counter() - start_t)
37 |
38 |
39 | def class_test():
40 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='CLASS') as bc:
41 | start_t = time.perf_counter()
42 | str = '北京时间2月17日凌晨,第69届柏林国际电影节公布主竞赛单元获奖名单,王景春、咏梅凭借王小帅执导的中国影片《地久天长》连夺最佳男女演员双银熊大奖,这是中国演员首次包揽柏林电影节最佳男女演员奖,为华语影片刷新纪录。与此同时,由青年导演王丽娜执导的影片《第一次的别离》也荣获了本届柏林电影节新生代单元国际评审团最佳影片,可以说,在经历数个获奖小年之后,中国电影在柏林影展再次迎来了高光时刻。'
43 | str2 = '受粤港澳大湾区规划纲要提振,港股周二高开,恒指开盘上涨近百点,涨幅0.33%,报28440.49点,相关概念股亦集体上涨,电子元件、新能源车、保险、基建概念多数上涨。粤泰股份、珠江实业、深天地A等10余股涨停;中兴通讯、丘钛科技、舜宇光学分别高开1.4%、4.3%、1.6%。比亚迪电子、比亚迪股份、光宇国际分别高开1.7%、1.2%、1%。越秀交通基建涨近2%,粤海投资、碧桂园等多股涨超1%。其他方面,日本软银集团股价上涨超0.4%,推动日经225和东证指数齐齐高开,但随后均回吐涨幅转跌东证指数跌0.2%,日经225指数跌0.11%,报21258.4点。受芯片制造商SK海力士股价下跌1.34%拖累,韩国综指下跌0.34%至2203.9点。澳大利亚ASX 200指数早盘上涨0.39%至6089.8点,大多数行业板块均现涨势。在保健品品牌澳佳宝下调下半财年的销售预期后,其股价暴跌超过23%。澳佳宝CEO亨弗里(Richard Henfrey)认为,公司下半年的利润可能会低于上半年,主要是受到销售额疲弱的影响。同时,亚市早盘澳洲联储公布了2月会议纪要,政策委员将继续谨慎评估经济增长前景,因前景充满不确定性的影响,稳定当前的利率水平比贸然调整利率更为合适,而且当前利率水平将有利于趋向通胀目标及改善就业,当前劳动力市场数据表现强势于其他经济数据。另一方面,经济增长前景亦令消费者消费意愿下滑,如果房价出现下滑,消费可能会进一步疲弱。在澳洲联储公布会议纪要后,澳元兑美元下跌近30点,报0.7120 。美元指数在昨日触及96.65附近的低点之后反弹至96.904。日元兑美元报110.56,接近上一交易日的低点。'
44 | str3 = '新京报快讯 据国家市场监管总局消息,针对媒体报道水饺等猪肉制品检出非洲猪瘟病毒核酸阳性问题,市场监管总局、农业农村部已要求企业立即追溯猪肉原料来源并对猪肉制品进行了处置。两部门已派出联合督查组调查核实相关情况,要求猪肉制品生产企业进一步加强对猪肉原料的管控,落实检验检疫票证查验规定,完善非洲猪瘟检测和复核制度,防止染疫猪肉原料进入食品加工环节。市场监管总局、农业农村部等部门要求各地全面落实防控责任,强化防控措施,规范信息报告和发布,对不按要求履行防控责任的企业,一旦发现将严厉查处。专家认为,非洲猪瘟不是人畜共患病,虽然对猪有致命危险,但对人没有任何危害,属于只传猪不传人型病毒,不会影响食品安全。开展猪肉制品病毒核酸检测,可为防控溯源工作提供线索。'
45 | rst = bc.encode([str, str2, str3])
46 | print('rst:', rst)
47 | print('time used:{}'.format(time.perf_counter() - start_t))
48 |
49 |
50 | if __name__ == '__main__':
51 | # class_test()
52 | ner_test()
53 | ner_cu_seg()
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 | """
4 | 用于语料库的处理
5 | 1. 全部处理成小于max_seq_length的序列,这样可以避免解码出现不合法的数据或者在最后算结果的时候出现out of range 的错误。
6 |
7 | @Author: Macan
8 | """
9 |
10 |
11 | import os
12 | import codecs
13 | import argparse
14 |
15 | def load_file(file_path):
16 | if not os.path.exists(file_path):
17 | return None
18 | with codecs.open(file_path, 'r', encoding='utf-8') as fd:
19 | for line in fd:
20 | yield line
21 |
22 |
23 | def _cut(sentence):
24 | new_sentence = []
25 | sen = []
26 | for i in sentence:
27 | if i.split(' ')[0] in ['。', '!', '?'] and len(sen) != 0:
28 | sen.append(i)
29 | new_sentence.append(sen)
30 | sen = []
31 | continue
32 | sen.append(i)
33 | if len(new_sentence) == 1: #娄底那种一句话超过max_seq_length的且没有句号的,用,分割,再长的不考虑了。。。
34 | new_sentence = []
35 | sen = []
36 | for i in sentence:
37 | if i.split(' ')[0] in [','] and len(sen) != 0:
38 | sen.append(i)
39 | new_sentence.append(sen)
40 | sen = []
41 | continue
42 | sen.append(i)
43 | return new_sentence
44 |
45 |
46 | def cut_sentence(file, max_seq_length):
47 | """
48 | 句子截断
49 | :param file:
50 | :param max_seq_length:
51 | :return:
52 | """
53 | context = []
54 | sentence = []
55 | cnt = 0
56 | for line in load_file(file):
57 | line = line.strip()
58 | if line == '' and len(sentence) != 0:
59 | # 判断这一句是否超过最大长度
60 | if len(sentence) > max_seq_length:
61 | sentence = _cut(sentence)
62 | context.extend(sentence)
63 | else:
64 | context.append(sentence)
65 | sentence = []
66 | continue
67 | cnt += 1
68 | sentence.append(line)
69 | print('token cnt:{}'.format(cnt))
70 | return context
71 |
72 | def write_to_file(file, context):
73 | # 首先将源文件改名为新文件名,避免覆盖
74 | os.rename(file, '{}.bak'.format(file))
75 | with codecs.open(file, 'w', encoding='utf-8') as fd:
76 | for sen in context:
77 | for token in sen:
78 | fd.write(token + '\n')
79 | fd.write('\n')
80 |
81 |
82 | if __name__ == '__main__':
83 | parser = argparse.ArgumentParser(description='data pre process')
84 | parser.add_argument('--train_data', type=str, default='./NERdata/train.txt')
85 | parser.add_argument('--dev_data', type=str, default='./NERdata/dev.txt')
86 | parser.add_argument('--test_data', type=str, default='./NERdata/test.txt')
87 | parser.add_argument('--max_seq_length', type=int, default=126)
88 | args = parser.parse_args()
89 |
90 | print('cut train data to max sequence length:{}'.format(args.max_seq_length))
91 | context = cut_sentence(args.train_data, args.max_seq_length)
92 | write_to_file(args.train_data, context)
93 |
94 | print('cut dev data to max sequence length:{}'.format(args.max_seq_length))
95 | context = cut_sentence(args.dev_data, args.max_seq_length)
96 | write_to_file(args.dev_data, context)
97 |
98 | print('cut test data to max sequence length:{}'.format(args.max_seq_length))
99 | context = cut_sentence(args.test_data, args.max_seq_length)
100 | write_to_file(args.test_data, context)
--------------------------------------------------------------------------------
/bert_base/bert/sample_text.txt:
--------------------------------------------------------------------------------
1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
2 | Text should be one-sentence-per-line, with empty lines between documents.
3 | This sample text is public domain and was randomly selected from Project Guttenberg.
4 |
5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
8 | "Cass" Beard had risen early that morning, but not with a view to discovery.
9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
11 | This was nearly opposite.
12 | Mr. Cassius crossed the highway, and stopped suddenly.
13 | Something glittered in the nearest red pool before him.
14 | Gold, surely!
15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
17 | Like most of his fellow gold-seekers, Cass was superstitious.
18 |
19 | The fountain of classic wisdom, Hypatia herself.
20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
21 | From my youth I felt in me a soul above the matter-entangled herd.
22 | She revealed to me the glorious fact, that I am a spark of Divinity itself.
23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
24 | There is a philosophic pleasure in opening one's treasures to the modest young.
25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
31 | At last they reached the quay at the opposite end of the street;
32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
34 |
--------------------------------------------------------------------------------
/bert_base/train/train_helper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 |
5 | @Time : 2019/1/30 14:01
6 | @Author : MaCan (ma_cancan@163.com)
7 | @File : train_helper.py
8 | """
9 |
10 | import argparse
11 | import os
12 |
13 | __all__ = ['get_args_parser']
14 |
15 | def get_args_parser():
16 | from .bert_lstm_ner import __version__
17 | parser = argparse.ArgumentParser()
18 | if os.name == 'nt':
19 | bert_path = 'F:\chinese_L-12_H-768_A-12'
20 | root_path = r'C:\workspace\python\BERT-BiLSTM-CRF-NER'
21 | else:
22 | bert_path = '/home/macan/ml/data/chinese_L-12_H-768_A-12/'
23 | root_path = '/home/macan/ml/workspace/BERT-BiLSTM-CRF-NER'
24 |
25 | group1 = parser.add_argument_group('File Paths',
26 | 'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model')
27 | group1.add_argument('-data_dir', type=str, default=os.path.join(root_path, 'NERdata'),
28 | help='train, dev and test data dir')
29 | group1.add_argument('-bert_config_file', type=str, default=os.path.join(bert_path, 'bert_config.json'))
30 | group1.add_argument('-output_dir', type=str, default=os.path.join(root_path, 'output'),
31 | help='directory of a pretrained BERT model')
32 | group1.add_argument('-init_checkpoint', type=str, default=os.path.join(bert_path, 'bert_model.ckpt'),
33 | help='Initial checkpoint (usually from a pre-trained BERT model).')
34 | group1.add_argument('-vocab_file', type=str, default=os.path.join(bert_path, 'vocab.txt'),
35 | help='')
36 |
37 | group2 = parser.add_argument_group('Model Config', 'config the model params')
38 | group2.add_argument('-max_seq_length', type=int, default=202,
39 | help='The maximum total input sequence length after WordPiece tokenization.')
40 | group2.add_argument('-do_train', action='store_false', default=True,
41 | help='Whether to run training.')
42 | group2.add_argument('-do_eval', action='store_false', default=True,
43 | help='Whether to run eval on the dev set.')
44 | group2.add_argument('-do_predict', action='store_false', default=True,
45 | help='Whether to run the predict in inference mode on the test set.')
46 | group2.add_argument('-batch_size', type=int, default=64,
47 | help='Total batch size for training, eval and predict.')
48 | group2.add_argument('-learning_rate', type=float, default=1e-5,
49 | help='The initial learning rate for Adam.')
50 | group2.add_argument('-num_train_epochs', type=float, default=10,
51 | help='Total number of training epochs to perform.')
52 | group2.add_argument('-dropout_rate', type=float, default=0.5,
53 | help='Dropout rate')
54 | group2.add_argument('-clip', type=float, default=0.5,
55 | help='Gradient clip')
56 | group2.add_argument('-warmup_proportion', type=float, default=0.1,
57 | help='Proportion of training to perform linear learning rate warmup for '
58 | 'E.g., 0.1 = 10% of training.')
59 | group2.add_argument('-lstm_size', type=int, default=128,
60 | help='size of lstm units.')
61 | group2.add_argument('-num_layers', type=int, default=1,
62 | help='number of rnn layers, default is 1.')
63 | group2.add_argument('-cell', type=str, default='lstm',
64 | help='which rnn cell used.')
65 | group2.add_argument('-save_checkpoints_steps', type=int, default=500,
66 | help='save_checkpoints_steps')
67 | group2.add_argument('-save_summary_steps', type=int, default=500,
68 | help='save_summary_steps.')
69 | group2.add_argument('-filter_adam_var', type=bool, default=False,
70 | help='after training do filter Adam params from model and save no Adam params model in file.')
71 | group2.add_argument('-do_lower_case', type=bool, default=True,
72 | help='Whether to lower case the input text.')
73 | group2.add_argument('-clean', type=bool, default=True)
74 | group2.add_argument('-device_map', type=str, default='0',
75 | help='witch device using to train')
76 |
77 | # add labels
78 | group2.add_argument('-label_list', type=str, default=None,
79 | help='User define labels, can be a file with one label one line or a string using \',\' split')
80 |
81 | parser.add_argument('-verbose', action='store_true', default=False,
82 | help='turn on tensorflow logging for debug')
83 | parser.add_argument('-ner', type=str, default='ner', help='which modle to train')
84 | parser.add_argument('-version', action='version', version='%(prog)s ' + __version__)
85 | return parser.parse_args()
86 |
--------------------------------------------------------------------------------
/bert_base/bert/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import tempfile
21 |
22 | import tokenization
23 | import tensorflow as tf
24 |
25 |
26 | class TokenizationTest(tf.test.TestCase):
27 |
28 | def test_full_tokenizer(self):
29 | vocab_tokens = [
30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
31 | "##ing", ","
32 | ]
33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
35 |
36 | vocab_file = vocab_writer.name
37 |
38 | tokenizer = tokenization.FullTokenizer(vocab_file)
39 | os.unlink(vocab_file)
40 |
41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
42 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
43 |
44 | self.assertAllEqual(
45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
46 |
47 | def test_chinese(self):
48 | tokenizer = tokenization.BasicTokenizer()
49 |
50 | self.assertAllEqual(
51 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
52 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
53 |
54 | def test_basic_tokenizer_lower(self):
55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
56 |
57 | self.assertAllEqual(
58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
59 | ["hello", "!", "how", "are", "you", "?"])
60 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
61 |
62 | def test_basic_tokenizer_no_lower(self):
63 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
64 |
65 | self.assertAllEqual(
66 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
67 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
68 |
69 | def test_wordpiece_tokenizer(self):
70 | vocab_tokens = [
71 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
72 | "##ing"
73 | ]
74 |
75 | vocab = {}
76 | for (i, token) in enumerate(vocab_tokens):
77 | vocab[token] = i
78 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
79 |
80 | self.assertAllEqual(tokenizer.tokenize(""), [])
81 |
82 | self.assertAllEqual(
83 | tokenizer.tokenize("unwanted running"),
84 | ["un", "##want", "##ed", "runn", "##ing"])
85 |
86 | self.assertAllEqual(
87 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
88 |
89 | def test_convert_tokens_to_ids(self):
90 | vocab_tokens = [
91 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
92 | "##ing"
93 | ]
94 |
95 | vocab = {}
96 | for (i, token) in enumerate(vocab_tokens):
97 | vocab[token] = i
98 |
99 | self.assertAllEqual(
100 | tokenization.convert_tokens_to_ids(
101 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
102 |
103 | def test_is_whitespace(self):
104 | self.assertTrue(tokenization._is_whitespace(u" "))
105 | self.assertTrue(tokenization._is_whitespace(u"\t"))
106 | self.assertTrue(tokenization._is_whitespace(u"\r"))
107 | self.assertTrue(tokenization._is_whitespace(u"\n"))
108 | self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
109 |
110 | self.assertFalse(tokenization._is_whitespace(u"A"))
111 | self.assertFalse(tokenization._is_whitespace(u"-"))
112 |
113 | def test_is_control(self):
114 | self.assertTrue(tokenization._is_control(u"\u0005"))
115 |
116 | self.assertFalse(tokenization._is_control(u"A"))
117 | self.assertFalse(tokenization._is_control(u" "))
118 | self.assertFalse(tokenization._is_control(u"\t"))
119 | self.assertFalse(tokenization._is_control(u"\r"))
120 |
121 | def test_is_punctuation(self):
122 | self.assertTrue(tokenization._is_punctuation(u"-"))
123 | self.assertTrue(tokenization._is_punctuation(u"$"))
124 | self.assertTrue(tokenization._is_punctuation(u"`"))
125 | self.assertTrue(tokenization._is_punctuation(u"."))
126 |
127 | self.assertFalse(tokenization._is_punctuation(u"A"))
128 | self.assertFalse(tokenization._is_punctuation(u" "))
129 |
130 |
131 | if __name__ == "__main__":
132 | tf.test.main()
133 |
--------------------------------------------------------------------------------
/bert_base/bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | new_global_step = global_step + 1
80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
81 | return train_op
82 |
83 |
84 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
85 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
86 |
87 | def __init__(self,
88 | learning_rate,
89 | weight_decay_rate=0.0,
90 | beta_1=0.9,
91 | beta_2=0.999,
92 | epsilon=1e-6,
93 | exclude_from_weight_decay=None,
94 | name="AdamWeightDecayOptimizer"):
95 | """Constructs a AdamWeightDecayOptimizer."""
96 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
97 |
98 | self.learning_rate = learning_rate
99 | self.weight_decay_rate = weight_decay_rate
100 | self.beta_1 = beta_1
101 | self.beta_2 = beta_2
102 | self.epsilon = epsilon
103 | self.exclude_from_weight_decay = exclude_from_weight_decay
104 |
105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
106 | """See base class."""
107 | assignments = []
108 | for (grad, param) in grads_and_vars:
109 | if grad is None or param is None:
110 | continue
111 |
112 | param_name = self._get_variable_name(param.name)
113 |
114 | m = tf.get_variable(
115 | name=param_name + "/adam_m",
116 | shape=param.shape.as_list(),
117 | dtype=tf.float32,
118 | trainable=False,
119 | initializer=tf.zeros_initializer())
120 | v = tf.get_variable(
121 | name=param_name + "/adam_v",
122 | shape=param.shape.as_list(),
123 | dtype=tf.float32,
124 | trainable=False,
125 | initializer=tf.zeros_initializer())
126 |
127 | # Standard Adam update.
128 | next_m = (
129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
130 | next_v = (
131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
132 | tf.square(grad)))
133 |
134 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
135 |
136 | # Just adding the square of the weights to the loss function is *not*
137 | # the correct way of using L2 regularization/weight decay with Adam,
138 | # since that will interact with the m and v parameters in strange ways.
139 | #
140 | # Instead we want ot decay the weights in a manner that doesn't interact
141 | # with the m/v parameters. This is equivalent to adding the square
142 | # of the weights to the loss with plain (non-momentum) SGD.
143 | if self._do_use_weight_decay(param_name):
144 | update += self.weight_decay_rate * param
145 |
146 | update_with_lr = self.learning_rate * update
147 |
148 | next_param = param - update_with_lr
149 |
150 | assignments.extend(
151 | [param.assign(next_param),
152 | m.assign(next_m),
153 | v.assign(next_v)])
154 | return tf.group(*assignments, name=name)
155 |
156 | def _do_use_weight_decay(self, param_name):
157 | """Whether to use L2 weight decay for `param_name`."""
158 | if not self.weight_decay_rate:
159 | return False
160 | if self.exclude_from_weight_decay:
161 | for r in self.exclude_from_weight_decay:
162 | if re.search(r, param_name) is not None:
163 | return False
164 | return True
165 |
166 | def _get_variable_name(self, param_name):
167 | """Get the variable name from the tensor name."""
168 | m = re.match("^(.*):\\d+$", param_name)
169 | if m is not None:
170 | param_name = m.group(1)
171 | return param_name
172 |
--------------------------------------------------------------------------------
/bert_base/train/lstm_crf_layer.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 | """
4 | bert-blstm-crf layer
5 | @Author:Macan
6 | """
7 |
8 | import tensorflow as tf
9 | from tensorflow.contrib import rnn
10 | from tensorflow.contrib import crf
11 |
12 |
13 | class BLSTM_CRF(object):
14 | def __init__(self, embedded_chars, hidden_unit, cell_type, num_layers, dropout_rate,
15 | initializers, num_labels, seq_length, labels, lengths, is_training):
16 | """
17 | BLSTM-CRF 网络
18 | :param embedded_chars: Fine-tuning embedding input
19 | :param hidden_unit: LSTM的隐含单元个数
20 | :param cell_type: RNN类型(LSTM OR GRU DICNN will be add in feature)
21 | :param num_layers: RNN的层数
22 | :param droupout_rate: droupout rate
23 | :param initializers: variable init class
24 | :param num_labels: 标签数量
25 | :param seq_length: 序列最大长度
26 | :param labels: 真实标签
27 | :param lengths: [batch_size] 每个batch下序列的真实长度
28 | :param is_training: 是否是训练过程
29 | """
30 | self.hidden_unit = hidden_unit
31 | self.dropout_rate = dropout_rate
32 | self.cell_type = cell_type
33 | self.num_layers = num_layers
34 | self.embedded_chars = embedded_chars
35 | self.initializers = initializers
36 | self.seq_length = seq_length
37 | self.num_labels = num_labels
38 | self.labels = labels
39 | self.lengths = lengths
40 | self.embedding_dims = embedded_chars.shape[-1].value
41 | self.is_training = is_training
42 |
43 | def add_blstm_crf_layer(self, crf_only):
44 | """
45 | blstm-crf网络
46 | :return:
47 | """
48 | if self.is_training:
49 | # lstm input dropout rate i set 0.9 will get best score
50 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.dropout_rate)
51 |
52 | if crf_only:
53 | logits = self.project_crf_layer(self.embedded_chars)
54 | else:
55 | # blstm
56 | lstm_output = self.blstm_layer(self.embedded_chars)
57 | # project
58 | logits = self.project_bilstm_layer(lstm_output)
59 | # crf
60 | loss, trans = self.crf_layer(logits)
61 | # CRF decode, pred_ids 是一条最大概率的标注路径
62 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths)
63 | return (loss, logits, trans, pred_ids)
64 |
65 | def _witch_cell(self):
66 | """
67 | RNN 类型
68 | :return:
69 | """
70 | cell_tmp = None
71 | if self.cell_type == 'lstm':
72 | cell_tmp = rnn.LSTMCell(self.hidden_unit)
73 | elif self.cell_type == 'gru':
74 | cell_tmp = rnn.GRUCell(self.hidden_unit)
75 | return cell_tmp
76 |
77 | def _bi_dir_rnn(self):
78 | """
79 | 双向RNN
80 | :return:
81 | """
82 | cell_fw = self._witch_cell()
83 | cell_bw = self._witch_cell()
84 | if self.dropout_rate is not None:
85 | cell_bw = rnn.DropoutWrapper(cell_bw, output_keep_prob=self.dropout_rate)
86 | cell_fw = rnn.DropoutWrapper(cell_fw, output_keep_prob=self.dropout_rate)
87 | return cell_fw, cell_bw
88 |
89 | def blstm_layer(self, embedding_chars):
90 | """
91 |
92 | :return:
93 | """
94 | with tf.variable_scope('rnn_layer'):
95 | cell_fw, cell_bw = self._bi_dir_rnn()
96 | if self.num_layers > 1:
97 | cell_fw = rnn.MultiRNNCell([cell_fw] * self.num_layers, state_is_tuple=True)
98 | cell_bw = rnn.MultiRNNCell([cell_bw] * self.num_layers, state_is_tuple=True)
99 |
100 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, embedding_chars,
101 | dtype=tf.float32)
102 | outputs = tf.concat(outputs, axis=2)
103 | return outputs
104 |
105 | def project_bilstm_layer(self, lstm_outputs, name=None):
106 | """
107 | hidden layer between lstm layer and logits
108 | :param lstm_outputs: [batch_size, num_steps, emb_size]
109 | :return: [batch_size, num_steps, num_tags]
110 | """
111 | with tf.variable_scope("project" if not name else name):
112 | with tf.variable_scope("hidden"):
113 | W = tf.get_variable("W", shape=[self.hidden_unit * 2, self.hidden_unit],
114 | dtype=tf.float32, initializer=self.initializers.xavier_initializer())
115 |
116 | b = tf.get_variable("b", shape=[self.hidden_unit], dtype=tf.float32,
117 | initializer=tf.zeros_initializer())
118 | output = tf.reshape(lstm_outputs, shape=[-1, self.hidden_unit * 2])
119 | hidden = tf.nn.xw_plus_b(output, W, b)
120 |
121 | # project to score of tags
122 | with tf.variable_scope("logits"):
123 | W = tf.get_variable("W", shape=[self.hidden_unit, self.num_labels],
124 | dtype=tf.float32, initializer=self.initializers.xavier_initializer())
125 |
126 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32,
127 | initializer=tf.zeros_initializer())
128 |
129 | pred = tf.nn.xw_plus_b(hidden, W, b)
130 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels])
131 |
132 | def project_crf_layer(self, embedding_chars, name=None):
133 | """
134 | hidden layer between input layer and logits
135 | :param lstm_outputs: [batch_size, num_steps, emb_size]
136 | :return: [batch_size, num_steps, num_tags]
137 | """
138 | with tf.variable_scope("project" if not name else name):
139 | with tf.variable_scope("logits"):
140 | W = tf.get_variable("W", shape=[self.embedding_dims, self.num_labels],
141 | dtype=tf.float32, initializer=self.initializers.xavier_initializer())
142 |
143 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32,
144 | initializer=tf.zeros_initializer())
145 | output = tf.reshape(self.embedded_chars,
146 | shape=[-1, self.embedding_dims]) # [batch_size, embedding_dims]
147 | pred = tf.tanh(tf.nn.xw_plus_b(output, W, b))
148 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels])
149 |
150 | def crf_layer(self, logits):
151 | """
152 | calculate crf loss
153 | :param project_logits: [1, num_steps, num_tags]
154 | :return: scalar loss
155 | """
156 | with tf.variable_scope("crf_loss"):
157 | trans = tf.get_variable(
158 | "transitions",
159 | shape=[self.num_labels, self.num_labels],
160 | initializer=self.initializers.xavier_initializer())
161 | if self.labels is None:
162 | return None, trans
163 | else:
164 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood(
165 | inputs=logits,
166 | tag_indices=self.labels,
167 | transition_params=trans,
168 | sequence_lengths=self.lengths)
169 | return tf.reduce_mean(-log_likelihood), trans
170 |
--------------------------------------------------------------------------------
/bert_base/train/tf_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Multiclass
3 | from:
4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py
5 |
6 | """
7 |
8 | __author__ = "Guillaume Genthial"
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix
13 |
14 | __all__ = ['precision', 'recall', 'f1', 'fbeta', 'safe_div', 'pr_re_fbeta', 'pr_re_fbeta', 'metrics_from_confusion_matrix']
15 |
16 |
17 | def precision(labels, predictions, num_classes, pos_indices=None,
18 | weights=None, average='micro'):
19 | """Multi-class precision metric for Tensorflow
20 | Parameters
21 | ----------
22 | labels : Tensor of tf.int32 or tf.int64
23 | The true labels
24 | predictions : Tensor of tf.int32 or tf.int64
25 | The predictions, same shape as labels
26 | num_classes : int
27 | The number of classes
28 | pos_indices : list of int, optional
29 | The indices of the positive classes, default is all
30 | weights : Tensor of tf.int32, optional
31 | Mask, must be of compatible shape with labels
32 | average : str, optional
33 | 'micro': counts the total number of true positives, false
34 | positives, and false negatives for the classes in
35 | `pos_indices` and infer the metric from it.
36 | 'macro': will compute the metric separately for each class in
37 | `pos_indices` and average. Will not account for class
38 | imbalance.
39 | 'weighted': will compute the metric separately for each class in
40 | `pos_indices` and perform a weighted average by the total
41 | number of true labels for each class.
42 | Returns
43 | -------
44 | tuple of (scalar float Tensor, update_op)
45 | """
46 | cm, op = _streaming_confusion_matrix(
47 | labels, predictions, num_classes, weights)
48 | pr, _, _ = metrics_from_confusion_matrix(
49 | cm, pos_indices, average=average)
50 | op, _, _ = metrics_from_confusion_matrix(
51 | op, pos_indices, average=average)
52 | return (pr, op)
53 |
54 |
55 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None,
56 | average='micro'):
57 | """Multi-class recall metric for Tensorflow
58 | Parameters
59 | ----------
60 | labels : Tensor of tf.int32 or tf.int64
61 | The true labels
62 | predictions : Tensor of tf.int32 or tf.int64
63 | The predictions, same shape as labels
64 | num_classes : int
65 | The number of classes
66 | pos_indices : list of int, optional
67 | The indices of the positive classes, default is all
68 | weights : Tensor of tf.int32, optional
69 | Mask, must be of compatible shape with labels
70 | average : str, optional
71 | 'micro': counts the total number of true positives, false
72 | positives, and false negatives for the classes in
73 | `pos_indices` and infer the metric from it.
74 | 'macro': will compute the metric separately for each class in
75 | `pos_indices` and average. Will not account for class
76 | imbalance.
77 | 'weighted': will compute the metric separately for each class in
78 | `pos_indices` and perform a weighted average by the total
79 | number of true labels for each class.
80 | Returns
81 | -------
82 | tuple of (scalar float Tensor, update_op)
83 | """
84 | cm, op = _streaming_confusion_matrix(
85 | labels, predictions, num_classes, weights)
86 | _, re, _ = metrics_from_confusion_matrix(
87 | cm, pos_indices, average=average)
88 | _, op, _ = metrics_from_confusion_matrix(
89 | op, pos_indices, average=average)
90 | return (re, op)
91 |
92 |
93 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None,
94 | average='micro'):
95 | return fbeta(labels, predictions, num_classes, pos_indices, weights,
96 | average)
97 |
98 |
99 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None,
100 | average='micro', beta=1):
101 | """Multi-class fbeta metric for Tensorflow
102 | Parameters
103 | ----------
104 | labels : Tensor of tf.int32 or tf.int64
105 | The true labels
106 | predictions : Tensor of tf.int32 or tf.int64
107 | The predictions, same shape as labels
108 | num_classes : int
109 | The number of classes
110 | pos_indices : list of int, optional
111 | The indices of the positive classes, default is all
112 | weights : Tensor of tf.int32, optional
113 | Mask, must be of compatible shape with labels
114 | average : str, optional
115 | 'micro': counts the total number of true positives, false
116 | positives, and false negatives for the classes in
117 | `pos_indices` and infer the metric from it.
118 | 'macro': will compute the metric separately for each class in
119 | `pos_indices` and average. Will not account for class
120 | imbalance.
121 | 'weighted': will compute the metric separately for each class in
122 | `pos_indices` and perform a weighted average by the total
123 | number of true labels for each class.
124 | beta : int, optional
125 | Weight of precision in harmonic mean
126 | Returns
127 | -------
128 | tuple of (scalar float Tensor, update_op)
129 | """
130 | cm, op = _streaming_confusion_matrix(
131 | labels, predictions, num_classes, weights)
132 | _, _, fbeta = metrics_from_confusion_matrix(
133 | cm, pos_indices, average=average, beta=beta)
134 | _, _, op = metrics_from_confusion_matrix(
135 | op, pos_indices, average=average, beta=beta)
136 | return (fbeta, op)
137 |
138 |
139 | def safe_div(numerator, denominator):
140 | """Safe division, return 0 if denominator is 0"""
141 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator)
142 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype)
143 | denominator_is_zero = tf.equal(denominator, zeros)
144 | return tf.where(denominator_is_zero, zeros, numerator / denominator)
145 |
146 |
147 | def pr_re_fbeta(cm, pos_indices, beta=1):
148 | """Uses a confusion matrix to compute precision, recall and fbeta"""
149 | num_classes = cm.shape[0]
150 | neg_indices = [i for i in range(num_classes) if i not in pos_indices]
151 | cm_mask = np.ones([num_classes, num_classes])
152 | cm_mask[neg_indices, neg_indices] = 0
153 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask))
154 |
155 | cm_mask = np.ones([num_classes, num_classes])
156 | cm_mask[:, neg_indices] = 0
157 | tot_pred = tf.reduce_sum(cm * cm_mask)
158 |
159 | cm_mask = np.ones([num_classes, num_classes])
160 | cm_mask[neg_indices, :] = 0
161 | tot_gold = tf.reduce_sum(cm * cm_mask)
162 |
163 | pr = safe_div(diag_sum, tot_pred)
164 | re = safe_div(diag_sum, tot_gold)
165 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re)
166 |
167 | return pr, re, fbeta
168 |
169 |
170 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro',
171 | beta=1):
172 | """Precision, Recall and F1 from the confusion matrix
173 | Parameters
174 | ----------
175 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes)
176 | The streaming confusion matrix.
177 | pos_indices : list of int, optional
178 | The indices of the positive classes
179 | beta : int, optional
180 | Weight of precision in harmonic mean
181 | average : str, optional
182 | 'micro', 'macro' or 'weighted'
183 | """
184 | num_classes = cm.shape[0]
185 | if pos_indices is None:
186 | pos_indices = [i for i in range(num_classes)]
187 |
188 | if average == 'micro':
189 | return pr_re_fbeta(cm, pos_indices, beta)
190 | elif average in {'macro', 'weighted'}:
191 | precisions, recalls, fbetas, n_golds = [], [], [], []
192 | for idx in pos_indices:
193 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta)
194 | precisions.append(pr)
195 | recalls.append(re)
196 | fbetas.append(fbeta)
197 | cm_mask = np.zeros([num_classes, num_classes])
198 | cm_mask[idx, :] = 1
199 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask)))
200 |
201 | if average == 'macro':
202 | pr = tf.reduce_mean(precisions)
203 | re = tf.reduce_mean(recalls)
204 | fbeta = tf.reduce_mean(fbetas)
205 | return pr, re, fbeta
206 | if average == 'weighted':
207 | n_gold = tf.reduce_sum(n_golds)
208 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds))
209 | pr = safe_div(pr_sum, n_gold)
210 | re_sum = sum(r * n for r, n in zip(recalls, n_golds))
211 | re = safe_div(re_sum, n_gold)
212 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds))
213 | fbeta = safe_div(fbeta_sum, n_gold)
214 | return pr, re, fbeta
215 |
216 | else:
217 | raise NotImplementedError()
--------------------------------------------------------------------------------
/bert_base/server/simple_flask_http_service.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | #@Time : ${DATE} ${TIME}
6 | # @Author : MaCan (ma_cancan@163.com)
7 | # @File : ${NAME}.py
8 | """
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import os
15 | import flask
16 | from flask import request, jsonify
17 | import json
18 | import pickle
19 | from datetime import datetime
20 | import tensorflow as tf
21 | from tensorflow import keras as K
22 | import numpy as np
23 |
24 | import sys
25 | sys.path.append('../..')
26 | from bert_base.train.models import create_model, InputFeatures
27 | from bert_base.bert import tokenization, modeling
28 |
29 |
30 | model_dir = r'../../output'
31 | bert_dir = 'H:\models\chinese_L-12_H-768_A-12'
32 |
33 | is_training=False
34 | use_one_hot_embeddings=False
35 | batch_size=1
36 | max_seq_length = 202
37 |
38 | gpu_config = tf.ConfigProto()
39 | gpu_config.gpu_options.allow_growth = True
40 | sess=tf.Session(config=gpu_config)
41 | model=None
42 |
43 | global graph
44 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None
45 |
46 |
47 | print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
48 | if not os.path.exists(os.path.join(model_dir, "checkpoint")):
49 | raise Exception("failed to get checkpoint. going to return ")
50 |
51 | # 加载label->id的词典
52 | with open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
53 | label2id = pickle.load(rf)
54 | id2label = {value: key for key, value in label2id.items()}
55 |
56 | with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
57 | label_list = pickle.load(rf)
58 | num_labels = len(label_list) + 1
59 |
60 |
61 | graph = tf.get_default_graph()
62 | with graph.as_default():
63 | print("going to restore checkpoint")
64 | #sess.run(tf.global_variables_initializer())
65 | input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")
66 | input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")
67 |
68 | bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
69 | (total_loss, logits, trans, pred_ids) = create_model(
70 | bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None,
71 | labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)
72 |
73 | saver = tf.train.Saver()
74 | saver.restore(sess, tf.train.latest_checkpoint(model_dir))
75 |
76 | tokenizer = tokenization.FullTokenizer(
77 | vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True)
78 |
79 | app = flask.Flask(__name__)
80 |
81 |
82 | @app.route('/ner_predict_service', methods=['GET'])
83 | def ner_predict_service():
84 | """
85 | do online prediction. each time make prediction for one instance.
86 | you can change to a batch if you want.
87 |
88 | :param line: a list. element is: [dummy_label,text_a,text_b]
89 | :return:
90 | """
91 | def convert(line):
92 | feature = convert_single_example(0, line, label_list, max_seq_length, tokenizer, 'p')
93 | input_ids = np.reshape([feature.input_ids],(batch_size, max_seq_length))
94 | input_mask = np.reshape([feature.input_mask],(batch_size, max_seq_length))
95 | segment_ids = np.reshape([feature.segment_ids],(batch_size, max_seq_length))
96 | label_ids =np.reshape([feature.label_ids],(batch_size, max_seq_length))
97 | return input_ids, input_mask, segment_ids, label_ids
98 |
99 | global graph
100 | with graph.as_default():
101 | result = {}
102 | result['code'] = 0
103 | try:
104 | sentence = request.args['query']
105 | result['query'] = sentence
106 | start = datetime.now()
107 | if len(sentence) < 2:
108 | print(sentence)
109 | result['data'] = ['O'] * len(sentence)
110 | return json.dumps(result)
111 | sentence = tokenizer.tokenize(sentence)
112 | # print('your input is:{}'.format(sentence))
113 | input_ids, input_mask, segment_ids, label_ids = convert(sentence)
114 |
115 |
116 | feed_dict = {input_ids_p: input_ids,
117 | input_mask_p: input_mask}
118 | # run session get current feed_dict result
119 | pred_ids_result = sess.run([pred_ids], feed_dict)
120 | pred_label_result = convert_id_to_label(pred_ids_result, id2label)
121 | print(pred_label_result)
122 | #todo: 组合策略
123 | result['data'] = pred_label_result
124 | print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
125 | return json.dumps(result)
126 | except:
127 | result['code'] = -1
128 | result['data'] = 'error'
129 | return json.dumps(result)
130 |
131 | def online_predict():
132 | """
133 | do online prediction. each time make prediction for one instance.
134 | you can change to a batch if you want.
135 |
136 | :param line: a list. element is: [dummy_label,text_a,text_b]
137 | :return:
138 | """
139 | def convert(line):
140 | feature = convert_single_example(0, line, label_list, max_seq_length, tokenizer, 'p')
141 | input_ids = np.reshape([feature.input_ids],(batch_size, max_seq_length))
142 | input_mask = np.reshape([feature.input_mask],(batch_size, max_seq_length))
143 | segment_ids = np.reshape([feature.segment_ids],(batch_size, max_seq_length))
144 | label_ids =np.reshape([feature.label_ids],(batch_size, max_seq_length))
145 | return input_ids, input_mask, segment_ids, label_ids
146 |
147 | global graph
148 | with graph.as_default():
149 |
150 | sentence = '北京天安门'
151 |
152 | start = datetime.now()
153 | if len(sentence) < 2:
154 | print(sentence)
155 |
156 | sentence = tokenizer.tokenize(sentence)
157 | # print('your input is:{}'.format(sentence))
158 | input_ids, input_mask, segment_ids, label_ids = convert(sentence)
159 |
160 |
161 | feed_dict = {input_ids_p: input_ids,
162 | input_mask_p: input_mask}
163 | # run session get current feed_dict result
164 | pred_ids_result = sess.run([pred_ids], feed_dict)
165 | pred_label_result = convert_id_to_label(pred_ids_result, id2label)
166 | print(pred_label_result)
167 |
168 | print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
169 |
170 |
171 |
172 |
173 |
174 | def convert_id_to_label(pred_ids_result, idx2label):
175 | """
176 | 将id形式的结果转化为真实序列结果
177 | :param pred_ids_result:
178 | :param idx2label:
179 | :return:
180 | """
181 | result = []
182 | for row in range(batch_size):
183 | curr_seq = []
184 | for ids in pred_ids_result[row][0]:
185 | if ids == 0:
186 | break
187 | curr_label = idx2label[ids]
188 | if curr_label in ['[CLS]', '[SEP]']:
189 | continue
190 | curr_seq.append(curr_label)
191 | result.append(curr_seq)
192 | return result
193 |
194 |
195 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode):
196 | """
197 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
198 | :param ex_index: index
199 | :param example: 一个样本
200 | :param label_list: 标签列表
201 | :param max_seq_length:
202 | :param tokenizer:
203 | :param mode:
204 | :return:
205 | """
206 | label_map = {}
207 | # 1表示从1开始对label进行index化
208 | for (i, label) in enumerate(label_list, 1):
209 | label_map[label] = i
210 | # 保存label->index 的map
211 | if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
212 | with open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
213 | pickle.dump(label_map, w)
214 |
215 | tokens = example
216 | # tokens = tokenizer.tokenize(example.text)
217 | # 序列截断
218 | if len(tokens) >= max_seq_length - 1:
219 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
220 | ntokens = []
221 | segment_ids = []
222 | label_ids = []
223 | ntokens.append("[CLS]") # 句子开始设置CLS 标志
224 | segment_ids.append(0)
225 | # append("O") or append("[CLS]") not sure!
226 | label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
227 | for i, token in enumerate(tokens):
228 | ntokens.append(token)
229 | segment_ids.append(0)
230 | label_ids.append(0)
231 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志
232 | segment_ids.append(0)
233 | # append("O") or append("[SEP]") not sure!
234 | label_ids.append(label_map["[SEP]"])
235 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式
236 | input_mask = [1] * len(input_ids)
237 |
238 | # padding, 使用
239 | while len(input_ids) < max_seq_length:
240 | input_ids.append(0)
241 | input_mask.append(0)
242 | segment_ids.append(0)
243 | # we don't concerned about it!
244 | label_ids.append(0)
245 | ntokens.append("**NULL**")
246 | # label_mask.append(0)
247 | # print(len(input_ids))
248 | assert len(input_ids) == max_seq_length
249 | assert len(input_mask) == max_seq_length
250 | assert len(segment_ids) == max_seq_length
251 | assert len(label_ids) == max_seq_length
252 | # assert len(label_mask) == max_seq_length
253 |
254 | # 结构化为一个类
255 | feature = InputFeatures(
256 | input_ids=input_ids,
257 | input_mask=input_mask,
258 | segment_ids=segment_ids,
259 | label_ids=label_ids,
260 | # label_mask = label_mask
261 | )
262 | return feature
263 |
264 |
265 | if __name__ == "__main__":
266 | app.run(host='0.0.0.0', port=12345)
267 | #online_predict()
268 |
269 |
270 |
--------------------------------------------------------------------------------
/bert_base/train/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | 一些公共模型代码
5 | @Time : 2019/1/30 12:46
6 | @Author : MaCan (ma_cancan@163.com)
7 | @File : models.py
8 | """
9 |
10 | from bert_base.train.lstm_crf_layer import BLSTM_CRF
11 | from tensorflow.contrib.layers.python.layers import initializers
12 |
13 |
14 | __all__ = ['InputExample', 'InputFeatures', 'decode_labels', 'create_model', 'convert_id_str',
15 | 'convert_id_to_label', 'result_to_json', 'create_classification_model']
16 |
17 | class Model(object):
18 | def __init__(self, *args, **kwargs):
19 | pass
20 |
21 |
22 | class InputExample(object):
23 | """A single training/test example for simple sequence classification."""
24 |
25 | def __init__(self, guid=None, text=None, label=None):
26 | """Constructs a InputExample.
27 | Args:
28 | guid: Unique id for the example.
29 | text_a: string. The untokenized text of the first sequence. For single
30 | sequence tasks, only this sequence must be specified.
31 | label: (Optional) string. The label of the example. This should be
32 | specified for train and dev examples, but not for test examples.
33 | """
34 | self.guid = guid
35 | self.text = text
36 | self.label = label
37 |
38 | class InputFeatures(object):
39 | """A single set of features of data."""
40 |
41 | def __init__(self, input_ids, input_mask, segment_ids, label_ids, ):
42 | self.input_ids = input_ids
43 | self.input_mask = input_mask
44 | self.segment_ids = segment_ids
45 | self.label_ids = label_ids
46 | # self.label_mask = label_mask
47 |
48 |
49 | class DataProcessor(object):
50 | """Base class for data converters for sequence classification data sets."""
51 |
52 | def get_train_examples(self, data_dir):
53 | """Gets a collection of `InputExample`s for the train set."""
54 | raise NotImplementedError()
55 |
56 | def get_dev_examples(self, data_dir):
57 | """Gets a collection of `InputExample`s for the dev set."""
58 | raise NotImplementedError()
59 |
60 | def get_labels(self):
61 | """Gets the list of labels for this data set."""
62 | raise NotImplementedError()
63 |
64 |
65 | def create_model(bert_config, is_training, input_ids, input_mask,
66 | segment_ids, labels, num_labels, use_one_hot_embeddings,
67 | dropout_rate=1.0, lstm_size=1, cell='lstm', num_layers=1):
68 | """
69 | 创建X模型
70 | :param bert_config: bert 配置
71 | :param is_training:
72 | :param input_ids: 数据的idx 表示
73 | :param input_mask:
74 | :param segment_ids:
75 | :param labels: 标签的idx 表示
76 | :param num_labels: 类别数量
77 | :param use_one_hot_embeddings:
78 | :return:
79 | """
80 | # 使用数据加载BertModel,获取对应的字embedding
81 | import tensorflow as tf
82 | from bert_base.bert import modeling
83 | model = modeling.BertModel(
84 | config=bert_config,
85 | is_training=is_training,
86 | input_ids=input_ids,
87 | input_mask=input_mask,
88 | token_type_ids=segment_ids,
89 | use_one_hot_embeddings=use_one_hot_embeddings
90 | )
91 | # 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size]
92 | embedding = model.get_sequence_output()
93 | max_seq_length = embedding.shape[1].value
94 | # 算序列真实长度
95 | used = tf.sign(tf.abs(input_ids))
96 | lengths = tf.reduce_sum(used, reduction_indices=1) # [batch_size] 大小的向量,包含了当前batch中的序列长度
97 | # 添加CRF output layer
98 | blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=lstm_size, cell_type=cell, num_layers=num_layers,
99 | dropout_rate=dropout_rate, initializers=initializers, num_labels=num_labels,
100 | seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training)
101 | rst = blstm_crf.add_blstm_crf_layer(crf_only=True)
102 | return rst
103 |
104 |
105 | def create_classification_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels):
106 | """
107 |
108 | :param bert_config:
109 | :param is_training:
110 | :param input_ids:
111 | :param input_mask:
112 | :param segment_ids:
113 | :param labels:
114 | :param num_labels:
115 | :param use_one_hot_embedding:
116 | :return:
117 | """
118 | import tensorflow as tf
119 | from bert_base.bert import modeling
120 | # 通过传入的训练数据,进行representation
121 | model = modeling.BertModel(
122 | config=bert_config,
123 | is_training=is_training,
124 | input_ids=input_ids,
125 | input_mask=input_mask,
126 | token_type_ids=segment_ids,
127 | )
128 |
129 | embedding_layer = model.get_sequence_output()
130 | output_layer = model.get_pooled_output()
131 | hidden_size = output_layer.shape[-1].value
132 |
133 | # predict = CNN_Classification(embedding_chars=embedding_layer,
134 | # labels=labels,
135 | # num_tags=num_labels,
136 | # sequence_length=FLAGS.max_seq_length,
137 | # embedding_dims=embedding_layer.shape[-1].value,
138 | # vocab_size=0,
139 | # filter_sizes=[3, 4, 5],
140 | # num_filters=3,
141 | # dropout_keep_prob=FLAGS.dropout_keep_prob,
142 | # l2_reg_lambda=0.001)
143 | # loss, predictions, probabilities = predict.add_cnn_layer()
144 |
145 | output_weights = tf.get_variable(
146 | "output_weights", [num_labels, hidden_size],
147 | initializer=tf.truncated_normal_initializer(stddev=0.02))
148 |
149 | output_bias = tf.get_variable(
150 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
151 |
152 | with tf.variable_scope("loss"):
153 | if is_training:
154 | # I.e., 0.1 dropout
155 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
156 |
157 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
158 | logits = tf.nn.bias_add(logits, output_bias)
159 | probabilities = tf.nn.softmax(logits, axis=-1)
160 | log_probs = tf.nn.log_softmax(logits, axis=-1)
161 |
162 | if labels is not None:
163 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
164 |
165 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
166 | loss = tf.reduce_mean(per_example_loss)
167 | else:
168 | loss, per_example_loss = None, None
169 | return (loss, per_example_loss, logits, probabilities)
170 |
171 |
172 | def decode_labels(labels, batch_size):
173 | new_labels = []
174 | for row in range(batch_size):
175 | label = []
176 | for i in labels[row]:
177 | i = i.decode('utf-8')
178 | if i == '**PAD**':
179 | break
180 | if i in ['[CLS]', '[SEP]']:
181 | continue
182 | label.append(i)
183 | new_labels.append(label)
184 | return new_labels
185 |
186 |
187 | def convert_id_str(input_ids, batch_size):
188 | res = []
189 | for row in range(batch_size):
190 | line = []
191 | for i in input_ids[row]:
192 | i = i.decode('utf-8')
193 | if i == '**PAD**':
194 | break
195 | if i in ['[CLS]', '[SEP]']:
196 | continue
197 |
198 | line.append(i)
199 | res.append(line)
200 | return res
201 |
202 |
203 | def convert_id_to_label(pred_ids_result, idx2label, batch_size):
204 | """
205 | 将id形式的结果转化为真实序列结果
206 | :param pred_ids_result:
207 | :param idx2label:
208 | :return:
209 | """
210 | result = []
211 | index_result = []
212 | for row in range(batch_size):
213 | curr_seq = []
214 | curr_idx = []
215 | ids = pred_ids_result[row]
216 | for idx, id in enumerate(ids):
217 | if id == 0:
218 | break
219 | curr_label = idx2label[id]
220 | if curr_label in ['[CLS]', '[SEP]']:
221 | if id == 102 and (idx < len(ids) and ids[idx + 1] == 0):
222 | break
223 | continue
224 | # elif curr_label == '[SEP]':
225 | # break
226 | curr_seq.append(curr_label)
227 | curr_idx.append(id)
228 | result.append(curr_seq)
229 | index_result.append(curr_idx)
230 | return result, index_result
231 |
232 |
233 | def result_to_json(self, string, tags):
234 | """
235 | 将模型标注序列和输入序列结合 转化为结果
236 | :param string: 输入序列
237 | :param tags: 标注结果
238 | :return:
239 | """
240 | item = {"entities": []}
241 | entity_name = ""
242 | entity_start = 0
243 | idx = 0
244 | last_tag = ''
245 |
246 | for char, tag in zip(string, tags):
247 | if tag[0] == "S":
248 | self.append(char, idx, idx+1, tag[2:])
249 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]})
250 | elif tag[0] == "B":
251 | if entity_name != '':
252 | self.append(entity_name, entity_start, idx, last_tag[2:])
253 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
254 | entity_name = ""
255 | entity_name += char
256 | entity_start = idx
257 | elif tag[0] == "I":
258 | entity_name += char
259 | elif tag[0] == "O":
260 | if entity_name != '':
261 | self.append(entity_name, entity_start, idx, last_tag[2:])
262 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
263 | entity_name = ""
264 | else:
265 | entity_name = ""
266 | entity_start = idx
267 | idx += 1
268 | last_tag = tag
269 | if entity_name != '':
270 | self.append(entity_name, entity_start, idx, last_tag[2:])
271 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
272 | return item
273 |
--------------------------------------------------------------------------------
/bert_base/bert/modeling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import json
21 | import random
22 | import re
23 |
24 | import modeling
25 | import six
26 | import tensorflow as tf
27 |
28 |
29 | class BertModelTest(tf.test.TestCase):
30 |
31 | class BertModelTester(object):
32 |
33 | def __init__(self,
34 | parent,
35 | batch_size=13,
36 | seq_length=7,
37 | is_training=True,
38 | use_input_mask=True,
39 | use_token_type_ids=True,
40 | vocab_size=99,
41 | hidden_size=32,
42 | num_hidden_layers=5,
43 | num_attention_heads=4,
44 | intermediate_size=37,
45 | hidden_act="gelu",
46 | hidden_dropout_prob=0.1,
47 | attention_probs_dropout_prob=0.1,
48 | max_position_embeddings=512,
49 | type_vocab_size=16,
50 | initializer_range=0.02,
51 | scope=None):
52 | self.parent = parent
53 | self.batch_size = batch_size
54 | self.seq_length = seq_length
55 | self.is_training = is_training
56 | self.use_input_mask = use_input_mask
57 | self.use_token_type_ids = use_token_type_ids
58 | self.vocab_size = vocab_size
59 | self.hidden_size = hidden_size
60 | self.num_hidden_layers = num_hidden_layers
61 | self.num_attention_heads = num_attention_heads
62 | self.intermediate_size = intermediate_size
63 | self.hidden_act = hidden_act
64 | self.hidden_dropout_prob = hidden_dropout_prob
65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
66 | self.max_position_embeddings = max_position_embeddings
67 | self.type_vocab_size = type_vocab_size
68 | self.initializer_range = initializer_range
69 | self.scope = scope
70 |
71 | def create_model(self):
72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
73 | self.vocab_size)
74 |
75 | input_mask = None
76 | if self.use_input_mask:
77 | input_mask = BertModelTest.ids_tensor(
78 | [self.batch_size, self.seq_length], vocab_size=2)
79 |
80 | token_type_ids = None
81 | if self.use_token_type_ids:
82 | token_type_ids = BertModelTest.ids_tensor(
83 | [self.batch_size, self.seq_length], self.type_vocab_size)
84 |
85 | config = modeling.BertConfig(
86 | vocab_size=self.vocab_size,
87 | hidden_size=self.hidden_size,
88 | num_hidden_layers=self.num_hidden_layers,
89 | num_attention_heads=self.num_attention_heads,
90 | intermediate_size=self.intermediate_size,
91 | hidden_act=self.hidden_act,
92 | hidden_dropout_prob=self.hidden_dropout_prob,
93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob,
94 | max_position_embeddings=self.max_position_embeddings,
95 | type_vocab_size=self.type_vocab_size,
96 | initializer_range=self.initializer_range)
97 |
98 | model = modeling.BertModel(
99 | config=config,
100 | is_training=self.is_training,
101 | input_ids=input_ids,
102 | input_mask=input_mask,
103 | token_type_ids=token_type_ids,
104 | scope=self.scope)
105 |
106 | outputs = {
107 | "embedding_output": model.get_embedding_output(),
108 | "sequence_output": model.get_sequence_output(),
109 | "pooled_output": model.get_pooled_output(),
110 | "all_encoder_layers": model.get_all_encoder_layers(),
111 | }
112 | return outputs
113 |
114 | def check_output(self, result):
115 | self.parent.assertAllEqual(
116 | result["embedding_output"].shape,
117 | [self.batch_size, self.seq_length, self.hidden_size])
118 |
119 | self.parent.assertAllEqual(
120 | result["sequence_output"].shape,
121 | [self.batch_size, self.seq_length, self.hidden_size])
122 |
123 | self.parent.assertAllEqual(result["pooled_output"].shape,
124 | [self.batch_size, self.hidden_size])
125 |
126 | def test_default(self):
127 | self.run_tester(BertModelTest.BertModelTester(self))
128 |
129 | def test_config_to_json_string(self):
130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37)
131 | obj = json.loads(config.to_json_string())
132 | self.assertEqual(obj["vocab_size"], 99)
133 | self.assertEqual(obj["hidden_size"], 37)
134 |
135 | def run_tester(self, tester):
136 | with self.test_session() as sess:
137 | ops = tester.create_model()
138 | init_op = tf.group(tf.global_variables_initializer(),
139 | tf.local_variables_initializer())
140 | sess.run(init_op)
141 | output_result = sess.run(ops)
142 | tester.check_output(output_result)
143 |
144 | self.assert_all_tensors_reachable(sess, [init_op, ops])
145 |
146 | @classmethod
147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
148 | """Creates a random int32 tensor of the shape within the vocab size."""
149 | if rng is None:
150 | rng = random.Random()
151 |
152 | total_dims = 1
153 | for dim in shape:
154 | total_dims *= dim
155 |
156 | values = []
157 | for _ in range(total_dims):
158 | values.append(rng.randint(0, vocab_size - 1))
159 |
160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
161 |
162 | def assert_all_tensors_reachable(self, sess, outputs):
163 | """Checks that all the tensors in the graph are reachable from outputs."""
164 | graph = sess.graph
165 |
166 | ignore_strings = [
167 | "^.*/assert_less_equal/.*$",
168 | "^.*/dilation_rate$",
169 | "^.*/Tensordot/concat$",
170 | "^.*/Tensordot/concat/axis$",
171 | "^testing/.*$",
172 | ]
173 |
174 | ignore_regexes = [re.compile(x) for x in ignore_strings]
175 |
176 | unreachable = self.get_unreachable_ops(graph, outputs)
177 | filtered_unreachable = []
178 | for x in unreachable:
179 | do_ignore = False
180 | for r in ignore_regexes:
181 | m = r.match(x.name)
182 | if m is not None:
183 | do_ignore = True
184 | if do_ignore:
185 | continue
186 | filtered_unreachable.append(x)
187 | unreachable = filtered_unreachable
188 |
189 | self.assertEqual(
190 | len(unreachable), 0, "The following ops are unreachable: %s" %
191 | (" ".join([x.name for x in unreachable])))
192 |
193 | @classmethod
194 | def get_unreachable_ops(cls, graph, outputs):
195 | """Finds all of the tensors in graph that are unreachable from outputs."""
196 | outputs = cls.flatten_recursive(outputs)
197 | output_to_op = collections.defaultdict(list)
198 | op_to_all = collections.defaultdict(list)
199 | assign_out_to_in = collections.defaultdict(list)
200 |
201 | for op in graph.get_operations():
202 | for x in op.inputs:
203 | op_to_all[op.name].append(x.name)
204 | for y in op.outputs:
205 | output_to_op[y.name].append(op.name)
206 | op_to_all[op.name].append(y.name)
207 | if str(op.type) == "Assign":
208 | for y in op.outputs:
209 | for x in op.inputs:
210 | assign_out_to_in[y.name].append(x.name)
211 |
212 | assign_groups = collections.defaultdict(list)
213 | for out_name in assign_out_to_in.keys():
214 | name_group = assign_out_to_in[out_name]
215 | for n1 in name_group:
216 | assign_groups[n1].append(out_name)
217 | for n2 in name_group:
218 | if n1 != n2:
219 | assign_groups[n1].append(n2)
220 |
221 | seen_tensors = {}
222 | stack = [x.name for x in outputs]
223 | while stack:
224 | name = stack.pop()
225 | if name in seen_tensors:
226 | continue
227 | seen_tensors[name] = True
228 |
229 | if name in output_to_op:
230 | for op_name in output_to_op[name]:
231 | if op_name in op_to_all:
232 | for input_name in op_to_all[op_name]:
233 | if input_name not in stack:
234 | stack.append(input_name)
235 |
236 | expanded_names = []
237 | if name in assign_groups:
238 | for assign_name in assign_groups[name]:
239 | expanded_names.append(assign_name)
240 |
241 | for expanded_name in expanded_names:
242 | if expanded_name not in stack:
243 | stack.append(expanded_name)
244 |
245 | unreachable_ops = []
246 | for op in graph.get_operations():
247 | is_unreachable = False
248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
249 | for name in all_names:
250 | if name not in seen_tensors:
251 | is_unreachable = True
252 | if is_unreachable:
253 | unreachable_ops.append(op)
254 | return unreachable_ops
255 |
256 | @classmethod
257 | def flatten_recursive(cls, item):
258 | """Flattens (potentially nested) a tuple/dictionary/list to a list."""
259 | output = []
260 | if isinstance(item, list):
261 | output.extend(item)
262 | elif isinstance(item, tuple):
263 | output.extend(list(item))
264 | elif isinstance(item, dict):
265 | for (_, v) in six.iteritems(item):
266 | output.append(v)
267 | else:
268 | return [item]
269 |
270 | flat_output = []
271 | for x in output:
272 | flat_output.extend(cls.flatten_recursive(x))
273 | return flat_output
274 |
275 |
276 | if __name__ == "__main__":
277 | tf.test.main()
278 |
--------------------------------------------------------------------------------
/bert_base/server/helper.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import sys
5 | import uuid
6 | import pickle
7 | import zmq
8 | from zmq.utils import jsonapi
9 |
10 | __all__ = ['set_logger', 'send_ndarray', 'get_args_parser',
11 | 'check_tf_version', 'auto_bind', 'import_tf']
12 |
13 |
14 | def set_logger(context, verbose=False):
15 | #if os.name == 'nt': # for Windows
16 | # return NTLogger(context, verbose)
17 |
18 | logger = logging.getLogger(context)
19 | logger.setLevel(logging.DEBUG if verbose else logging.INFO)
20 | formatter = logging.Formatter(
21 | '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
22 | '%m-%d %H:%M:%S')
23 | console_handler = logging.StreamHandler()
24 | console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
25 | console_handler.setFormatter(formatter)
26 | logger.handlers = []
27 | logger.addHandler(console_handler)
28 | return logger
29 |
30 |
31 | class NTLogger:
32 | def __init__(self, context, verbose):
33 | self.context = context
34 | self.verbose = verbose
35 |
36 | def info(self, msg, **kwargs):
37 | print('I:%s:%s' % (self.context, msg), flush=True)
38 |
39 | def debug(self, msg, **kwargs):
40 | if self.verbose:
41 | print('D:%s:%s' % (self.context, msg), flush=True)
42 |
43 | def error(self, msg, **kwargs):
44 | print('E:%s:%s' % (self.context, msg), flush=True)
45 |
46 | def warning(self, msg, **kwargs):
47 | print('W:%s:%s' % (self.context, msg), flush=True)
48 |
49 |
50 | def send_ndarray(src, dest, X, req_id=b'', flags=0, copy=True, track=False):
51 | """send a numpy array with metadata"""
52 | # md = dict(dtype=str(X.dtype), shape=X.shape)
53 | if type(X) == list and type(X[0]) == dict: # 分类for sink发送消息的处理
54 | md = dict(dtype='json', shape=(len(X[0]['pred_label']), 1))
55 | elif type(X) == dict: # 分类 bertwork 发送消息的处理
56 | md = dict(dtype='json', shape=(len(X['pred_label']), 1))
57 | else:
58 | md = dict(dtype='str', shape=(len(X), len(X[0])))
59 | # print('md', md)
60 | return src.send_multipart([dest, jsonapi.dumps(md), pickle.dumps(X), req_id], flags, copy=copy, track=track)
61 |
62 |
63 | def get_args_parser():
64 | from . import __version__
65 | from .graph import PoolingStrategy
66 |
67 | parser = argparse.ArgumentParser()
68 |
69 | group1 = parser.add_argument_group('File Paths',
70 | 'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model')
71 |
72 | group1.add_argument('-bert_model_dir', type=str, required=True,
73 | help='chinese google bert model path')
74 |
75 | group1.add_argument('-model_dir', type=str, required=True,
76 | help='directory of a pretrained BERT model')
77 | group1.add_argument('-model_pb_dir', type=str, default=None,
78 | help='directory of a pretrained BERT model')
79 |
80 | group1.add_argument('-tuned_model_dir', type=str,
81 | help='directory of a fine-tuned BERT model')
82 | group1.add_argument('-ckpt_name', type=str, default='bert_model.ckpt',
83 | help='filename of the checkpoint file. By default it is "bert_model.ckpt", but \
84 | for a fine-tuned model the name could be different.')
85 | group1.add_argument('-config_name', type=str, default='bert_config.json',
86 | help='filename of the JSON config file for BERT model.')
87 |
88 | group2 = parser.add_argument_group('BERT Parameters',
89 | 'config how BERT model and pooling works')
90 | group2.add_argument('-max_seq_len', type=int, default=128,
91 | help='maximum length of a sequence')
92 | group2.add_argument('-pooling_layer', type=int, nargs='+', default=[-2],
93 | help='the encoder layer(s) that receives pooling. \
94 | Give a list in order to concatenate several layers into one')
95 | group2.add_argument('-pooling_strategy', type=PoolingStrategy.from_string,
96 | default=PoolingStrategy.REDUCE_MEAN, choices=list(PoolingStrategy),
97 | help='the pooling strategy for generating encoding vectors')
98 | group2.add_argument('-mask_cls_sep', action='store_true', default=False,
99 | help='masking the embedding on [CLS] and [SEP] with zero. \
100 | When pooling_strategy is in {CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN} \
101 | then the embedding is preserved, otherwise the embedding is masked to zero before pooling')
102 | group2.add_argument('-lstm_size', type=int, default=128,
103 | help='size of lstm units.')
104 |
105 | group3 = parser.add_argument_group('Serving Configs',
106 | 'config how server utilizes GPU/CPU resources')
107 | group3.add_argument('-port', '-port_in', '-port_data', type=int, default=5555,
108 | help='server port for receiving data from client')
109 | group3.add_argument('-port_out', '-port_result', type=int, default=5556,
110 | help='server port for sending result to client')
111 | group3.add_argument('-http_port', type=int, default=None,
112 | help='server port for receiving HTTP requests')
113 | group3.add_argument('-http_max_connect', type=int, default=10,
114 | help='maximum number of concurrent HTTP connections')
115 | group3.add_argument('-cors', type=str, default='*',
116 | help='setting "Access-Control-Allow-Origin" for HTTP requests')
117 | group3.add_argument('-num_worker', type=int, default=1,
118 | help='number of server instances')
119 | group3.add_argument('-max_batch_size', type=int, default=1024,
120 | help='maximum number of sequences handled by each worker')
121 | group3.add_argument('-priority_batch_size', type=int, default=16,
122 | help='batch smaller than this size will be labeled as high priority,'
123 | 'and jumps forward in the job queue')
124 | group3.add_argument('-cpu', action='store_true', default=False,
125 | help='running on CPU (default on GPU)')
126 | group3.add_argument('-xla', action='store_true', default=False,
127 | help='enable XLA compiler (experimental)')
128 | group3.add_argument('-fp16', action='store_true', default=False,
129 | help='use float16 precision (experimental)')
130 | group3.add_argument('-gpu_memory_fraction', type=float, default=0.5,
131 | help='determine the fraction of the overall amount of memory \
132 | that each visible GPU should be allocated per worker. \
133 | Should be in range [0.0, 1.0]')
134 | group3.add_argument('-device_map', type=int, nargs='+', default=[],
135 | help='specify the list of GPU device ids that will be used (id starts from 0). \
136 | If num_worker > len(device_map), then device will be reused; \
137 | if num_worker < len(device_map), then device_map[:num_worker] will be used')
138 | group3.add_argument('-prefetch_size', type=int, default=10,
139 | help='the number of batches to prefetch on each worker. When running on a CPU-only machine, \
140 | this is set to 0 for comparability')
141 |
142 | parser.add_argument('-verbose', action='store_true', default=False,
143 | help='turn on tensorflow logging for debug')
144 | parser.add_argument('-mode', type=str, default='NER')
145 | parser.add_argument('-version', action='version', version='%(prog)s ' + __version__)
146 | return parser
147 |
148 |
149 | def check_tf_version():
150 | import tensorflow as tf
151 | tf_ver = tf.__version__.split('.')
152 | assert int(tf_ver[0]) >= 1 and int(tf_ver[1]) >= 10, 'Tensorflow >=1.10 is required!'
153 | return tf_ver
154 |
155 |
156 | def import_tf(device_id=-1, verbose=False, use_fp16=False):
157 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id)
158 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3'
159 | os.environ['TF_FP16_MATMUL_USE_FP32_COMPUTE'] = '0' if use_fp16 else '1'
160 | os.environ['TF_FP16_CONV_USE_FP32_COMPUTE'] = '0' if use_fp16 else '1'
161 | import tensorflow as tf
162 | tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR)
163 | return tf
164 |
165 |
166 | def auto_bind(socket):
167 | """
168 | 自动进行端口绑定
169 | :param socket:
170 | :return:
171 | """
172 | if os.name == 'nt': # for Windows
173 | socket.bind_to_random_port('tcp://127.0.0.1')
174 | else:
175 | # Get the location for tmp file for sockets
176 | try:
177 | tmp_dir = os.environ['ZEROMQ_SOCK_TMP_DIR']
178 | if not os.path.exists(tmp_dir):
179 | raise ValueError('This directory for sockets ({}) does not seems to exist.'.format(tmp_dir))
180 | # 随机产生一个
181 | tmp_dir = os.path.join(tmp_dir, str(uuid.uuid1())[:8])
182 | except KeyError:
183 | tmp_dir = '*'
184 |
185 | socket.bind('ipc://{}'.format(tmp_dir))
186 | return socket.getsockopt(zmq.LAST_ENDPOINT).decode('ascii')
187 |
188 |
189 | def get_run_args(parser_fn=get_args_parser, printed=True):
190 | args = parser_fn().parse_args()
191 | if printed:
192 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())])
193 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str))
194 | return args
195 |
196 |
197 | def get_benchmark_parser():
198 | parser = get_args_parser()
199 |
200 | parser.set_defaults(num_client=1, client_batch_size=4096)
201 |
202 | group = parser.add_argument_group('Benchmark parameters', 'config the experiments of the benchmark')
203 |
204 | group.add_argument('-test_client_batch_size', type=int, nargs='*', default=[1, 16, 256, 4096])
205 | group.add_argument('-test_max_batch_size', type=int, nargs='*', default=[8, 32, 128, 512])
206 | group.add_argument('-test_max_seq_len', type=int, nargs='*', default=[32, 64, 128, 256])
207 | group.add_argument('-test_num_client', type=int, nargs='*', default=[1, 4, 16, 64])
208 | group.add_argument('-test_pooling_layer', type=int, nargs='*', default=[[-j] for j in range(1, 13)])
209 |
210 | group.add_argument('-wait_till_ready', type=int, default=30,
211 | help='seconds to wait until server is ready to serve')
212 | group.add_argument('-client_vocab_file', type=str, default='README.md',
213 | help='file path for building client vocabulary')
214 | group.add_argument('-num_repeat', type=int, default=10,
215 | help='number of repeats per experiment (must >2), '
216 | 'as the first two results are omitted for warm-up effect')
217 | return parser
218 |
--------------------------------------------------------------------------------
/bert_base/train/conlleval.py:
--------------------------------------------------------------------------------
1 | # Python version of the evaluation script from CoNLL'00-
2 | # Originates from: https://github.com/spyysalo/conlleval.py
3 |
4 |
5 | # Intentional differences:
6 | # - accept any space as delimiter by default
7 | # - optional file argument (default STDIN)
8 | # - option to set boundary (-b argument)
9 | # - LaTeX output (-l argument) not supported
10 | # - raw tags (-r argument) not supported
11 |
12 | # add function :evaluate(predicted_label, ori_label): which will not read from file
13 |
14 | import sys
15 | import re
16 | import codecs
17 | from collections import defaultdict, namedtuple
18 |
19 | ANY_SPACE = ''
20 |
21 |
22 | class FormatError(Exception):
23 | pass
24 |
25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
26 |
27 |
28 | class EvalCounts(object):
29 | def __init__(self):
30 | self.correct_chunk = 0 # number of correctly identified chunks
31 | self.correct_tags = 0 # number of correct chunk tags
32 | self.found_correct = 0 # number of chunks in corpus
33 | self.found_guessed = 0 # number of identified chunks
34 | self.token_counter = 0 # token counter (ignores sentence breaks)
35 |
36 | # counts by type
37 | self.t_correct_chunk = defaultdict(int)
38 | self.t_found_correct = defaultdict(int)
39 | self.t_found_guessed = defaultdict(int)
40 |
41 |
42 | def parse_args(argv):
43 | import argparse
44 | parser = argparse.ArgumentParser(
45 | description='evaluate tagging results using CoNLL criteria',
46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
47 | )
48 | arg = parser.add_argument
49 | arg('-b', '--boundary', metavar='STR', default='-X-',
50 | help='sentence boundary')
51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
52 | help='character delimiting items in input')
53 | arg('-o', '--otag', metavar='CHAR', default='O',
54 | help='alternative outside tag')
55 | arg('file', nargs='?', default=None)
56 | return parser.parse_args(argv)
57 |
58 |
59 | def parse_tag(t):
60 | m = re.match(r'^([^-]*)-(.*)$', t)
61 | return m.groups() if m else (t, '')
62 |
63 |
64 | def evaluate(iterable, options=None):
65 | if options is None:
66 | options = parse_args([]) # use defaults
67 |
68 | counts = EvalCounts()
69 | num_features = None # number of features per line
70 | in_correct = False # currently processed chunks is correct until now
71 | last_correct = 'O' # previous chunk tag in corpus
72 | last_correct_type = '' # type of previously identified chunk tag
73 | last_guessed = 'O' # previously identified chunk tag
74 | last_guessed_type = '' # type of previous chunk tag in corpus
75 |
76 | for line in iterable:
77 | line = line.rstrip('\r\n')
78 |
79 | if options.delimiter == ANY_SPACE:
80 | features = line.split()
81 | else:
82 | features = line.split(options.delimiter)
83 |
84 | if num_features is None:
85 | num_features = len(features)
86 | elif num_features != len(features) and len(features) != 0:
87 | raise FormatError('unexpected number of features: %d (%d)' %
88 | (len(features), num_features))
89 |
90 | if len(features) == 0 or features[0] == options.boundary:
91 | features = [options.boundary, 'O', 'O']
92 | if len(features) < 3:
93 | raise FormatError('unexpected number of features in line %s' % line)
94 |
95 | guessed, guessed_type = parse_tag(features.pop())
96 | correct, correct_type = parse_tag(features.pop())
97 | first_item = features.pop(0)
98 |
99 | if first_item == options.boundary:
100 | guessed = 'O'
101 |
102 | end_correct = end_of_chunk(last_correct, correct,
103 | last_correct_type, correct_type)
104 | end_guessed = end_of_chunk(last_guessed, guessed,
105 | last_guessed_type, guessed_type)
106 | start_correct = start_of_chunk(last_correct, correct,
107 | last_correct_type, correct_type)
108 | start_guessed = start_of_chunk(last_guessed, guessed,
109 | last_guessed_type, guessed_type)
110 |
111 | if in_correct:
112 | if (end_correct and end_guessed and
113 | last_guessed_type == last_correct_type):
114 | in_correct = False
115 | counts.correct_chunk += 1
116 | counts.t_correct_chunk[last_correct_type] += 1
117 | elif (end_correct != end_guessed or guessed_type != correct_type):
118 | in_correct = False
119 |
120 | if start_correct and start_guessed and guessed_type == correct_type:
121 | in_correct = True
122 |
123 | if start_correct:
124 | counts.found_correct += 1
125 | counts.t_found_correct[correct_type] += 1
126 | if start_guessed:
127 | counts.found_guessed += 1
128 | counts.t_found_guessed[guessed_type] += 1
129 | if first_item != options.boundary:
130 | if correct == guessed and guessed_type == correct_type:
131 | counts.correct_tags += 1
132 | counts.token_counter += 1
133 |
134 | last_guessed = guessed
135 | last_correct = correct
136 | last_guessed_type = guessed_type
137 | last_correct_type = correct_type
138 |
139 | if in_correct:
140 | counts.correct_chunk += 1
141 | counts.t_correct_chunk[last_correct_type] += 1
142 |
143 | return counts
144 |
145 |
146 |
147 | def uniq(iterable):
148 | seen = set()
149 | return [i for i in iterable if not (i in seen or seen.add(i))]
150 |
151 |
152 | def calculate_metrics(correct, guessed, total):
153 | tp, fp, fn = correct, guessed-correct, total-correct
154 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
155 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
156 | f = 0 if p + r == 0 else 2 * p * r / (p + r)
157 | return Metrics(tp, fp, fn, p, r, f)
158 |
159 |
160 | def metrics(counts):
161 | c = counts
162 | overall = calculate_metrics(
163 | c.correct_chunk, c.found_guessed, c.found_correct
164 | )
165 | by_type = {}
166 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
167 | by_type[t] = calculate_metrics(
168 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
169 | )
170 | return overall, by_type
171 |
172 |
173 | def report(counts, out=None):
174 | if out is None:
175 | out = sys.stdout
176 |
177 | overall, by_type = metrics(counts)
178 |
179 | c = counts
180 | out.write('processed %d tokens with %d phrases; ' %
181 | (c.token_counter, c.found_correct))
182 | out.write('found: %d phrases; correct: %d.\n' %
183 | (c.found_guessed, c.correct_chunk))
184 |
185 | if c.token_counter > 0:
186 | out.write('accuracy: %6.2f%%; ' %
187 | (100.*c.correct_tags/c.token_counter))
188 | out.write('precision: %6.2f%%; ' % (100.*overall.prec))
189 | out.write('recall: %6.2f%%; ' % (100.*overall.rec))
190 | out.write('FB1: %6.2f\n' % (100.*overall.fscore))
191 |
192 | for i, m in sorted(by_type.items()):
193 | out.write('%17s: ' % i)
194 | out.write('precision: %6.2f%%; ' % (100.*m.prec))
195 | out.write('recall: %6.2f%%; ' % (100.*m.rec))
196 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
197 |
198 |
199 | def report_notprint(counts, out=None):
200 | if out is None:
201 | out = sys.stdout
202 |
203 | overall, by_type = metrics(counts)
204 |
205 | c = counts
206 | final_report = []
207 | line = []
208 | line.append('processed %d tokens with %d phrases; ' %
209 | (c.token_counter, c.found_correct))
210 | line.append('found: %d phrases; correct: %d.\n' %
211 | (c.found_guessed, c.correct_chunk))
212 | final_report.append("".join(line))
213 |
214 | if c.token_counter > 0:
215 | line = []
216 | line.append('accuracy: %6.2f%%; ' %
217 | (100.*c.correct_tags/c.token_counter))
218 | line.append('precision: %6.2f%%; ' % (100.*overall.prec))
219 | line.append('recall: %6.2f%%; ' % (100.*overall.rec))
220 | line.append('FB1: %6.2f\n' % (100.*overall.fscore))
221 | final_report.append("".join(line))
222 |
223 | for i, m in sorted(by_type.items()):
224 | line = []
225 | line.append('%17s: ' % i)
226 | line.append('precision: %6.2f%%; ' % (100.*m.prec))
227 | line.append('recall: %6.2f%%; ' % (100.*m.rec))
228 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
229 | final_report.append("".join(line))
230 | return final_report
231 |
232 |
233 | def end_of_chunk(prev_tag, tag, prev_type, type_):
234 | # check if a chunk ended between the previous and current word
235 | # arguments: previous and current chunk tags, previous and current types
236 | chunk_end = False
237 |
238 | if prev_tag == 'E': chunk_end = True
239 | if prev_tag == 'S': chunk_end = True
240 |
241 | if prev_tag == 'B' and tag == 'B': chunk_end = True
242 | if prev_tag == 'B' and tag == 'S': chunk_end = True
243 | if prev_tag == 'B' and tag == 'O': chunk_end = True
244 | if prev_tag == 'I' and tag == 'B': chunk_end = True
245 | if prev_tag == 'I' and tag == 'S': chunk_end = True
246 | if prev_tag == 'I' and tag == 'O': chunk_end = True
247 |
248 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
249 | chunk_end = True
250 |
251 | # these chunks are assumed to have length 1
252 | if prev_tag == ']': chunk_end = True
253 | if prev_tag == '[': chunk_end = True
254 |
255 | return chunk_end
256 |
257 |
258 | def start_of_chunk(prev_tag, tag, prev_type, type_):
259 | # check if a chunk started between the previous and current word
260 | # arguments: previous and current chunk tags, previous and current types
261 | chunk_start = False
262 |
263 | if tag == 'B': chunk_start = True
264 | if tag == 'S': chunk_start = True
265 |
266 | if prev_tag == 'E' and tag == 'E': chunk_start = True
267 | if prev_tag == 'E' and tag == 'I': chunk_start = True
268 | if prev_tag == 'S' and tag == 'E': chunk_start = True
269 | if prev_tag == 'S' and tag == 'I': chunk_start = True
270 | if prev_tag == 'O' and tag == 'E': chunk_start = True
271 | if prev_tag == 'O' and tag == 'I': chunk_start = True
272 |
273 | if tag != 'O' and tag != '.' and prev_type != type_:
274 | chunk_start = True
275 |
276 | # these chunks are assumed to have length 1
277 | if tag == '[': chunk_start = True
278 | if tag == ']': chunk_start = True
279 |
280 | return chunk_start
281 |
282 |
283 | def return_report(input_file):
284 | with codecs.open(input_file, "r", "utf8") as f:
285 | counts = evaluate(f)
286 | return report_notprint(counts)
287 |
288 |
289 | def main(argv):
290 | args = parse_args(argv[1:])
291 |
292 | if args.file is None:
293 | counts = evaluate(sys.stdin, args)
294 | else:
295 | with open(args.file) as f:
296 | counts = evaluate(f, args)
297 | report(counts)
298 |
299 | if __name__ == '__main__':
300 | sys.exit(main(sys.argv))
--------------------------------------------------------------------------------
/bert_base/bert/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/bert_base/bert/multilingual.md:
--------------------------------------------------------------------------------
1 | ## Models
2 |
3 | There are two multilingual models currently available. We do not plan to release
4 | more single-language models, but we may release `BERT-Large` versions of these
5 | two in the future:
6 |
7 | * **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
8 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
9 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
10 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
11 | parameters
12 |
13 | See the [list of languages](#list-of-languages) that the Multilingual model
14 | supports. The Multilingual model does include Chinese (and English), but if your
15 | fine-tuning data is Chinese-only, then the Chinese model will likely produce
16 | better results.
17 |
18 | ## Results
19 |
20 | To evaluate these systems, we use the
21 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a
22 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the
23 | dev and test sets have been translated (by humans) into 15 languages. Note that
24 | the training set was *machine* translated (we used the translations provided by
25 | XNLI, not Google NMT). For clarity, we only report on 6 languages below:
26 |
27 |
28 |
29 | | System | English | Chinese | Spanish | German | Arabic | Urdu |
30 | | ------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- |
31 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 |
32 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 |
33 | | BERT -Translate Train | **81.4** | **74.2** | **77.3** | **75.2** | **70.5** | 61.7 |
34 | | BERT - Translate Test | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** |
35 | | BERT - Zero Shot | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 |
36 |
37 |
38 |
39 | The first two rows are baselines from the XNLI paper and the last three rows are
40 | our results with BERT.
41 |
42 | **Translate Train** means that the MultiNLI training set was machine translated
43 | from English into the foreign language. So training and evaluation were both
44 | done in the foreign language. Unfortunately, training was done on
45 | machine-translated data, so it is impossible to quantify how much of the lower
46 | accuracy (compared to English) is due to the quality of the machine translation
47 | vs. the quality of the pre-trained model.
48 |
49 | **Translate Test** means that the XNLI test set was machine translated from the
50 | foreign language into English. So training and evaluation were both done on
51 | English. However, test evaluation was done on machine-translated English, so the
52 | accuracy depends on the quality of the machine translation system.
53 |
54 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English
55 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case,
56 | machine translation was not involved at all in either the pre-training or
57 | fine-tuning.
58 |
59 | Note that the English result is worse than the 84.2 MultiNLI baseline because
60 | this training used Multilingual BERT rather than English-only BERT. This implies
61 | that for high-resource languages, the Multilingual model is somewhat worse than
62 | a single-language model. However, it is not feasible for us to train and
63 | maintain dozens of single-language model. Therefore, if your goal is to maximize
64 | performance with a language other than English or Chinese, you might find it
65 | beneficial to run pre-training for additional steps starting from our
66 | Multilingual model on data from your language of interest.
67 |
68 | Here is a comparison of training Chinese models with the Multilingual
69 | `BERT-Base` and Chinese-only `BERT-Base`:
70 |
71 | System | Chinese
72 | ----------------------- | -------
73 | XNLI Baseline | 67.0
74 | BERT Multilingual Model | 74.2
75 | BERT Chinese-only Model | 77.2
76 |
77 | Similar to English, the single-language model does 3% better than the
78 | Multilingual model.
79 |
80 | ## Fine-tuning Example
81 |
82 | The multilingual model does **not** require any special consideration or API
83 | changes. We did update the implementation of `BasicTokenizer` in
84 | `tokenization.py` to support Chinese character tokenization, so please update if
85 | you forked it. However, we did not change the tokenization API.
86 |
87 | To test the new models, we did modify `run_classifier.py` to add support for the
88 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language
89 | version of MultiNLI where the dev/test sets have been human-translated, and the
90 | training set has been machine-translated.
91 |
92 | To run the fine-tuning code, please download the
93 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the
94 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip)
95 | and then unpack both .zip files into some directory `$XNLI_DIR`.
96 |
97 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py`
98 | (Chinese by default), so please modify `XnliProcessor` if you want to run on
99 | another language.
100 |
101 | This is a large dataset, so this will training will take a few hours on a GPU
102 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for
103 | debugging, just set `num_train_epochs` to a small value like `0.1`.
104 |
105 | ```shell
106 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12
107 | export XNLI_DIR=/path/to/xnli
108 |
109 | python run_classifier.py \
110 | --task_name=XNLI \
111 | --do_train=true \
112 | --do_eval=true \
113 | --data_dir=$XNLI_DIR \
114 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
115 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
116 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
117 | --max_seq_length=128 \
118 | --train_batch_size=32 \
119 | --learning_rate=5e-5 \
120 | --num_train_epochs=2.0 \
121 | --output_dir=/tmp/xnli_output/
122 | ```
123 |
124 | With the Chinese-only model, the results should look something like this:
125 |
126 | ```
127 | ***** Eval results *****
128 | eval_accuracy = 0.774116
129 | eval_loss = 0.83554
130 | global_step = 24543
131 | loss = 0.74603
132 | ```
133 |
134 | ## Details
135 |
136 | ### Data Source and Sampling
137 |
138 | The languages chosen were the
139 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias).
140 | The entire Wikipedia dump for each language (excluding user and talk pages) was
141 | taken as the training data for each language
142 |
143 | However, the size of the Wikipedia for a given language varies greatly, and
144 | therefore low-resource languages may be "under-represented" in terms of the
145 | neural network model (under the assumption that languages are "competing" for
146 | limited model capacity to some extent).
147 |
148 | However, the size of a Wikipedia also correlates with the number of speakers of
149 | a language, and we also don't want to overfit the model by performing thousands
150 | of epochs over a tiny Wikipedia for a particular language.
151 |
152 | To balance these two factors, we performed exponentially smoothed weighting of
153 | the data during pre-training data creation (and WordPiece vocab creation). In
154 | other words, let's say that the probability of a language is *P(L)*, e.g.,
155 | *P(English) = 0.21* means that after concatenating all of the Wikipedias
156 | together, 21% of our data is English. We exponentiate each probability by some
157 | factor *S* and then re-normalize, and sample from that distribution. In our case
158 | we use *S=0.7*. So, high-resource languages like English will be under-sampled,
159 | and low-resource languages like Icelandic will be over-sampled. E.g., in the
160 | original distribution English would be sampled 1000x more than Icelandic, but
161 | after smoothing it's only sampled 100x more.
162 |
163 | ### Tokenization
164 |
165 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are
166 | weighted the same way as the data, so low-resource languages are upweighted by
167 | some factor. We intentionally do *not* use any marker to denote the input
168 | language (so that zero-shot training can work).
169 |
170 | Because Chinese does not have whitespace characters, we add spaces around every
171 | character in the
172 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\))
173 | before applying WordPiece. This means that Chinese is effectively
174 | character-tokenized. Note that the CJK Unicode block only includes
175 | Chinese-origin characters and does *not* include Hangul Korean or
176 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like
177 | all other languages.
178 |
179 | For all other languages, we apply the
180 | [same recipe as English](https://github.com/google-research/bert#tokenization):
181 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace
182 | tokenization. We understand that accent markers have substantial meaning in some
183 | languages, but felt that the benefits of reducing the effective vocabulary make
184 | up for this. Generally the strong contextual models of BERT should make up for
185 | any ambiguity introduced by stripping accent markers.
186 |
187 | ### List of Languages
188 |
189 | The multilingual model supports the following languages. These languages were
190 | chosen because they are the top 100 languages with the largest Wikipedias:
191 |
192 | * Afrikaans
193 | * Albanian
194 | * Arabic
195 | * Aragonese
196 | * Armenian
197 | * Asturian
198 | * Azerbaijani
199 | * Bashkir
200 | * Basque
201 | * Bavarian
202 | * Belarusian
203 | * Bengali
204 | * Bishnupriya Manipuri
205 | * Bosnian
206 | * Breton
207 | * Bulgarian
208 | * Burmese
209 | * Catalan
210 | * Cebuano
211 | * Chechen
212 | * Chinese (Simplified)
213 | * Chinese (Traditional)
214 | * Chuvash
215 | * Croatian
216 | * Czech
217 | * Danish
218 | * Dutch
219 | * English
220 | * Estonian
221 | * Finnish
222 | * French
223 | * Galician
224 | * Georgian
225 | * German
226 | * Greek
227 | * Gujarati
228 | * Haitian
229 | * Hebrew
230 | * Hindi
231 | * Hungarian
232 | * Icelandic
233 | * Ido
234 | * Indonesian
235 | * Irish
236 | * Italian
237 | * Japanese
238 | * Javanese
239 | * Kannada
240 | * Kazakh
241 | * Kirghiz
242 | * Korean
243 | * Latin
244 | * Latvian
245 | * Lithuanian
246 | * Lombard
247 | * Low Saxon
248 | * Luxembourgish
249 | * Macedonian
250 | * Malagasy
251 | * Malay
252 | * Malayalam
253 | * Marathi
254 | * Minangkabau
255 | * Nepali
256 | * Newar
257 | * Norwegian (Bokmal)
258 | * Norwegian (Nynorsk)
259 | * Occitan
260 | * Persian (Farsi)
261 | * Piedmontese
262 | * Polish
263 | * Portuguese
264 | * Punjabi
265 | * Romanian
266 | * Russian
267 | * Scots
268 | * Serbian
269 | * Serbo-Croatian
270 | * Sicilian
271 | * Slovak
272 | * Slovenian
273 | * South Azerbaijani
274 | * Spanish
275 | * Sundanese
276 | * Swahili
277 | * Swedish
278 | * Tagalog
279 | * Tajik
280 | * Tamil
281 | * Tatar
282 | * Telugu
283 | * Turkish
284 | * Ukrainian
285 | * Urdu
286 | * Uzbek
287 | * Vietnamese
288 | * Volapük
289 | * Waray-Waray
290 | * Welsh
291 | * West
292 | * Western Punjabi
293 | * Yoruba
294 |
295 | The only language which we had to unfortunately exclude was Thai, since it is
296 | the only language (other than Chinese) that does not use whitespace to delimit
297 | words, and it has too many characters-per-word to use character-based
298 | tokenization. Our WordPiece algorithm is quadratic with respect to the size of
299 | the input token so very long character strings do not work with it.
300 |
--------------------------------------------------------------------------------
/bert_base/bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import six
24 | import tensorflow as tf
25 |
26 |
27 | def convert_to_unicode(text):
28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
29 | if six.PY3:
30 | if isinstance(text, str):
31 | return text
32 | elif isinstance(text, bytes):
33 | return text.decode("utf-8", "ignore")
34 | else:
35 | raise ValueError("Unsupported string type: %s" % (type(text)))
36 | elif six.PY2:
37 | if isinstance(text, str):
38 | return text.decode("utf-8", "ignore")
39 | elif isinstance(text, unicode):
40 | return text
41 | else:
42 | raise ValueError("Unsupported string type: %s" % (type(text)))
43 | else:
44 | raise ValueError("Not running on Python2 or Python 3?")
45 |
46 |
47 | def printable_text(text):
48 | """Returns text encoded in a way suitable for print or `tf.logging`."""
49 |
50 | # These functions want `str` for both Python2 and Python3, but in one case
51 | # it's a Unicode string and in the other it's a byte string.
52 | if six.PY3:
53 | if isinstance(text, str):
54 | return text
55 | elif isinstance(text, bytes):
56 | return text.decode("utf-8", "ignore")
57 | else:
58 | raise ValueError("Unsupported string type: %s" % (type(text)))
59 | elif six.PY2:
60 | if isinstance(text, str):
61 | return text
62 | elif isinstance(text, unicode):
63 | return text.encode("utf-8")
64 | else:
65 | raise ValueError("Unsupported string type: %s" % (type(text)))
66 | else:
67 | raise ValueError("Not running on Python2 or Python 3?")
68 |
69 |
70 | def load_vocab(vocab_file):
71 | """Loads a vocabulary file into a dictionary."""
72 | vocab = collections.OrderedDict()
73 | index = 0
74 | with tf.gfile.GFile(vocab_file, "r") as reader:
75 | while True:
76 | token = convert_to_unicode(reader.readline())
77 | if not token:
78 | break
79 | token = token.strip()
80 | vocab[token] = index
81 | index += 1
82 | return vocab
83 |
84 |
85 | def convert_by_vocab(vocab, items):
86 | """Converts a sequence of [tokens|ids] using the vocab."""
87 | output = []
88 | for item in items:
89 | #TODO: modify for oov, using [unk] replace, if you using english language do not change this
90 | # output.append(vocab.[item])
91 | output.append(vocab.get(item, 100))
92 | return output
93 |
94 |
95 | def convert_tokens_to_ids(vocab, tokens):
96 | return convert_by_vocab(vocab, tokens)
97 |
98 |
99 | def convert_ids_to_tokens(inv_vocab, ids):
100 | return convert_by_vocab(inv_vocab, ids)
101 |
102 |
103 | def whitespace_tokenize(text):
104 | """Runs basic whitespace cleaning and splitting on a peice of text."""
105 | text = text.strip()
106 | if not text:
107 | return []
108 | tokens = text.split()
109 | return tokens
110 |
111 |
112 | class FullTokenizer(object):
113 | """Runs end-to-end tokenziation."""
114 |
115 | def __init__(self, vocab_file, do_lower_case=True):
116 | self.vocab = load_vocab(vocab_file)
117 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
118 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
119 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
120 |
121 | def tokenize(self, text):
122 | split_tokens = []
123 | for token in self.basic_tokenizer.tokenize(text):
124 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
125 | split_tokens.append(sub_token)
126 |
127 | return split_tokens
128 |
129 | def convert_tokens_to_ids(self, tokens):
130 | return convert_by_vocab(self.vocab, tokens)
131 |
132 | def convert_ids_to_tokens(self, ids):
133 | return convert_by_vocab(self.inv_vocab, ids)
134 |
135 |
136 | class BasicTokenizer(object):
137 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
138 |
139 | def __init__(self, do_lower_case=True):
140 | """Constructs a BasicTokenizer.
141 |
142 | Args:
143 | do_lower_case: Whether to lower case the input.
144 | """
145 | self.do_lower_case = do_lower_case
146 |
147 | def tokenize(self, text):
148 | """Tokenizes a piece of text."""
149 | text = convert_to_unicode(text)
150 | text = self._clean_text(text)
151 |
152 | # This was added on November 1st, 2018 for the multilingual and Chinese
153 | # models. This is also applied to the English models now, but it doesn't
154 | # matter since the English models were not trained on any Chinese data
155 | # and generally don't have any Chinese data in them (there are Chinese
156 | # characters in the vocabulary because Wikipedia does have some Chinese
157 | # words in the English Wikipedia.).
158 | text = self._tokenize_chinese_chars(text)
159 |
160 | orig_tokens = whitespace_tokenize(text)
161 | split_tokens = []
162 | for token in orig_tokens:
163 | if self.do_lower_case:
164 | token = token.lower()
165 | token = self._run_strip_accents(token)
166 | split_tokens.extend(self._run_split_on_punc(token))
167 |
168 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
169 | return output_tokens
170 |
171 | def _run_strip_accents(self, text):
172 | """Strips accents from a piece of text."""
173 | text = unicodedata.normalize("NFD", text)
174 | output = []
175 | for char in text:
176 | cat = unicodedata.category(char)
177 | if cat == "Mn":
178 | continue
179 | output.append(char)
180 | return "".join(output)
181 |
182 | def _run_split_on_punc(self, text):
183 | """Splits punctuation on a piece of text."""
184 | chars = list(text)
185 | i = 0
186 | start_new_word = True
187 | output = []
188 | while i < len(chars):
189 | char = chars[i]
190 | if _is_punctuation(char):
191 | output.append([char])
192 | start_new_word = True
193 | else:
194 | if start_new_word:
195 | output.append([])
196 | start_new_word = False
197 | output[-1].append(char)
198 | i += 1
199 |
200 | return ["".join(x) for x in output]
201 |
202 | def _tokenize_chinese_chars(self, text):
203 | """Adds whitespace around any CJK character."""
204 | output = []
205 | for char in text:
206 | cp = ord(char)
207 | if self._is_chinese_char(cp):
208 | output.append(" ")
209 | output.append(char)
210 | output.append(" ")
211 | else:
212 | output.append(char)
213 | return "".join(output)
214 |
215 | def _is_chinese_char(self, cp):
216 | """Checks whether CP is the codepoint of a CJK character."""
217 | # This defines a "chinese character" as anything in the CJK Unicode block:
218 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
219 | #
220 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
221 | # despite its name. The modern Korean Hangul alphabet is a different block,
222 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
223 | # space-separated words, so they are not treated specially and handled
224 | # like the all of the other languages.
225 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
226 | (cp >= 0x3400 and cp <= 0x4DBF) or #
227 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
228 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
229 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
230 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
231 | (cp >= 0xF900 and cp <= 0xFAFF) or #
232 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
233 | return True
234 |
235 | return False
236 |
237 | def _clean_text(self, text):
238 | """Performs invalid character removal and whitespace cleanup on text."""
239 | output = []
240 | for char in text:
241 | cp = ord(char)
242 | if cp == 0 or cp == 0xfffd or _is_control(char):
243 | continue
244 | if _is_whitespace(char):
245 | output.append(" ")
246 | else:
247 | output.append(char)
248 | return "".join(output)
249 |
250 |
251 | class WordpieceTokenizer(object):
252 | """Runs WordPiece tokenziation."""
253 |
254 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
255 | self.vocab = vocab
256 | self.unk_token = unk_token
257 | self.max_input_chars_per_word = max_input_chars_per_word
258 |
259 | def tokenize(self, text):
260 | """Tokenizes a piece of text into its word pieces.
261 |
262 | This uses a greedy longest-match-first algorithm to perform tokenization
263 | using the given vocabulary.
264 |
265 | For example:
266 | input = "unaffable"
267 | output = ["un", "##aff", "##able"]
268 |
269 | Args:
270 | text: A single token or whitespace separated tokens. This should have
271 | already been passed through `BasicTokenizer.
272 |
273 | Returns:
274 | A list of wordpiece tokens.
275 | """
276 |
277 | text = convert_to_unicode(text)
278 |
279 | output_tokens = []
280 | for token in whitespace_tokenize(text):
281 | chars = list(token)
282 | if len(chars) > self.max_input_chars_per_word:
283 | output_tokens.append(self.unk_token)
284 | continue
285 |
286 | is_bad = False
287 | start = 0
288 | sub_tokens = []
289 | while start < len(chars):
290 | end = len(chars)
291 | cur_substr = None
292 | while start < end:
293 | substr = "".join(chars[start:end])
294 | if start > 0:
295 | substr = "##" + substr
296 | if substr in self.vocab:
297 | cur_substr = substr
298 | break
299 | end -= 1
300 | if cur_substr is None:
301 | is_bad = True
302 | break
303 | sub_tokens.append(cur_substr)
304 | start = end
305 |
306 | if is_bad:
307 | output_tokens.append(self.unk_token)
308 | else:
309 | output_tokens.extend(sub_tokens)
310 | return output_tokens
311 |
312 |
313 | def _is_whitespace(char):
314 | """Checks whether `chars` is a whitespace character."""
315 | # \t, \n, and \r are technically contorl characters but we treat them
316 | # as whitespace since they are generally considered as such.
317 | if char == " " or char == "\t" or char == "\n" or char == "\r":
318 | return True
319 | cat = unicodedata.category(char)
320 | if cat == "Zs":
321 | return True
322 | return False
323 |
324 |
325 | def _is_control(char):
326 | """Checks whether `chars` is a control character."""
327 | # These are technically control characters but we count them as whitespace
328 | # characters.
329 | if char == "\t" or char == "\n" or char == "\r":
330 | return False
331 | cat = unicodedata.category(char)
332 | if cat.startswith("C"):
333 | return True
334 | return False
335 |
336 |
337 | def _is_punctuation(char):
338 | """Checks whether `chars` is a punctuation character."""
339 | cp = ord(char)
340 | # We treat all non-letter/number ASCII as punctuation.
341 | # Characters such as "^", "$", and "`" are not in the Unicode
342 | # Punctuation class but we treat them as punctuation anyways, for
343 | # consistency.
344 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
345 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
346 | return True
347 | cat = unicodedata.category(char)
348 | if cat.startswith("P"):
349 | return True
350 | return False
351 |
352 |
--------------------------------------------------------------------------------
/terminal_predict.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 | """
4 | 基于命令行的在线预测方法
5 | @Author: Macan (ma_cancan@163.com)
6 | """
7 |
8 | import tensorflow as tf
9 | import numpy as np
10 | import codecs
11 | import pickle
12 | import os
13 | from datetime import datetime
14 |
15 | from bert_base.train.models import create_model, InputFeatures
16 | from bert_base.bert import tokenization, modeling
17 | from bert_base.train.train_helper import get_args_parser
18 | args = get_args_parser()
19 |
20 | model_dir = r'C:\Users\C\Documents\Tencent Files\389631699\FileRecv\semi_corpus_people_2014'
21 | bert_dir = 'F:\chinese_L-12_H-768_A-12'
22 |
23 | is_training=False
24 | use_one_hot_embeddings=False
25 | batch_size=1
26 |
27 | gpu_config = tf.ConfigProto()
28 | gpu_config.gpu_options.allow_growth = True
29 | sess=tf.Session(config=gpu_config)
30 | model=None
31 |
32 | global graph
33 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None
34 |
35 |
36 | print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
37 | if not os.path.exists(os.path.join(model_dir, "checkpoint")):
38 | raise Exception("failed to get checkpoint. going to return ")
39 |
40 | # 加载label->id的词典
41 | with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
42 | label2id = pickle.load(rf)
43 | id2label = {value: key for key, value in label2id.items()}
44 |
45 | with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
46 | label_list = pickle.load(rf)
47 | num_labels = len(label_list) + 1
48 |
49 | graph = tf.get_default_graph()
50 | with graph.as_default():
51 | print("going to restore checkpoint")
52 | #sess.run(tf.global_variables_initializer())
53 | input_ids_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length], name="input_ids")
54 | input_mask_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length], name="input_mask")
55 |
56 | bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
57 | (total_loss, logits, trans, pred_ids) = create_model(
58 | bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None,
59 | labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)
60 |
61 | saver = tf.train.Saver()
62 | saver.restore(sess, tf.train.latest_checkpoint(model_dir))
63 |
64 |
65 | tokenizer = tokenization.FullTokenizer(
66 | vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case)
67 |
68 |
69 | def predict_online():
70 | """
71 | do online prediction. each time make prediction for one instance.
72 | you can change to a batch if you want.
73 |
74 | :param line: a list. element is: [dummy_label,text_a,text_b]
75 | :return:
76 | """
77 | def convert(line):
78 | feature = convert_single_example(0, line, label_list, args.max_seq_length, tokenizer, 'p')
79 | input_ids = np.reshape([feature.input_ids],(batch_size, args.max_seq_length))
80 | input_mask = np.reshape([feature.input_mask],(batch_size, args.max_seq_length))
81 | segment_ids = np.reshape([feature.segment_ids],(batch_size, args.max_seq_length))
82 | label_ids =np.reshape([feature.label_ids],(batch_size, args.max_seq_length))
83 | return input_ids, input_mask, segment_ids, label_ids
84 |
85 | global graph
86 | with graph.as_default():
87 | print(id2label)
88 | while True:
89 | print('input the test sentence:')
90 | sentence = str(input())
91 | start = datetime.now()
92 | if len(sentence) < 2:
93 | print(sentence)
94 | continue
95 | sentence = tokenizer.tokenize(sentence)
96 | # print('your input is:{}'.format(sentence))
97 | input_ids, input_mask, segment_ids, label_ids = convert(sentence)
98 |
99 | feed_dict = {input_ids_p: input_ids,
100 | input_mask_p: input_mask}
101 | # run session get current feed_dict result
102 | pred_ids_result = sess.run([pred_ids], feed_dict)
103 | pred_label_result = convert_id_to_label(pred_ids_result, id2label)
104 | print(pred_label_result)
105 | #todo: 组合策略
106 | result = strage_combined_link_org_loc(sentence, pred_label_result[0])
107 | print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
108 |
109 | def convert_id_to_label(pred_ids_result, idx2label):
110 | """
111 | 将id形式的结果转化为真实序列结果
112 | :param pred_ids_result:
113 | :param idx2label:
114 | :return:
115 | """
116 | result = []
117 | for row in range(batch_size):
118 | curr_seq = []
119 | for ids in pred_ids_result[row][0]:
120 | if ids == 0:
121 | break
122 | curr_label = idx2label[ids]
123 | if curr_label in ['[CLS]', '[SEP]']:
124 | continue
125 | curr_seq.append(curr_label)
126 | result.append(curr_seq)
127 | return result
128 |
129 |
130 |
131 | def strage_combined_link_org_loc(tokens, tags):
132 | """
133 | 组合策略
134 | :param pred_label_result:
135 | :param types:
136 | :return:
137 | """
138 | def print_output(data, type):
139 | line = []
140 | line.append(type)
141 | for i in data:
142 | line.append(i.word)
143 | print(', '.join(line))
144 |
145 | params = None
146 | eval = Result(params)
147 | if len(tokens) > len(tags):
148 | tokens = tokens[:len(tags)]
149 | person, loc, org = eval.get_result(tokens, tags)
150 | print_output(loc, 'LOC')
151 | print_output(person, 'PER')
152 | print_output(org, 'ORG')
153 |
154 |
155 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode):
156 | """
157 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
158 | :param ex_index: index
159 | :param example: 一个样本
160 | :param label_list: 标签列表
161 | :param max_seq_length:
162 | :param tokenizer:
163 | :param mode:
164 | :return:
165 | """
166 | label_map = {}
167 | # 1表示从1开始对label进行index化
168 | for (i, label) in enumerate(label_list, 1):
169 | label_map[label] = i
170 | # 保存label->index 的map
171 | if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
172 | with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
173 | pickle.dump(label_map, w)
174 |
175 | tokens = example
176 | # tokens = tokenizer.tokenize(example.text)
177 | # 序列截断
178 | if len(tokens) >= max_seq_length - 1:
179 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
180 | ntokens = []
181 | segment_ids = []
182 | label_ids = []
183 | ntokens.append("[CLS]") # 句子开始设置CLS 标志
184 | segment_ids.append(0)
185 | # append("O") or append("[CLS]") not sure!
186 | label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
187 | for i, token in enumerate(tokens):
188 | ntokens.append(token)
189 | segment_ids.append(0)
190 | label_ids.append(0)
191 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志
192 | segment_ids.append(0)
193 | # append("O") or append("[SEP]") not sure!
194 | label_ids.append(label_map["[SEP]"])
195 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式
196 | input_mask = [1] * len(input_ids)
197 |
198 | # padding, 使用
199 | while len(input_ids) < max_seq_length:
200 | input_ids.append(0)
201 | input_mask.append(0)
202 | segment_ids.append(0)
203 | # we don't concerned about it!
204 | label_ids.append(0)
205 | ntokens.append("**NULL**")
206 | # label_mask.append(0)
207 | # print(len(input_ids))
208 | assert len(input_ids) == max_seq_length
209 | assert len(input_mask) == max_seq_length
210 | assert len(segment_ids) == max_seq_length
211 | assert len(label_ids) == max_seq_length
212 | # assert len(label_mask) == max_seq_length
213 |
214 | # 结构化为一个类
215 | feature = InputFeatures(
216 | input_ids=input_ids,
217 | input_mask=input_mask,
218 | segment_ids=segment_ids,
219 | label_ids=label_ids,
220 | # label_mask = label_mask
221 | )
222 | return feature
223 |
224 |
225 | class Pair(object):
226 | def __init__(self, word, start, end, type, merge=False):
227 | self.__word = word
228 | self.__start = start
229 | self.__end = end
230 | self.__merge = merge
231 | self.__types = type
232 |
233 | @property
234 | def start(self):
235 | return self.__start
236 | @property
237 | def end(self):
238 | return self.__end
239 | @property
240 | def merge(self):
241 | return self.__merge
242 | @property
243 | def word(self):
244 | return self.__word
245 |
246 | @property
247 | def types(self):
248 | return self.__types
249 | @word.setter
250 | def word(self, word):
251 | self.__word = word
252 | @start.setter
253 | def start(self, start):
254 | self.__start = start
255 | @end.setter
256 | def end(self, end):
257 | self.__end = end
258 | @merge.setter
259 | def merge(self, merge):
260 | self.__merge = merge
261 |
262 | @types.setter
263 | def types(self, type):
264 | self.__types = type
265 |
266 | def __str__(self) -> str:
267 | line = []
268 | line.append('entity:{}'.format(self.__word))
269 | line.append('start:{}'.format(self.__start))
270 | line.append('end:{}'.format(self.__end))
271 | line.append('merge:{}'.format(self.__merge))
272 | line.append('types:{}'.format(self.__types))
273 | return '\t'.join(line)
274 |
275 |
276 | class Result(object):
277 | def __init__(self, config):
278 | self.config = config
279 | self.person = []
280 | self.loc = []
281 | self.org = []
282 | self.others = []
283 | def get_result(self, tokens, tags, config=None):
284 | # 先获取标注结果
285 | self.result_to_json(tokens, tags)
286 | return self.person, self.loc, self.org
287 |
288 | def result_to_json(self, string, tags):
289 | """
290 | 将模型标注序列和输入序列结合 转化为结果
291 | :param string: 输入序列
292 | :param tags: 标注结果
293 | :return:
294 | """
295 | item = {"entities": []}
296 | entity_name = ""
297 | entity_start = 0
298 | idx = 0
299 | last_tag = ''
300 |
301 | for char, tag in zip(string, tags):
302 | if tag[0] == "S":
303 | self.append(char, idx, idx+1, tag[2:])
304 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]})
305 | elif tag[0] == "B":
306 | if entity_name != '':
307 | self.append(entity_name, entity_start, idx, last_tag[2:])
308 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
309 | entity_name = ""
310 | entity_name += char
311 | entity_start = idx
312 | elif tag[0] == "I":
313 | entity_name += char
314 | elif tag[0] == "O":
315 | if entity_name != '':
316 | self.append(entity_name, entity_start, idx, last_tag[2:])
317 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
318 | entity_name = ""
319 | else:
320 | entity_name = ""
321 | entity_start = idx
322 | idx += 1
323 | last_tag = tag
324 | if entity_name != '':
325 | self.append(entity_name, entity_start, idx, last_tag[2:])
326 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
327 | return item
328 |
329 | def append(self, word, start, end, tag):
330 | if tag == 'LOC':
331 | self.loc.append(Pair(word, start, end, 'LOC'))
332 | elif tag == 'PER':
333 | self.person.append(Pair(word, start, end, 'PER'))
334 | elif tag == 'ORG':
335 | self.org.append(Pair(word, start, end, 'ORG'))
336 | else:
337 | self.others.append(Pair(word, start, end, tag))
338 |
339 |
340 | if __name__ == "__main__":
341 | predict_online()
342 |
343 |
--------------------------------------------------------------------------------
/bert_base/train/conlleval.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 | # conlleval: evaluate result of processing CoNLL-2000 shared task
3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html
5 | # options: l: generate LaTeX output for tables like in
6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex
7 | # r: accept raw result tags (without B- and I- prefix;
8 | # assumes one word per chunk)
9 | # d: alternative delimiter tag (default is single space)
10 | # o: alternative outside tag (default is O)
11 | # note: the file should contain lines with items separated
12 | # by $delimiter characters (default space). The final
13 | # two items should contain the correct tag and the
14 | # guessed tag in that order. Sentences should be
15 | # separated from each other by empty lines or lines
16 | # with $boundary fields (default -X-).
17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/
18 | # started: 1998-09-25
19 | # version: 2004-01-26
20 | # author: Erik Tjong Kim Sang
21 |
22 | use strict;
23 |
24 | my $false = 0;
25 | my $true = 42;
26 |
27 | my $boundary = "-X-"; # sentence boundary
28 | my $correct; # current corpus chunk tag (I,O,B)
29 | my $correctChunk = 0; # number of correctly identified chunks
30 | my $correctTags = 0; # number of correct chunk tags
31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32 | my $delimiter = " "; # field delimiter
33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34 | my $firstItem; # first feature (for sentence boundary checks)
35 | my $foundCorrect = 0; # number of chunks in corpus
36 | my $foundGuessed = 0; # number of identified chunks
37 | my $guessed; # current guessed chunk tag
38 | my $guessedType; # type of current guessed chunk tag
39 | my $i; # miscellaneous counter
40 | my $inCorrect = $false; # currently processed chunk is correct until now
41 | my $lastCorrect = "O"; # previous chunk tag in corpus
42 | my $latex = 0; # generate LaTeX formatted output
43 | my $lastCorrectType = ""; # type of previously identified chunk tag
44 | my $lastGuessed = "O"; # previously identified chunk tag
45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus
46 | my $lastType; # temporary storage for detecting duplicates
47 | my $line; # line
48 | my $nbrOfFeatures = -1; # number of features per line
49 | my $precision = 0.0; # precision score
50 | my $oTag = "O"; # outside tag, default O
51 | my $raw = 0; # raw input: add B to every token
52 | my $recall = 0.0; # recall score
53 | my $tokenCounter = 0; # token counter (ignores sentence breaks)
54 |
55 | my %correctChunk = (); # number of correctly identified chunks per type
56 | my %foundCorrect = (); # number of chunks in corpus per type
57 | my %foundGuessed = (); # number of identified chunks per type
58 |
59 | my @features; # features on line
60 | my @sortedTypes; # sorted list of chunk type names
61 |
62 | # sanity check
63 | while (@ARGV and $ARGV[0] =~ /^-/) {
64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66 | elsif ($ARGV[0] eq "-d") {
67 | shift(@ARGV);
68 | if (not defined $ARGV[0]) {
69 | die "conlleval: -d requires delimiter character";
70 | }
71 | $delimiter = shift(@ARGV);
72 | } elsif ($ARGV[0] eq "-o") {
73 | shift(@ARGV);
74 | if (not defined $ARGV[0]) {
75 | die "conlleval: -o requires delimiter character";
76 | }
77 | $oTag = shift(@ARGV);
78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79 | }
80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81 | # process input
82 | while () {
83 | chomp($line = $_);
84 | @features = split(/$delimiter/,$line);
85 | # @features = split(/\t/,$line);
86 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
87 | elsif ($nbrOfFeatures != $#features and @features != 0) {
88 | printf STDERR "unexpected number of features: %d (%d)\n",
89 | $#features+1,$nbrOfFeatures+1;
90 | exit(1);
91 | }
92 | if (@features == 0 or
93 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
94 | if (@features < 2) {
95 | printf STDERR "feature length is %d. \n", @features;
96 | die "conlleval: unexpected number of features in line $line\n";
97 | }
98 | if ($raw) {
99 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
100 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
101 | if ($features[$#features] ne "O") {
102 | $features[$#features] = "B-$features[$#features]";
103 | }
104 | if ($features[$#features-1] ne "O") {
105 | $features[$#features-1] = "B-$features[$#features-1]";
106 | }
107 | }
108 | # 20040126 ET code which allows hyphens in the types
109 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
110 | $guessed = $1;
111 | $guessedType = $2;
112 | } else {
113 | $guessed = $features[$#features];
114 | $guessedType = "";
115 | }
116 | pop(@features);
117 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
118 | $correct = $1;
119 | $correctType = $2;
120 | } else {
121 | $correct = $features[$#features];
122 | $correctType = "";
123 | }
124 | pop(@features);
125 | # ($guessed,$guessedType) = split(/-/,pop(@features));
126 | # ($correct,$correctType) = split(/-/,pop(@features));
127 | $guessedType = $guessedType ? $guessedType : "";
128 | $correctType = $correctType ? $correctType : "";
129 | $firstItem = shift(@features);
130 |
131 | # 1999-06-26 sentence breaks should always be counted as out of chunk
132 | if ( $firstItem eq $boundary ) { $guessed = "O"; }
133 |
134 | if ($inCorrect) {
135 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
136 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
137 | $lastGuessedType eq $lastCorrectType) {
138 | $inCorrect=$false;
139 | $correctChunk++;
140 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
141 | $correctChunk{$lastCorrectType}+1 : 1;
142 | } elsif (
143 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
144 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
145 | $guessedType ne $correctType ) {
146 | $inCorrect=$false;
147 | }
148 | }
149 |
150 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
151 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
152 | $guessedType eq $correctType) { $inCorrect = $true; }
153 |
154 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
155 | $foundCorrect++;
156 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
157 | $foundCorrect{$correctType}+1 : 1;
158 | }
159 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
160 | $foundGuessed++;
161 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
162 | $foundGuessed{$guessedType}+1 : 1;
163 | }
164 | if ( $firstItem ne $boundary ) {
165 | if ( $correct eq $guessed and $guessedType eq $correctType ) {
166 | $correctTags++;
167 | }
168 | $tokenCounter++;
169 | }
170 |
171 | $lastGuessed = $guessed;
172 | $lastCorrect = $correct;
173 | $lastGuessedType = $guessedType;
174 | $lastCorrectType = $correctType;
175 | }
176 | if ($inCorrect) {
177 | $correctChunk++;
178 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
179 | $correctChunk{$lastCorrectType}+1 : 1;
180 | }
181 |
182 | if (not $latex) {
183 | # compute overall precision, recall and FB1 (default values are 0.0)
184 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
185 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
186 | $FB1 = 2*$precision*$recall/($precision+$recall)
187 | if ($precision+$recall > 0);
188 |
189 | # print overall performance
190 | printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
191 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
192 | if ($tokenCounter>0) {
193 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
194 | printf "precision: %6.2f%%; ",$precision;
195 | printf "recall: %6.2f%%; ",$recall;
196 | printf "FB1: %6.2f\n",$FB1;
197 | }
198 | }
199 |
200 | # sort chunk type names
201 | undef($lastType);
202 | @sortedTypes = ();
203 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
204 | if (not($lastType) or $lastType ne $i) {
205 | push(@sortedTypes,($i));
206 | }
207 | $lastType = $i;
208 | }
209 | # print performance per chunk type
210 | if (not $latex) {
211 | for $i (@sortedTypes) {
212 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
213 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
214 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
215 | if (not($foundCorrect{$i})) { $recall = 0.0; }
216 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
217 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
218 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
219 | printf "%17s: ",$i;
220 | printf "precision: %6.2f%%; ",$precision;
221 | printf "recall: %6.2f%%; ",$recall;
222 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
223 | }
224 | } else {
225 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
226 | for $i (@sortedTypes) {
227 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
228 | if (not($foundGuessed{$i})) { $precision = 0.0; }
229 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
230 | if (not($foundCorrect{$i})) { $recall = 0.0; }
231 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
232 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
233 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
234 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
235 | $i,$precision,$recall,$FB1;
236 | }
237 | print "\\hline\n";
238 | $precision = 0.0;
239 | $recall = 0;
240 | $FB1 = 0.0;
241 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
242 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
243 | $FB1 = 2*$precision*$recall/($precision+$recall)
244 | if ($precision+$recall > 0);
245 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
246 | $precision,$recall,$FB1;
247 | }
248 |
249 | exit 0;
250 |
251 | # endOfChunk: checks if a chunk ended between the previous and current word
252 | # arguments: previous and current chunk tags, previous and current types
253 | # note: this code is capable of handling other chunk representations
254 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
255 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
256 |
257 | sub endOfChunk {
258 | my $prevTag = shift(@_);
259 | my $tag = shift(@_);
260 | my $prevType = shift(@_);
261 | my $type = shift(@_);
262 | my $chunkEnd = $false;
263 |
264 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
265 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
266 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
267 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
268 |
269 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
270 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
271 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
272 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
273 |
274 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
275 | $chunkEnd = $true;
276 | }
277 |
278 | # corrected 1998-12-22: these chunks are assumed to have length 1
279 | if ( $prevTag eq "]" ) { $chunkEnd = $true; }
280 | if ( $prevTag eq "[" ) { $chunkEnd = $true; }
281 |
282 | return($chunkEnd);
283 | }
284 |
285 | # startOfChunk: checks if a chunk started between the previous and current word
286 | # arguments: previous and current chunk tags, previous and current types
287 | # note: this code is capable of handling other chunk representations
288 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
289 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
290 |
291 | sub startOfChunk {
292 | my $prevTag = shift(@_);
293 | my $tag = shift(@_);
294 | my $prevType = shift(@_);
295 | my $type = shift(@_);
296 | my $chunkStart = $false;
297 |
298 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
299 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
300 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
301 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
302 |
303 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
304 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
305 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
306 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
307 |
308 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
309 | $chunkStart = $true;
310 | }
311 |
312 | # corrected 1998-12-22: these chunks are assumed to have length 1
313 | if ( $tag eq "[" ) { $chunkStart = $true; }
314 | if ( $tag eq "]" ) { $chunkStart = $true; }
315 |
316 | return($chunkStart);
317 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BERT-BiLSTM-CRF-NER
2 | Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning
3 |
4 | 使用谷歌的BERT模型在BLSTM-CRF模型上进行预训练用于中文命名实体识别的Tensorflow代码'
5 |
6 | 中文文档请查看https://blog.csdn.net/macanv/article/details/85684284 如果对您有帮助,麻烦点个star,谢谢~~
7 |
8 | Welcome to star this repository!
9 |
10 | The Chinese training data($PATH/NERdata/) come from:https://github.com/zjy-ucas/ChineseNER
11 |
12 | The CoNLL-2003 data($PATH/NERdata/ori/) come from:https://github.com/kyzhouhzau/BERT-NER
13 |
14 | The evaluation codes come from:https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py
15 |
16 |
17 | Try to implement NER work based on google's BERT code and BiLSTM-CRF network!
18 | This project may be more close to process Chinese data. but other language only need Modify a small amount of code.
19 |
20 | THIS PROJECT ONLY SUPPORT Python3.
21 | ###################################################################
22 | ## Download project and install
23 | You can install this project by:
24 | ```
25 | pip install bert-base==0.0.9 -i https://pypi.python.org/simple
26 | ```
27 |
28 | OR
29 | ```angular2html
30 | git clone https://github.com/macanv/BERT-BiLSTM-CRF-NER
31 | cd BERT-BiLSTM-CRF-NER/
32 | python3 setup.py install
33 | ```
34 |
35 | if you do not want to install, you just need clone this project and reference the file of to train the model or start the service.
36 |
37 | ## UPDATE:
38 | - 2020.2.6 add simple flask ner service code
39 | - 2019.2.25 Fix some bug for ner service
40 | - 2019.2.19: add text classification service
41 | - fix Missing loss error
42 | - add label_list params in train process, so you can using -label_list xxx to special labels in training process.
43 |
44 |
45 | ## Train model:
46 | You can use -help to view the relevant parameters of the training named entity recognition model, where data_dir, bert_config_file, output_dir, init_checkpoint, vocab_file must be specified.
47 | ```angular2html
48 | bert-base-ner-train -help
49 | ```
50 | 
51 |
52 |
53 | train/dev/test dataset is like this:
54 | ```
55 | 海 O
56 | 钓 O
57 | 比 O
58 | 赛 O
59 | 地 O
60 | 点 O
61 | 在 O
62 | 厦 B-LOC
63 | 门 I-LOC
64 | 与 O
65 | 金 B-LOC
66 | 门 I-LOC
67 | 之 O
68 | 间 O
69 | 的 O
70 | 海 O
71 | 域 O
72 | 。 O
73 | ```
74 | The first one of each line is a token, the second is token's label, and the line is divided by a blank line. The maximum length of each sentence is [max_seq_length] params.
75 | You can get training data from above two git repos
76 | You can training ner model by running below command:
77 | ```angular2html
78 | bert-base-ner-train \
79 | -data_dir {your dataset dir}\
80 | -output_dir {training output dir}\
81 | -init_checkpoint {Google BERT model dir}\
82 | -bert_config_file {bert_config.json under the Google BERT model dir} \
83 | -vocab_file {vocab.txt under the Google BERT model dir}
84 | ```
85 | like my init_checkpoint:
86 | ```
87 | init_checkpoint = F:\chinese_L-12_H-768_A-12\bert_model.ckpt
88 | ```
89 | you can special labels using -label_list params, the project get labels from training data.
90 | ```angular2html
91 | # using , split
92 | -labels 'B-LOC, I-LOC ...'
93 | OR save label in a file like labels.txt, one line one label
94 | -labels labels.txt
95 | ```
96 |
97 | After training model, the NER model will be saved in {output_dir} which you special above cmd line.
98 | ##### My Training environment:Tesla P40 24G mem
99 |
100 | ## As Service
101 | Many server and client code comes from excellent open source projects: [bert as service of hanxiao](https://github.com/hanxiao/bert-as-service) If my code violates any license agreement, please let me know and I will correct it the first time.
102 | ~~and NER server/client service code can be applied to other tasks with simple modifications, such as text categorization, which I will provide later.~~
103 | this project private Named Entity Recognition and Text Classification server service.
104 | Welcome to submit your request or share your model, if you want to share it on Github or my work.
105 |
106 | You can use -help to view the relevant parameters of the NER as Service:
107 | which model_dir, bert_model_dir is need
108 | ```
109 | bert-base-serving-start -help
110 | ```
111 | 
112 |
113 | and than you can using below cmd start ner service:
114 | ```angular2html
115 | bert-base-serving-start \
116 | -model_dir C:\workspace\python\BERT_Base\output\ner2 \
117 | -bert_model_dir F:\chinese_L-12_H-768_A-12
118 | -model_pb_dir C:\workspace\python\BERT_Base\model_pb_dir
119 | -mode NER
120 | ```
121 | or text classification service:
122 | ```angular2html
123 | bert-base-serving-start \
124 | -model_dir C:\workspace\python\BERT_Base\output\ner2 \
125 | -bert_model_dir F:\chinese_L-12_H-768_A-12
126 | -model_pb_dir C:\workspace\python\BERT_Base\model_pb_dir
127 | -mode CLASS
128 | -max_seq_len 202
129 | ```
130 |
131 | as you see:
132 | mode: If mode is NER/CLASS, then the service identified by the Named Entity Recognition/Text Classification will be started. If it is BERT, it will be the same as the [bert as service] project.
133 | bert_model_dir: bert_model_dir is a BERT model, you can download from https://github.com/google-research/bert
134 | ner_model_dir: your ner model checkpoint dir
135 | model_pb_dir: model freeze save dir, after run optimize func, there will contains like ner_model.pb binary file
136 | >You can download my ner model from:https://pan.baidu.com/s/1m9VcueQ5gF-TJc00sFD88w, ex_code: guqq
137 | > Or text classification model from: https://pan.baidu.com/s/1oFPsOUh1n5AM2HjDIo2XCw, ex_code: bbu8
138 | Set ner_mode.pb/classification_model.pb to model_pb_dir, and set other file to model_dir(Different models need to be stored separately, you can set ner models label_list.pkl and label2id.pkl to model_dir/ner/ and set text classification file to model_dir/text_classification) , Text classification model can classify 12 categories of Chinese data: '游戏', '娱乐', '财经', '时政', '股票', '教育', '社会', '体育', '家居', '时尚', '房产', '彩票'
139 |
140 | You can see below service starting info:
141 | 
142 | 
143 |
144 |
145 | you can using below code test client:
146 | #### 1. NER Client
147 | ```angular2html
148 | import time
149 | from bert_base.client import BertClient
150 |
151 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='NER') as bc:
152 | start_t = time.perf_counter()
153 | str = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。'
154 | rst = bc.encode([str, str])
155 | print('rst:', rst)
156 | print(time.perf_counter() - start_t)
157 | ```
158 | you can see this after run the above code:
159 | 
160 | If you want to customize the word segmentation method, you only need to make the following simple changes on the client side code.
161 |
162 | ```angular2html
163 | rst = bc.encode([list(str), list(str)], is_tokenized=True)
164 | ```
165 |
166 | #### 2. Text Classification Client
167 | ```angular2html
168 | with BertClient(show_server_config=False, check_version=False, check_length=False, mode='CLASS') as bc:
169 | start_t = time.perf_counter()
170 | str1 = '北京时间2月17日凌晨,第69届柏林国际电影节公布主竞赛单元获奖名单,王景春、咏梅凭借王小帅执导的中国影片《地久天长》连夺最佳男女演员双银熊大奖,这是中国演员首次包揽柏林电影节最佳男女演员奖,为华语影片刷新纪录。与此同时,由青年导演王丽娜执导的影片《第一次的别离》也荣获了本届柏林电影节新生代单元国际评审团最佳影片,可以说,在经历数个获奖小年之后,中国电影在柏林影展再次迎来了高光时刻。'
171 | str2 = '受粤港澳大湾区规划纲要提振,港股周二高开,恒指开盘上涨近百点,涨幅0.33%,报28440.49点,相关概念股亦集体上涨,电子元件、新能源车、保险、基建概念多数上涨。粤泰股份、珠江实业、深天地A等10余股涨停;中兴通讯、丘钛科技、舜宇光学分别高开1.4%、4.3%、1.6%。比亚迪电子、比亚迪股份、光宇国际分别高开1.7%、1.2%、1%。越秀交通基建涨近2%,粤海投资、碧桂园等多股涨超1%。其他方面,日本软银集团股价上涨超0.4%,推动日经225和东证指数齐齐高开,但随后均回吐涨幅转跌东证指数跌0.2%,日经225指数跌0.11%,报21258.4点。受芯片制造商SK海力士股价下跌1.34%拖累,韩国综指下跌0.34%至2203.9点。澳大利亚ASX 200指数早盘上涨0.39%至6089.8点,大多数行业板块均现涨势。在保健品品牌澳佳宝下调下半财年的销售预期后,其股价暴跌超过23%。澳佳宝CEO亨弗里(Richard Henfrey)认为,公司下半年的利润可能会低于上半年,主要是受到销售额疲弱的影响。同时,亚市早盘澳洲联储公布了2月会议纪要,政策委员将继续谨慎评估经济增长前景,因前景充满不确定性的影响,稳定当前的利率水平比贸然调整利率更为合适,而且当前利率水平将有利于趋向通胀目标及改善就业,当前劳动力市场数据表现强势于其他经济数据。另一方面,经济增长前景亦令消费者消费意愿下滑,如果房价出现下滑,消费可能会进一步疲弱。在澳洲联储公布会议纪要后,澳元兑美元下跌近30点,报0.7120 。美元指数在昨日触及96.65附近的低点之后反弹至96.904。日元兑美元报110.56,接近上一交易日的低点。'
172 | str3 = '新京报快讯 据国家市场监管总局消息,针对媒体报道水饺等猪肉制品检出非洲猪瘟病毒核酸阳性问题,市场监管总局、农业农村部已要求企业立即追溯猪肉原料来源并对猪肉制品进行了处置。两部门已派出联合督查组调查核实相关情况,要求猪肉制品生产企业进一步加强对猪肉原料的管控,落实检验检疫票证查验规定,完善非洲猪瘟检测和复核制度,防止染疫猪肉原料进入食品加工环节。市场监管总局、农业农村部等部门要求各地全面落实防控责任,强化防控措施,规范信息报告和发布,对不按要求履行防控责任的企业,一旦发现将严厉查处。专家认为,非洲猪瘟不是人畜共患病,虽然对猪有致命危险,但对人没有任何危害,属于只传猪不传人型病毒,不会影响食品安全。开展猪肉制品病毒核酸检测,可为防控溯源工作提供线索。'
173 | rst = bc.encode([str1, str2, str3])
174 | print('rst:', rst)
175 | print('time used:{}'.format(time.perf_counter() - start_t))
176 | ```
177 | you can see this after run the above code:
178 | 
179 |
180 | Note that it can not start NER service and Text Classification service together. but you can using twice command line start ner service and text classification with different port.
181 |
182 | ### Flask server service
183 | sometimes, multi thread deep learning model service may not use C/S service, you can useing simple http service replace that, like using flask.
184 | now you can reference code:bert_base/server/simple_flask_http_service.py,building your simple http server service
185 |
186 | ## License
187 | MIT.
188 |
189 | # The following tutorial is an old version and will be removed in the future.
190 |
191 | ## How to train
192 | #### 1. Download BERT chinese model :
193 | ```
194 | wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
195 | ```
196 | #### 2. create output dir
197 | create output path in project path:
198 | ```angular2html
199 | mkdir output
200 | ```
201 | #### 3. Train model
202 |
203 | ##### first method
204 | ```
205 | python3 bert_lstm_ner.py \
206 | --task_name="NER" \
207 | --do_train=True \
208 | --do_eval=True \
209 | --do_predict=True
210 | --data_dir=NERdata \
211 | --vocab_file=checkpoint/vocab.txt \
212 | --bert_config_file=checkpoint/bert_config.json \
213 | --init_checkpoint=checkpoint/bert_model.ckpt \
214 | --max_seq_length=128 \
215 | --train_batch_size=32 \
216 | --learning_rate=2e-5 \
217 | --num_train_epochs=3.0 \
218 | --output_dir=./output/result_dir/
219 | ```
220 | ##### OR replace the BERT path and project path in bert_lstm_ner.py
221 | ```
222 | if os.name == 'nt': #windows path config
223 | bert_path = '{your BERT model path}'
224 | root_path = '{project path}'
225 | else: # linux path config
226 | bert_path = '{your BERT model path}'
227 | root_path = '{project path}'
228 | ```
229 | Than Run:
230 | ```angular2html
231 | python3 bert_lstm_ner.py
232 | ```
233 |
234 | ### USING BLSTM-CRF OR ONLY CRF FOR DECODE!
235 | Just alter bert_lstm_ner.py line of 450, the params of the function of add_blstm_crf_layer: crf_only=True or False
236 |
237 | ONLY CRF output layer:
238 | ```
239 | blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=FLAGS.lstm_size, cell_type=FLAGS.cell, num_layers=FLAGS.num_layers,
240 | dropout_rate=FLAGS.droupout_rate, initializers=initializers, num_labels=num_labels,
241 | seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training)
242 | rst = blstm_crf.add_blstm_crf_layer(crf_only=True)
243 | ```
244 |
245 |
246 | BiLSTM with CRF output layer
247 | ```
248 | blstm_crf = BLSTM_CRF(embedded_chars=embedding, hidden_unit=FLAGS.lstm_size, cell_type=FLAGS.cell, num_layers=FLAGS.num_layers,
249 | dropout_rate=FLAGS.droupout_rate, initializers=initializers, num_labels=num_labels,
250 | seq_length=max_seq_length, labels=labels, lengths=lengths, is_training=is_training)
251 | rst = blstm_crf.add_blstm_crf_layer(crf_only=False)
252 | ```
253 |
254 | ## Result:
255 | all params using default
256 | #### In dev data set:
257 | 
258 |
259 | #### In test data set
260 | 
261 |
262 | #### entity leval result:
263 | last two result are label level result, the entitly level result in code of line 796-798,this result will be output in predict process.
264 | show my entity level result :
265 | 
266 | > my model can download from baidu cloud:
267 | >链接:https://pan.baidu.com/s/1GfDFleCcTv5393ufBYdgqQ 提取码:4cus
268 | NOTE: My model is trained by crf_only params
269 |
270 | ## ONLINE PREDICT
271 | If model is train finished, just run
272 | ```angular2html
273 | python3 terminal_predict.py
274 | ```
275 | 
276 |
277 | ## Using NER as Service
278 |
279 | #### Service
280 | Using NER as Service is simple, you just need to run the python script below in the project root path:
281 | ```angular2html
282 | python3 runs.py \
283 | -mode NER
284 | -bert_model_dir /home/macan/ml/data/chinese_L-12_H-768_A-12 \
285 | -ner_model_dir /home/macan/ml/data/bert_ner \
286 | -model_pd_dir /home/macan/ml/workspace/BERT_Base/output/predict_optimizer \
287 | -num_worker 8
288 | ```
289 |
290 |
291 | You can download my ner model from:https://pan.baidu.com/s/1m9VcueQ5gF-TJc00sFD88w, ex_code: guqq
292 | Set ner_mode.pb to model_pd_dir, and set other file to ner_model_dir and than run last cmd
293 | 
294 | 
295 |
296 |
297 | #### Client
298 | The client using methods can reference client_test.py script
299 | ```angular2html
300 | import time
301 | from client.client import BertClient
302 |
303 | ner_model_dir = 'C:\workspace\python\BERT_Base\output\predict_ner'
304 | with BertClient( ner_model_dir=ner_model_dir, show_server_config=False, check_version=False, check_length=False, mode='NER') as bc:
305 | start_t = time.perf_counter()
306 | str = '1月24日,新华社对外发布了中央对雄安新区的指导意见,洋洋洒洒1.2万多字,17次提到北京,4次提到天津,信息量很大,其实也回答了人们关心的很多问题。'
307 | rst = bc.encode([str])
308 | print('rst:', rst)
309 | print(time.perf_counter() - start_t)
310 | ```
311 | NOTE: input format you can sometime reference bert as service project.
312 | Welcome to provide more client language code like java or others.
313 | ## Using yourself data to train
314 | if you want to use yourself data to train ner model,you just modify the get_labes func.
315 | ```angular2html
316 | def get_labels(self):
317 | return ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]
318 | ```
319 | NOTE: "X", “[CLS]”, “[SEP]” These three are necessary, you just replace your data label to this return list.
320 | Or you can use last code lets the program automatically get the label from training data
321 | ```angular2html
322 | def get_labels(self):
323 | # 通过读取train文件获取标签的方法会出现一定的风险。
324 | if os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')):
325 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'rb') as rf:
326 | self.labels = pickle.load(rf)
327 | else:
328 | if len(self.labels) > 0:
329 | self.labels = self.labels.union(set(["X", "[CLS]", "[SEP]"]))
330 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as rf:
331 | pickle.dump(self.labels, rf)
332 | else:
333 | self.labels = ["O", 'B-TIM', 'I-TIM', "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]
334 | return self.labels
335 |
336 | ```
337 |
338 |
339 | ## NEW UPDATE
340 | 2019.1.30 Support pip install and command line control
341 |
342 | 2019.1.30 Add Service/Client for NER process
343 |
344 | 2019.1.9: Add code to remove the adam related parameters in the model, and reduce the size of the model file from 1.3GB to 400MB.
345 |
346 | 2019.1.3: Add online predict code
347 |
348 |
349 |
350 | ## reference:
351 | + The evaluation codes come from:https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py
352 |
353 | + [https://github.com/google-research/bert](https://github.com/google-research/bert)
354 |
355 | + [https://github.com/kyzhouhzau/BERT-NER](https://github.com/kyzhouhzau/BERT-NER)
356 |
357 | + [https://github.com/zjy-ucas/ChineseNER](https://github.com/zjy-ucas/ChineseNER)
358 |
359 | + [https://github.com/hanxiao/bert-as-service](https://github.com/hanxiao/bert-as-service)
360 | > Any problem please open issue OR email me(ma_cancan@163.com)
361 |
--------------------------------------------------------------------------------
/bert_base/bert/create_pretraining_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Create masked LM/next sentence masked_lm TF examples for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import random
23 |
24 | import tokenization
25 | import tensorflow as tf
26 |
27 | flags = tf.flags
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | flags.DEFINE_string("input_file", None,
32 | "Input raw text file (or comma-separated list of files).")
33 |
34 | flags.DEFINE_string(
35 | "output_file", None,
36 | "Output TF example file (or comma-separated list of files).")
37 |
38 | flags.DEFINE_string("vocab_file", None,
39 | "The vocabulary file that the BERT model was trained on.")
40 |
41 | flags.DEFINE_bool(
42 | "do_lower_case", True,
43 | "Whether to lower case the input text. Should be True for uncased "
44 | "models and False for cased models.")
45 |
46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
47 |
48 | flags.DEFINE_integer("max_predictions_per_seq", 20,
49 | "Maximum number of masked LM predictions per sequence.")
50 |
51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
52 |
53 | flags.DEFINE_integer(
54 | "dupe_factor", 10,
55 | "Number of times to duplicate the input data (with different masks).")
56 |
57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
58 |
59 | flags.DEFINE_float(
60 | "short_seq_prob", 0.1,
61 | "Probability of creating sequences which are shorter than the "
62 | "maximum length.")
63 |
64 |
65 | class TrainingInstance(object):
66 | """A single training instance (sentence pair)."""
67 |
68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
69 | is_random_next):
70 | self.tokens = tokens
71 | self.segment_ids = segment_ids
72 | self.is_random_next = is_random_next
73 | self.masked_lm_positions = masked_lm_positions
74 | self.masked_lm_labels = masked_lm_labels
75 |
76 | def __str__(self):
77 | s = ""
78 | s += "tokens: %s\n" % (" ".join(
79 | [tokenization.printable_text(x) for x in self.tokens]))
80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
81 | s += "is_random_next: %s\n" % self.is_random_next
82 | s += "masked_lm_positions: %s\n" % (" ".join(
83 | [str(x) for x in self.masked_lm_positions]))
84 | s += "masked_lm_labels: %s\n" % (" ".join(
85 | [tokenization.printable_text(x) for x in self.masked_lm_labels]))
86 | s += "\n"
87 | return s
88 |
89 | def __repr__(self):
90 | return self.__str__()
91 |
92 |
93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length,
94 | max_predictions_per_seq, output_files):
95 | """Create TF example files from `TrainingInstance`s."""
96 | writers = []
97 | for output_file in output_files:
98 | writers.append(tf.python_io.TFRecordWriter(output_file))
99 |
100 | writer_index = 0
101 |
102 | total_written = 0
103 | for (inst_index, instance) in enumerate(instances):
104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
105 | input_mask = [1] * len(input_ids)
106 | segment_ids = list(instance.segment_ids)
107 | assert len(input_ids) <= max_seq_length
108 |
109 | while len(input_ids) < max_seq_length:
110 | input_ids.append(0)
111 | input_mask.append(0)
112 | segment_ids.append(0)
113 |
114 | assert len(input_ids) == max_seq_length
115 | assert len(input_mask) == max_seq_length
116 | assert len(segment_ids) == max_seq_length
117 |
118 | masked_lm_positions = list(instance.masked_lm_positions)
119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
120 | masked_lm_weights = [1.0] * len(masked_lm_ids)
121 |
122 | while len(masked_lm_positions) < max_predictions_per_seq:
123 | masked_lm_positions.append(0)
124 | masked_lm_ids.append(0)
125 | masked_lm_weights.append(0.0)
126 |
127 | next_sentence_label = 1 if instance.is_random_next else 0
128 |
129 | features = collections.OrderedDict()
130 | features["input_ids"] = create_int_feature(input_ids)
131 | features["input_mask"] = create_int_feature(input_mask)
132 | features["segment_ids"] = create_int_feature(segment_ids)
133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label])
137 |
138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
139 |
140 | writers[writer_index].write(tf_example.SerializeToString())
141 | writer_index = (writer_index + 1) % len(writers)
142 |
143 | total_written += 1
144 |
145 | if inst_index < 20:
146 | tf.logging.info("*** Example ***")
147 | tf.logging.info("tokens: %s" % " ".join(
148 | [tokenization.printable_text(x) for x in instance.tokens]))
149 |
150 | for feature_name in features.keys():
151 | feature = features[feature_name]
152 | values = []
153 | if feature.int64_list.value:
154 | values = feature.int64_list.value
155 | elif feature.float_list.value:
156 | values = feature.float_list.value
157 | tf.logging.info(
158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
159 |
160 | for writer in writers:
161 | writer.close()
162 |
163 | tf.logging.info("Wrote %d total instances", total_written)
164 |
165 |
166 | def create_int_feature(values):
167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
168 | return feature
169 |
170 |
171 | def create_float_feature(values):
172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
173 | return feature
174 |
175 |
176 | def create_training_instances(input_files, tokenizer, max_seq_length,
177 | dupe_factor, short_seq_prob, masked_lm_prob,
178 | max_predictions_per_seq, rng):
179 | """Create `TrainingInstance`s from raw text."""
180 | all_documents = [[]]
181 |
182 | # Input file format:
183 | # (1) One sentence per line. These should ideally be actual sentences, not
184 | # entire paragraphs or arbitrary spans of text. (Because we use the
185 | # sentence boundaries for the "next sentence prediction" task).
186 | # (2) Blank lines between documents. Document boundaries are needed so
187 | # that the "next sentence prediction" task doesn't span between documents.
188 | for input_file in input_files:
189 | with tf.gfile.GFile(input_file, "r") as reader:
190 | while True:
191 | line = tokenization.convert_to_unicode(reader.readline())
192 | if not line:
193 | break
194 | line = line.strip()
195 |
196 | # Empty lines are used as document delimiters
197 | if not line:
198 | all_documents.append([])
199 | tokens = tokenizer.tokenize(line)
200 | if tokens:
201 | all_documents[-1].append(tokens)
202 |
203 | # Remove empty documents
204 | all_documents = [x for x in all_documents if x]
205 | rng.shuffle(all_documents)
206 |
207 | vocab_words = list(tokenizer.vocab.keys())
208 | instances = []
209 | for _ in range(dupe_factor):
210 | for document_index in range(len(all_documents)):
211 | instances.extend(
212 | create_instances_from_document(
213 | all_documents, document_index, max_seq_length, short_seq_prob,
214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
215 |
216 | rng.shuffle(instances)
217 | return instances
218 |
219 |
220 | def create_instances_from_document(
221 | all_documents, document_index, max_seq_length, short_seq_prob,
222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
223 | """Creates `TrainingInstance`s for a single document."""
224 | document = all_documents[document_index]
225 |
226 | # Account for [CLS], [SEP], [SEP]
227 | max_num_tokens = max_seq_length - 3
228 |
229 | # We *usually* want to fill up the entire sequence since we are padding
230 | # to `max_seq_length` anyways, so short sequences are generally wasted
231 | # computation. However, we *sometimes*
232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
233 | # sequences to minimize the mismatch between pre-training and fine-tuning.
234 | # The `target_seq_length` is just a rough target however, whereas
235 | # `max_seq_length` is a hard limit.
236 | target_seq_length = max_num_tokens
237 | if rng.random() < short_seq_prob:
238 | target_seq_length = rng.randint(2, max_num_tokens)
239 |
240 | # We DON'T just concatenate all of the tokens from a document into a long
241 | # sequence and choose an arbitrary split point because this would make the
242 | # next sentence prediction task too easy. Instead, we split the input into
243 | # segments "A" and "B" based on the actual "sentences" provided by the user
244 | # input.
245 | instances = []
246 | current_chunk = []
247 | current_length = 0
248 | i = 0
249 | while i < len(document):
250 | segment = document[i]
251 | current_chunk.append(segment)
252 | current_length += len(segment)
253 | if i == len(document) - 1 or current_length >= target_seq_length:
254 | if current_chunk:
255 | # `a_end` is how many segments from `current_chunk` go into the `A`
256 | # (first) sentence.
257 | a_end = 1
258 | if len(current_chunk) >= 2:
259 | a_end = rng.randint(1, len(current_chunk) - 1)
260 |
261 | tokens_a = []
262 | for j in range(a_end):
263 | tokens_a.extend(current_chunk[j])
264 |
265 | tokens_b = []
266 | # Random next
267 | is_random_next = False
268 | if len(current_chunk) == 1 or rng.random() < 0.5:
269 | is_random_next = True
270 | target_b_length = target_seq_length - len(tokens_a)
271 |
272 | # This should rarely go for more than one iteration for large
273 | # corpora. However, just to be careful, we try to make sure that
274 | # the random document is not the same as the document
275 | # we're processing.
276 | for _ in range(10):
277 | random_document_index = rng.randint(0, len(all_documents) - 1)
278 | if random_document_index != document_index:
279 | break
280 |
281 | random_document = all_documents[random_document_index]
282 | random_start = rng.randint(0, len(random_document) - 1)
283 | for j in range(random_start, len(random_document)):
284 | tokens_b.extend(random_document[j])
285 | if len(tokens_b) >= target_b_length:
286 | break
287 | # We didn't actually use these segments so we "put them back" so
288 | # they don't go to waste.
289 | num_unused_segments = len(current_chunk) - a_end
290 | i -= num_unused_segments
291 | # Actual next
292 | else:
293 | is_random_next = False
294 | for j in range(a_end, len(current_chunk)):
295 | tokens_b.extend(current_chunk[j])
296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
297 |
298 | assert len(tokens_a) >= 1
299 | assert len(tokens_b) >= 1
300 |
301 | tokens = []
302 | segment_ids = []
303 | tokens.append("[CLS]")
304 | segment_ids.append(0)
305 | for token in tokens_a:
306 | tokens.append(token)
307 | segment_ids.append(0)
308 |
309 | tokens.append("[SEP]")
310 | segment_ids.append(0)
311 |
312 | for token in tokens_b:
313 | tokens.append(token)
314 | segment_ids.append(1)
315 | tokens.append("[SEP]")
316 | segment_ids.append(1)
317 |
318 | (tokens, masked_lm_positions,
319 | masked_lm_labels) = create_masked_lm_predictions(
320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
321 | instance = TrainingInstance(
322 | tokens=tokens,
323 | segment_ids=segment_ids,
324 | is_random_next=is_random_next,
325 | masked_lm_positions=masked_lm_positions,
326 | masked_lm_labels=masked_lm_labels)
327 | instances.append(instance)
328 | current_chunk = []
329 | current_length = 0
330 | i += 1
331 |
332 | return instances
333 |
334 |
335 | def create_masked_lm_predictions(tokens, masked_lm_prob,
336 | max_predictions_per_seq, vocab_words, rng):
337 | """Creates the predictions for the masked LM objective."""
338 |
339 | cand_indexes = []
340 | for (i, token) in enumerate(tokens):
341 | if token == "[CLS]" or token == "[SEP]":
342 | continue
343 | cand_indexes.append(i)
344 |
345 | rng.shuffle(cand_indexes)
346 |
347 | output_tokens = list(tokens)
348 |
349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
350 |
351 | num_to_predict = min(max_predictions_per_seq,
352 | max(1, int(round(len(tokens) * masked_lm_prob))))
353 |
354 | masked_lms = []
355 | covered_indexes = set()
356 | for index in cand_indexes:
357 | if len(masked_lms) >= num_to_predict:
358 | break
359 | if index in covered_indexes:
360 | continue
361 | covered_indexes.add(index)
362 |
363 | masked_token = None
364 | # 80% of the time, replace with [MASK]
365 | if rng.random() < 0.8:
366 | masked_token = "[MASK]"
367 | else:
368 | # 10% of the time, keep original
369 | if rng.random() < 0.5:
370 | masked_token = tokens[index]
371 | # 10% of the time, replace with random word
372 | else:
373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
374 |
375 | output_tokens[index] = masked_token
376 |
377 | masked_lms.append(masked_lm(index=index, label=tokens[index]))
378 |
379 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
380 |
381 | masked_lm_positions = []
382 | masked_lm_labels = []
383 | for p in masked_lms:
384 | masked_lm_positions.append(p.index)
385 | masked_lm_labels.append(p.label)
386 |
387 | return (output_tokens, masked_lm_positions, masked_lm_labels)
388 |
389 |
390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
391 | """Truncates a pair of sequences to a maximum sequence length."""
392 | while True:
393 | total_length = len(tokens_a) + len(tokens_b)
394 | if total_length <= max_num_tokens:
395 | break
396 |
397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
398 | assert len(trunc_tokens) >= 1
399 |
400 | # We want to sometimes truncate from the front and sometimes from the
401 | # back to add more randomness and avoid biases.
402 | if rng.random() < 0.5:
403 | del trunc_tokens[0]
404 | else:
405 | trunc_tokens.pop()
406 |
407 |
408 | def main(_):
409 | tf.logging.set_verbosity(tf.logging.INFO)
410 |
411 | tokenizer = tokenization.FullTokenizer(
412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
413 |
414 | input_files = []
415 | for input_pattern in FLAGS.input_file.split(","):
416 | input_files.extend(tf.gfile.Glob(input_pattern))
417 |
418 | tf.logging.info("*** Reading from input files ***")
419 | for input_file in input_files:
420 | tf.logging.info(" %s", input_file)
421 |
422 | rng = random.Random(FLAGS.random_seed)
423 | instances = create_training_instances(
424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
426 | rng)
427 |
428 | output_files = FLAGS.output_file.split(",")
429 | tf.logging.info("*** Writing to output files ***")
430 | for output_file in output_files:
431 | tf.logging.info(" %s", output_file)
432 |
433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
434 | FLAGS.max_predictions_per_seq, output_files)
435 |
436 |
437 | if __name__ == "__main__":
438 | flags.mark_flag_as_required("input_file")
439 | flags.mark_flag_as_required("output_file")
440 | flags.mark_flag_as_required("vocab_file")
441 | tf.app.run()
442 |
--------------------------------------------------------------------------------
/bert_base/server/graph.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import json
3 | import os
4 | from enum import Enum
5 |
6 | from termcolor import colored
7 |
8 | from .helper import import_tf, set_logger
9 |
10 | import sys
11 | sys.path.append('..')
12 | from bert_base.bert import modeling
13 |
14 | __all__ = ['PoolingStrategy', 'optimize_bert_graph', 'optimize_ner_model', 'optimize_class_model']
15 |
16 |
17 | class PoolingStrategy(Enum):
18 | NONE = 0
19 | REDUCE_MAX = 1
20 | REDUCE_MEAN = 2
21 | REDUCE_MEAN_MAX = 3
22 | FIRST_TOKEN = 4 # corresponds to [CLS] for single sequences
23 | LAST_TOKEN = 5 # corresponds to [SEP] for single sequences
24 | CLS_TOKEN = 4 # corresponds to the first token for single seq.
25 | SEP_TOKEN = 5 # corresponds to the last token for single seq.
26 |
27 | def __str__(self):
28 | return self.name
29 |
30 | @staticmethod
31 | def from_string(s):
32 | try:
33 | return PoolingStrategy[s]
34 | except KeyError:
35 | raise ValueError()
36 |
37 |
38 | def optimize_bert_graph(args, logger=None):
39 | if not logger:
40 | logger = set_logger(colored('GRAPHOPT', 'cyan'), args.verbose)
41 | try:
42 | if not os.path.exists(args.model_pb_dir):
43 | os.mkdir(args.model_pb_dir)
44 | pb_file = os.path.join(args.model_pb_dir, 'bert_model.pb')
45 | if os.path.exists(pb_file):
46 | return pb_file
47 | # we don't need GPU for optimizing the graph
48 | tf = import_tf(verbose=args.verbose)
49 | from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
50 |
51 | config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)
52 |
53 | config_fp = os.path.join(args.model_dir, args.config_name)
54 | init_checkpoint = os.path.join(args.tuned_model_dir or args.bert_model_dir, args.ckpt_name)
55 | if args.fp16:
56 | logger.warning('fp16 is turned on! '
57 | 'Note that not all CPU GPU support fast fp16 instructions, '
58 | 'worst case you will have degraded performance!')
59 | logger.info('model config: %s' % config_fp)
60 | logger.info(
61 | 'checkpoint%s: %s' % (
62 | ' (override by the fine-tuned model)' if args.tuned_model_dir else '', init_checkpoint))
63 | with tf.gfile.GFile(config_fp, 'r') as f:
64 | bert_config = modeling.BertConfig.from_dict(json.load(f))
65 |
66 | logger.info('build graph...')
67 | # input placeholders, not sure if they are friendly to XLA
68 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
69 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
70 | input_type_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_type_ids')
71 |
72 | jit_scope = tf.contrib.compiler.jit.experimental_jit_scope if args.xla else contextlib.suppress
73 |
74 | with jit_scope():
75 | input_tensors = [input_ids, input_mask, input_type_ids]
76 |
77 | model = modeling.BertModel(
78 | config=bert_config,
79 | is_training=False,
80 | input_ids=input_ids,
81 | input_mask=input_mask,
82 | token_type_ids=input_type_ids,
83 | use_one_hot_embeddings=False)
84 |
85 | tvars = tf.trainable_variables()
86 |
87 | (assignment_map, initialized_variable_names
88 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
89 |
90 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
91 |
92 | minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30
93 | mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
94 | masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1)
95 | masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
96 | tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
97 |
98 | with tf.variable_scope("pooling"):
99 | if len(args.pooling_layer) == 1:
100 | encoder_layer = model.all_encoder_layers[args.pooling_layer[0]]
101 | else:
102 | all_layers = [model.all_encoder_layers[l] for l in args.pooling_layer]
103 | encoder_layer = tf.concat(all_layers, -1)
104 |
105 | input_mask = tf.cast(input_mask, tf.float32)
106 | if args.pooling_strategy == PoolingStrategy.REDUCE_MEAN:
107 | pooled = masked_reduce_mean(encoder_layer, input_mask)
108 | elif args.pooling_strategy == PoolingStrategy.REDUCE_MAX:
109 | pooled = masked_reduce_max(encoder_layer, input_mask)
110 | elif args.pooling_strategy == PoolingStrategy.REDUCE_MEAN_MAX:
111 | pooled = tf.concat([masked_reduce_mean(encoder_layer, input_mask),
112 | masked_reduce_max(encoder_layer, input_mask)], axis=1)
113 | elif args.pooling_strategy == PoolingStrategy.FIRST_TOKEN or \
114 | args.pooling_strategy == PoolingStrategy.CLS_TOKEN:
115 | pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1)
116 | elif args.pooling_strategy == PoolingStrategy.LAST_TOKEN or \
117 | args.pooling_strategy == PoolingStrategy.SEP_TOKEN:
118 | seq_len = tf.cast(tf.reduce_sum(input_mask, axis=1), tf.int32)
119 | rng = tf.range(0, tf.shape(seq_len)[0])
120 | indexes = tf.stack([rng, seq_len - 1], 1)
121 | pooled = tf.gather_nd(encoder_layer, indexes)
122 | elif args.pooling_strategy == PoolingStrategy.NONE:
123 | pooled = mul_mask(encoder_layer, input_mask)
124 | else:
125 | raise NotImplementedError()
126 |
127 | if args.fp16:
128 | pooled = tf.cast(pooled, tf.float16)
129 |
130 | pooled = tf.identity(pooled, 'final_encodes')
131 | output_tensors = [pooled]
132 | tmp_g = tf.get_default_graph().as_graph_def()
133 |
134 | with tf.Session(config=config) as sess:
135 | logger.info('load parameters from checkpoint...')
136 |
137 | sess.run(tf.global_variables_initializer())
138 | dtypes = [n.dtype for n in input_tensors]
139 | logger.info('optimize...')
140 | tmp_g = optimize_for_inference(
141 | tmp_g,
142 | [n.name[:-2] for n in input_tensors],
143 | [n.name[:-2] for n in output_tensors],
144 | [dtype.as_datatype_enum for dtype in dtypes],
145 | False)
146 |
147 | logger.info('freeze...')
148 | tmp_g = convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors],
149 | use_fp16=args.fp16)
150 |
151 | logger.info('write graph to a tmp file: %s' % args.model_pb_dir)
152 | with tf.gfile.GFile(pb_file, 'wb') as f:
153 | f.write(tmp_g.SerializeToString())
154 | except Exception:
155 | logger.error('fail to optimize the graph!', exc_info=True)
156 |
157 |
158 | def convert_variables_to_constants(sess,
159 | input_graph_def,
160 | output_node_names,
161 | variable_names_whitelist=None,
162 | variable_names_blacklist=None,
163 | use_fp16=False):
164 | from tensorflow.python.framework.graph_util_impl import extract_sub_graph
165 | from tensorflow.core.framework import graph_pb2
166 | from tensorflow.core.framework import node_def_pb2
167 | from tensorflow.core.framework import attr_value_pb2
168 | from tensorflow.core.framework import types_pb2
169 | from tensorflow.python.framework import tensor_util
170 |
171 | def patch_dtype(input_node, field_name, output_node):
172 | if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT):
173 | output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF))
174 |
175 | inference_graph = extract_sub_graph(input_graph_def, output_node_names)
176 |
177 | variable_names = []
178 | variable_dict_names = []
179 | for node in inference_graph.node:
180 | if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
181 | variable_name = node.name
182 | if ((variable_names_whitelist is not None and
183 | variable_name not in variable_names_whitelist) or
184 | (variable_names_blacklist is not None and
185 | variable_name in variable_names_blacklist)):
186 | continue
187 | variable_dict_names.append(variable_name)
188 | if node.op == "VarHandleOp":
189 | variable_names.append(variable_name + "/Read/ReadVariableOp:0")
190 | else:
191 | variable_names.append(variable_name + ":0")
192 | if variable_names:
193 | returned_variables = sess.run(variable_names)
194 | else:
195 | returned_variables = []
196 | found_variables = dict(zip(variable_dict_names, returned_variables))
197 |
198 | output_graph_def = graph_pb2.GraphDef()
199 | how_many_converted = 0
200 | for input_node in inference_graph.node:
201 | output_node = node_def_pb2.NodeDef()
202 | if input_node.name in found_variables:
203 | output_node.op = "Const"
204 | output_node.name = input_node.name
205 | dtype = input_node.attr["dtype"]
206 | data = found_variables[input_node.name]
207 |
208 | if use_fp16 and dtype.type == types_pb2.DT_FLOAT:
209 | output_node.attr["value"].CopyFrom(
210 | attr_value_pb2.AttrValue(
211 | tensor=tensor_util.make_tensor_proto(data.astype('float16'),
212 | dtype=types_pb2.DT_HALF,
213 | shape=data.shape)))
214 | else:
215 | output_node.attr["dtype"].CopyFrom(dtype)
216 | output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(
217 | tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type,
218 | shape=data.shape)))
219 | how_many_converted += 1
220 | elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables):
221 | # placeholder nodes
222 | # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"]))
223 | output_node.op = "Identity"
224 | output_node.name = input_node.name
225 | output_node.input.extend([input_node.input[0]])
226 | output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
227 | if "_class" in input_node.attr:
228 | output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
229 | else:
230 | # mostly op nodes
231 | output_node.CopyFrom(input_node)
232 |
233 | patch_dtype(input_node, 'dtype', output_node)
234 | patch_dtype(input_node, 'T', output_node)
235 | patch_dtype(input_node, 'DstT', output_node)
236 | patch_dtype(input_node, 'SrcT', output_node)
237 | patch_dtype(input_node, 'Tparams', output_node)
238 |
239 | if use_fp16 and ('value' in output_node.attr) and (
240 | output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT):
241 | # hard-coded value need to be converted as well
242 | output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
243 | tensor=tensor_util.make_tensor_proto(
244 | output_node.attr['value'].tensor.float_val[0],
245 | dtype=types_pb2.DT_HALF)))
246 |
247 | output_graph_def.node.extend([output_node])
248 |
249 | output_graph_def.library.CopyFrom(inference_graph.library)
250 | return output_graph_def
251 |
252 |
253 | def optimize_ner_model(args, num_labels, logger=None):
254 | """
255 | 加载中文NER模型
256 | :param args:
257 | :param num_labels:
258 | :param logger:
259 | :return:
260 | """
261 | if not logger:
262 | logger = set_logger(colored('NER_MODEL, Lodding...', 'cyan'), args.verbose)
263 | try:
264 | # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径
265 | if args.model_pb_dir is None:
266 | # 获取当前的运行路径
267 | tmp_file = os.path.join(os.getcwd(), 'predict_optimizer')
268 | if not os.path.exists(tmp_file):
269 | os.mkdir(tmp_file)
270 | else:
271 | tmp_file = args.model_pb_dir
272 | pb_file = os.path.join(tmp_file, 'ner_model.pb')
273 | if os.path.exists(pb_file):
274 | print('pb_file exits', pb_file)
275 | return pb_file
276 |
277 | import tensorflow as tf
278 |
279 | graph = tf.Graph()
280 | with graph.as_default():
281 | with tf.Session() as sess:
282 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
283 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
284 |
285 | bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json'))
286 | from bert_base.train.models import create_model
287 | (total_loss, logits, trans, pred_ids) = create_model(
288 | bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=None,
289 | labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0, lstm_size=args.lstm_size)
290 | pred_ids = tf.identity(pred_ids, 'pred_ids')
291 | saver = tf.train.Saver()
292 |
293 | with tf.Session() as sess:
294 | sess.run(tf.global_variables_initializer())
295 | saver.restore(sess, tf.train.latest_checkpoint(args.model_dir))
296 | logger.info('freeze...')
297 | from tensorflow.python.framework import graph_util
298 | tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_ids'])
299 | logger.info('model cut finished !!!')
300 | # 存储二进制模型到文件中
301 | logger.info('write graph to a tmp file: %s' % pb_file)
302 | with tf.gfile.GFile(pb_file, 'wb') as f:
303 | f.write(tmp_g.SerializeToString())
304 | return pb_file
305 | except Exception as e:
306 | logger.error('fail to optimize the graph! %s' % e, exc_info=True)
307 |
308 |
309 | def optimize_class_model(args, num_labels, logger=None):
310 | """
311 | 加载中文分类模型
312 | :param args:
313 | :param num_labels:
314 | :param logger:
315 | :return:
316 | """
317 | if not logger:
318 | logger = set_logger(colored('CLASSIFICATION_MODEL, Lodding...', 'cyan'), args.verbose)
319 | try:
320 | # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径
321 | if args.model_pb_dir is None:
322 | # 获取当前的运行路径
323 | tmp_file = os.path.join(os.getcwd(), 'predict_optimizer')
324 | if not os.path.exists(tmp_file):
325 | os.mkdir(tmp_file)
326 | else:
327 | tmp_file = args.model_pb_dir
328 | pb_file = os.path.join(tmp_file, 'classification_model.pb')
329 | if os.path.exists(pb_file):
330 | print('pb_file exits', pb_file)
331 | return pb_file
332 | import tensorflow as tf
333 |
334 | graph = tf.Graph()
335 | with graph.as_default():
336 | with tf.Session() as sess:
337 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
338 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
339 |
340 | bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json'))
341 | from bert_base.train.models import create_classification_model
342 | #为了兼容多输入,增加segment_id特征,即训练代码中的input_type_ids特征。
343 | #loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False,
344 | #input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
345 | segment_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'segment_ids')
346 | loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, num_labels=num_labels)
347 | # pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids')
348 | # pred_ids = tf.identity(pred_ids, 'pred_ids')
349 | probabilities = tf.identity(probabilities, 'pred_prob')
350 | saver = tf.train.Saver()
351 |
352 | with tf.Session() as sess:
353 | sess.run(tf.global_variables_initializer())
354 | saver.restore(sess, tf.train.latest_checkpoint(args.model_dir))
355 | logger.info('freeze...')
356 | from tensorflow.python.framework import graph_util
357 | tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
358 | logger.info('predict cut finished !!!')
359 | # 存储二进制模型到文件中
360 | logger.info('write graph to a tmp file: %s' % pb_file)
361 | with tf.gfile.GFile(pb_file, 'wb') as f:
362 | f.write(tmp_g.SerializeToString())
363 | return pb_file
364 | except Exception as e:
365 | logger.error('fail to optimize the graph! %s' % e, exc_info=True)
366 |
367 |
368 |
369 |
--------------------------------------------------------------------------------