├── qiznlp ├── run │ ├── __init__.py │ ├── run_mch.py │ ├── run_base.py │ ├── run_s2l.py │ ├── run_multi_s2s.py │ └── run_cls.py ├── common │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── DAM │ │ │ ├── __init__.py │ │ │ ├── operations.py │ │ │ └── layers.py │ │ ├── bert │ │ │ ├── requirements.txt │ │ │ ├── chinese_L-12_H-768_A-12 │ │ │ │ └── bert_config.json │ │ │ ├── __init__.py │ │ │ ├── optimization.py │ │ │ ├── LICENSE │ │ │ └── tokenization.py │ │ ├── embedding.py │ │ ├── idcnn.py │ │ ├── bert_model.py │ │ ├── birnn.py │ │ └── rerank.py │ ├── train_helper.py │ └── tfrecord_utils.py ├── deploy │ ├── __init__.py │ ├── web_API.py │ └── example.py ├── model │ ├── __init__.py │ └── mch_model.py └── __init__.py ├── MANIFEST.in ├── run_demo.gif ├── main_class_diagram.png ├── setup.py ├── .gitignore └── README.md /qiznlp/run/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qiznlp/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qiznlp/deploy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qiznlp/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qiznlp/common/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qiznlp/common/modules/DAM/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /run_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qznan/QizNLP/HEAD/run_demo.gif -------------------------------------------------------------------------------- /main_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qznan/QizNLP/HEAD/main_class_diagram.png -------------------------------------------------------------------------------- /qiznlp/common/modules/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /qiznlp/common/modules/bert/chinese_L-12_H-768_A-12/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /qiznlp/common/modules/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /qiznlp/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding=utf-8 3 | 4 | import os 5 | import shutil 6 | 7 | def qiznlp_init(): 8 | cwd = os.getcwd() 9 | curr_dir = os.path.dirname(__file__) 10 | 11 | print('copying ...') 12 | 13 | shutil.copytree(f'{curr_dir}/model', f'{cwd}/model') 14 | print('copy model_dir finish') 15 | 16 | shutil.copytree(f'{curr_dir}/run', f'{cwd}/run') 17 | os.remove(f'{cwd}/run/run_base.py') 18 | print('copy run_dir finish') 19 | 20 | shutil.copytree(f'{curr_dir}/deploy', f'{cwd}/deploy') 21 | print('copy deploy_dir finish') 22 | 23 | shutil.copytree(f'{curr_dir}/data', f'{cwd}/data') 24 | print('copy data_dir finish') 25 | 26 | shutil.copytree(f'{curr_dir}/common/modules/bert/chinese_L-12_H-768_A-12', 27 | f'{cwd}/common/modules/bert/chinese_L-12_H-768_A-12') 28 | print('copy bert_model_dir finish') 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from setuptools import setup, find_packages 4 | 5 | with open("README.md", "r") as fh: 6 | long_description = fh.read() 7 | 8 | setup( 9 | name='QizNLP', 10 | version='0.1.4', 11 | author='Qznan', 12 | author_email='summerzynqz@gmail.com', 13 | description='Quick run NLP in many task', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | license='MPLv2.0', 17 | url='https://github.com/Qznan/QizNLP', 18 | packages=find_packages(), 19 | package_data={'qiznlp': 20 | ['data/*.txt', 21 | 'common/modules/bert/chinese_L-12_H-768_A-12/*.json', 22 | 'common/modules/bert/chinese_L-12_H-768_A-12/*.txt', 23 | ], 24 | }, 25 | install_requires=[ 26 | 'jieba', 27 | 'tensorflow>=1.8, <=1.14' 28 | ], 29 | python_requires='>=3.6', 30 | entry_points={ 31 | 'console_scripts': [ 32 | 'qiznlp_init=qiznlp:qiznlp_init', 33 | ] 34 | }, 35 | keywords='NLP Classification Match Sequence_Label Senquence_to_Senquence Neural_Network', 36 | classifiers=[ 37 | 'Development Status :: 2 - Pre-Alpha', 38 | 'Environment :: Console', 39 | 'Environment :: MacOS X', 40 | 'Intended Audience :: Developers', 41 | 'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', 42 | 'Programming Language :: Python :: 3', 43 | 'Programming Language :: Python :: 3.6', 44 | 'Programming Language :: Python :: 3.7', 45 | 'Programming Language :: Python :: 3.8', 46 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 47 | ], 48 | ) 49 | -------------------------------------------------------------------------------- /qiznlp/deploy/web_API.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os, time, json, sys 4 | from run.run_cls import Run_Model_Cls 5 | import tornado.ioloop 6 | import tornado.web 7 | from urllib import parse # parse.quote 文本->url编码 ' '->'%20' parse.unquote url编码->文本 '%20'->' ' 8 | 9 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用CPU设为'-1' 11 | 12 | rm_cls = Run_Model_Cls('trans_mhattnpool') 13 | rm_cls.restore('cls_ckpt_toutiao1') # restore for infer 14 | 15 | 16 | def web_predict(sent, need_cut=True): 17 | time0 = time.time() 18 | pred = rm_cls.predict([sent], need_cut=need_cut)[0] 19 | print('elapsed:', time.time() - time0) 20 | return pred 21 | 22 | 23 | class MainHandler(tornado.web.RequestHandler): 24 | def get(self): 25 | sent = self.get_argument('sent') 26 | sent = parse.unquote(sent) # 你好%20嗯嗯 -> 你好 嗯嗯 27 | 28 | print('sent:', sent) 29 | ret = web_predict(sent) 30 | ret = {'result': ret} 31 | ret_str = json.dumps(ret, ensure_ascii=False) 32 | print(ret_str) 33 | self.write(ret_str) 34 | sys.stdout.flush() 35 | 36 | def post(self): 37 | body_data = self.request.body 38 | if isinstance(body_data, bytes): 39 | body_data = body_data.decode('utf-8') 40 | args_data = json.loads(body_data) 41 | sent = args_data.get('sent', None) 42 | print('sent:', sent) 43 | if sent is None: 44 | return 45 | ret = web_predict(sent) 46 | ret = {'result': ret} 47 | ret_str = json.dumps(ret, ensure_ascii=False) 48 | print(ret_str) 49 | self.write(ret_str) 50 | sys.stdout.flush() 51 | 52 | 53 | def make_app(): 54 | return tornado.web.Application([ 55 | (r"/QizNLP/predict", MainHandler), 56 | ]) 57 | 58 | 59 | if __name__ == '__main__': 60 | DEFAULT_PORT = 8090 61 | app = make_app() 62 | app.listen(DEFAULT_PORT) 63 | tornado.ioloop.IOLoop.current().start() 64 | 65 | # test with follow: 66 | # curl localhost:8090/QizNLP/predict?sent=去日本的邮轮游需要5万的资产证明吗? 67 | -------------------------------------------------------------------------------- /qiznlp/common/modules/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from .common_layers import * 4 | 5 | 6 | def embedding(ids, vocab_size, embedding_size, name='embedding', reuse=tf.AUTO_REUSE, pad_id=0, scale_sqrt_depth=True, 7 | pretrain_embedding=None, pretrain_trainable=True, word_dropout_rate=0.): 8 | """ embedding """ 9 | # ids 3-D Tensor [batch, length, 1] 10 | if pretrain_embedding is None: 11 | with tf.variable_scope(name, reuse=reuse): 12 | var = tf.get_variable('weights', [vocab_size, embedding_size], # [vocab,embed] 13 | initializer=tf.random_normal_initializer(0.0, embedding_size ** -0.5)) 14 | else: 15 | with tf.variable_scope(name, reuse=reuse): 16 | var = tf.get_variable('weights', [vocab_size, embedding_size], # [vocab,embed] 17 | trainable=pretrain_trainable, 18 | initializer=tf.constant_initializer(pretrain_embedding, dtype=tf.float32)) 19 | 20 | # word level drop out 21 | if word_dropout_rate: 22 | ids = dropout_no_scaling(ids, 1.0 - word_dropout_rate) # 随机将部分id变为0,相当于将单词变为pad 23 | 24 | # lookup table 25 | embedding = tf.gather(var, ids) # [batch,length,1,hidden] 26 | embedding = tf.squeeze(embedding, axis=-2) # [batch,length,hidden] 27 | if scale_sqrt_depth: 28 | embedding *= embedding_size ** 0.5 29 | embedding = embedding * tf.to_float(tf.not_equal(ids, pad_id)) # 将pad(id=0)的emb变为[0,0,...] 30 | return embedding, var 31 | 32 | 33 | def proj_logits(outputs, hidden_size, logit_size, name='proj_logits', reuse=tf.AUTO_REUSE): 34 | """ if name = 'embedding' 复用embed矩阵 35 | outputs [batch, length, hidden] or [batch, hidden] 36 | """ 37 | 38 | with tf.variable_scope(name, reuse=reuse): 39 | var = tf.get_variable('weights', [logit_size, hidden_size], # [vocab,hidden] 40 | initializer=tf.random_normal_initializer(0.0, hidden_size ** -0.5)) 41 | 42 | outputs_shape = shape_list(outputs) # [batch, length, hidden] 43 | outputs = tf.reshape(outputs, [-1, outputs_shape[-1]]) # [batch*length,hidden] 44 | logits = tf.matmul(outputs, var, transpose_b=True) # x,h * h,l -> x,l 45 | logits = tf.reshape(logits, outputs_shape[:-1] + [logit_size]) # [batch,length,vocab] 46 | 47 | return logits 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.swp 6 | 7 | # IDE 8 | .idea/ 9 | 10 | # mac 11 | .DS_Store/ 12 | 13 | # C extensions 14 | *.so 15 | 16 | # 一些辅助脚本 17 | scripts/ 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ -------------------------------------------------------------------------------- /qiznlp/deploy/example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import jieba 4 | import numpy as np 5 | import tensorflow as tf 6 | import os, time, re, sys 7 | 8 | """ 9 | deploy cls_model example(for toutiao dataset) 10 | """ 11 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 12 | sys.path.append(curr_dir) 13 | sys.path.append(curr_dir + '/..') 14 | 15 | import qiznlp.common.utils as utils 16 | from qiznlp.model.cls_model import Model 17 | 18 | 19 | class Deplpy_CLS_Model(): 20 | def __init__(self, ckpt_name=None, pbmodel_dir=None): 21 | assert ckpt_name or pbmodel_dir, 'ues at least one way' 22 | self.graph = tf.Graph() 23 | self.config = tf.ConfigProto(allow_soft_placement=True, 24 | gpu_options=tf.GPUOptions(allow_growth=True), 25 | ) 26 | self.sess = tf.Session(config=self.config, graph=self.graph) 27 | 28 | self.token2id_dct = { 29 | 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiao_cls_word2id.dct', use_line_no=True), 30 | 'label2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiao_cls_label2id.dct', use_line_no=True), 31 | } 32 | self.jieba = jieba.Tokenizer() 33 | self.tokenize = lambda t: self.jieba.lcut(re.sub(r'\s+', ',', t)) 34 | self.cut = lambda t: ' '.join(self.tokenize(t)) 35 | if ckpt_name: 36 | self.load_from_ckpt_meta(ckpt_name) 37 | else: 38 | self.load_from_pbmodel(pbmodel_dir) 39 | 40 | self.id2label = self.token2id_dct['label2id'].get_reverse() 41 | 42 | def load_from_ckpt_meta(self, ckpt_name): 43 | self.model, self.saver = Model.from_ckpt_meta(ckpt_name, self.sess, self.graph) 44 | 45 | def load_from_pbmodel(self, pbmodel_dir): 46 | self.model = Model.from_pbmodel(pbmodel_dir, self.sess) 47 | 48 | def predict(self, sent, need_cut=True): 49 | if need_cut: 50 | sent = self.cut(sent) 51 | feed_dict = self.model.create_feed_dict_from_raw([sent], [], self.token2id_dct) 52 | prob = self.sess.run(self.model.y_prob, feed_dict)[0] 53 | pred = np.argmax(prob) # [batch] 54 | return self.id2label[pred] 55 | 56 | 57 | if __name__ == '__main__': 58 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用CPU设为'-1' 59 | 60 | # option 1. from ckpt 61 | # assume ckpt file is 'cls_ckpt_toutiao1/trans_mhattnpool-7-0.58-0.999.ckpt-854' 62 | dcm = Deplpy_CLS_Model(ckpt_name=f'{curr_dir}/../run/cls_ckpt_toutiao1/trans_mhattnpool-7-0.58-0.999.ckpt-854') 63 | print(dcm.predict('去日本的邮轮游需要5万的资产证明吗?')) 64 | 65 | 66 | # option 2. from pbmodel 67 | # firstly export pbmodel in [cls_pbmodel_dir] with: 68 | """ 69 | rm_cls = Run_Model_Cls('trans_mhattnpool') 70 | rm.export_model(f'{curr_dir}/../run/cls_pbmodel_dir') 71 | """ 72 | # then exec follow: 73 | dcm = Deplpy_CLS_Model(pbmodel_dir=f'{curr_dir}/../run/cls_pbmodel_dir') 74 | print(dcm.predict('去日本的邮轮游需要5万的资产证明吗?')) 75 | -------------------------------------------------------------------------------- /qiznlp/common/modules/idcnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import tensorflow as tf 4 | 5 | 6 | class IDCNN(): 7 | def __init__(self, **kwargs): 8 | """ 9 | """ 10 | self.kernel_size = 3 # kernel_size 11 | self.num_filters = 100 # out_channel 12 | self.repeat_times = 4 # 共4个block 13 | self.layers = [ # 每个block结构 14 | { 15 | 'dilation': 1 16 | }, 17 | { 18 | 'dilation': 3 19 | }, 20 | { 21 | 'dilation': 5 22 | }, 23 | ] 24 | 25 | def __call__(self, embedding, name='idcnn', reuse=tf.AUTO_REUSE, **kwargs): 26 | """ 27 | :param idcnn_inputs embedding: [batch, len, embed] 28 | :return: [batch, len, num_filter * repeat_times] 29 | """ 30 | with tf.variable_scope(name, reuse=reuse): 31 | inputs = tf.expand_dims(embedding, 1) # [batch, 1, length, embed] 32 | 33 | # shape of input = [batch, in_height, in_width, in_channels] 34 | layerinput = tf.layers.conv2d(inputs, self.num_filters, [1, self.kernel_size], 35 | strides=[1, 1], padding='SAME', 36 | use_bias=False, 37 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 38 | name='init_layer') 39 | 40 | final_out_from_layers = [] 41 | total_width_for_last_dim = 0 42 | for j in range(self.repeat_times): # 多个block共享参数 43 | for i in range(len(self.layers)): 44 | dilation = self.layers[i]['dilation'] 45 | with tf.variable_scope("atrous-conv-layer-%d" % i, reuse=tf.AUTO_REUSE): 46 | w = tf.get_variable("filterW", shape=[1, self.kernel_size, self.num_filters, self.num_filters], 47 | initializer=tf.contrib.layers.xavier_initializer()) 48 | b = tf.get_variable("filterB", shape=[self.num_filters]) 49 | layeroutput = tf.nn.atrous_conv2d(layerinput, 50 | w, 51 | rate=dilation, 52 | padding="SAME") 53 | 54 | layeroutput = tf.nn.bias_add(layeroutput, b) 55 | layeroutput = tf.nn.relu(layeroutput) 56 | if i == (len(self.layers) - 1): 57 | final_out_from_layers.append(layeroutput) 58 | total_width_for_last_dim += self.num_filters 59 | layerinput = layeroutput 60 | idcnn_output = tf.concat(axis=3, values=final_out_from_layers) # 【batch,1,len,hid]】 61 | idcnn_output = tf.squeeze(idcnn_output, 1) 62 | 63 | return idcnn_output # [batch,len,hid] 64 | -------------------------------------------------------------------------------- /qiznlp/common/modules/bert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os 4 | import tensorflow as tf 5 | from .bert import modeling 6 | from .bert import tokenization 7 | 8 | class BERT(object): 9 | def __init__(self, 10 | bert_model_dir, 11 | is_training, 12 | input_ids, # [batch,len] 13 | input_mask=None, # [batch,len] 14 | segment_ids=None, # [batch,len] 15 | verbose=False 16 | ): 17 | self.bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_model_dir, 'bert_config.json')) 18 | self.is_training = is_training 19 | self.input_ids = input_ids 20 | self.input_mask = input_mask 21 | self.segment_ids = segment_ids 22 | self.bm = None 23 | self.build_model(self.input_ids, self.input_mask, self.segment_ids) 24 | if self.is_training: 25 | self.restore_vars(os.path.join(bert_model_dir, 'bert_model.ckpt'), verbose=verbose) 26 | 27 | def build_model(self, input_ids, input_mask=None, segment_ids=None): 28 | self.bm = modeling.BertModel( 29 | config=self.bert_config, 30 | is_training=self.is_training, 31 | input_ids=input_ids, 32 | input_mask=input_mask, 33 | token_type_ids=segment_ids, 34 | use_one_hot_embeddings=False 35 | ) 36 | 37 | def get_pooled_output(self): 38 | """ 获取对应的bert的[CLS]输出[batch_size, hidden_size] """ 39 | if not self.bm: 40 | print('bert model has not build yet, please call build_model()') 41 | return self.bm.get_pooled_output() 42 | 43 | def get_sequence_output(self): 44 | """ 获取对应的bert输出[batch_size, seq_length, hidden_size] """ 45 | if not self.bm: 46 | print('bert model has not build yet, please call build_model()') 47 | return self.bm.get_sequence_output() 48 | 49 | def get_seqlen(self): 50 | """ 获取input_ids真实长度 """ 51 | seqlen = tf.cast(tf.reduce_sum(tf.sign(tf.abs(self.input_ids)), axis=1), tf.int32) # [batch] 52 | return seqlen 53 | 54 | def get_nonpad_mask(self): 55 | """ 获取input_ids的mask: 1 for nonpad(!=0) - 0 for pad(=0) """ 56 | mask = tf.sign(tf.abs(self.input_ids)) # [batch,len] 57 | return mask 58 | 59 | def restore_vars(self, init_checkpoint, verbose=False): 60 | """ 加载预训练的BERT模型参数 """ 61 | print(f'<<<<<< restoring bert vars from ckpt: {init_checkpoint}') 62 | tvars = tf.trainable_variables() 63 | assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 64 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 65 | if verbose: 66 | print("**** Trainable Variables ****") 67 | for var in tvars: 68 | init_string = "" 69 | if var.name in initialized_variable_names: 70 | init_string = ", *INIT_FROM_CKPT*" 71 | print(f' name = {var.name}, shape = {var.shape}{init_string}') 72 | 73 | def get_tokenizer(vocab_file): 74 | return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) -------------------------------------------------------------------------------- /qiznlp/common/modules/birnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import tensorflow as tf 4 | 5 | 6 | class Bi_RNN(): 7 | def __init__(self, **kwargs): 8 | """ 9 | """ 10 | self.cell_name = kwargs['cell_name'] # 'GRUCell'/'LSTMCell' 11 | self.dropout = kwargs['dropout_rate'] # default 0. (0,1) 12 | self.hidden_size = kwargs['hidden_size'] 13 | 14 | Cell = getattr(tf.nn.rnn_cell, self.cell_name) # class GRUCell/LSTMCell 15 | self.fw_cell = Cell(self.hidden_size, name='fw') 16 | self.bw_cell = Cell(self.hidden_size, name='bw') 17 | 18 | if isinstance(self.dropout, tf.Tensor): # 是placeholder则一定要先wrapper了 19 | self.fw_cell = tf.nn.rnn_cell.DropoutWrapper(self.fw_cell, input_keep_prob=1. - self.dropout, output_keep_prob=1.0) 20 | self.bw_cell = tf.nn.rnn_cell.DropoutWrapper(self.bw_cell, input_keep_prob=1. - self.dropout, output_keep_prob=1.0) 21 | else: 22 | if self.dropout: # float not 0. 23 | self.fw_cell = tf.nn.rnn_cell.DropoutWrapper(self.fw_cell, input_keep_prob=1. - self.dropout, output_keep_prob=1.0) 24 | self.bw_cell = tf.nn.rnn_cell.DropoutWrapper(self.bw_cell, input_keep_prob=1. - self.dropout, output_keep_prob=1.0) 25 | 26 | def __call__(self, embedding, seq_length, name=None, reuse=tf.AUTO_REUSE, **kwargs): 27 | """ 28 | :param inputs embedding: [batch, length, embed] 29 | seq_length: [batch] 30 | :return: outputs: [batch, length, 2*hidden], state: [batch, 2*hidden] 31 | """ 32 | scope_name = name + '/' + f'bi_{self.cell_name}_encoder' if name else f'bi_{self.cell_name}_encoder' 33 | with tf.variable_scope(scope_name, reuse=reuse): 34 | (fw_outputs, bw_outputs), (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn(self.fw_cell, self.bw_cell, 35 | embedding, 36 | sequence_length=seq_length, 37 | dtype=tf.float32) 38 | if self.cell_name == 'LSTMCell': 39 | # LSTMStateTuple 40 | fw_state = fw_state.h 41 | bw_state = bw_state.h 42 | 43 | outputs = tf.concat([fw_outputs, bw_outputs], axis=-1) # [batch,length,2*hidden] 44 | state = tf.concat([fw_state, bw_state], axis=-1) # [batch,2*hidden] 45 | 46 | return outputs, state 47 | 48 | 49 | class RNN(): 50 | def __init__(self, **kwargs): 51 | """ 52 | """ 53 | self.cell_name = kwargs['cell_name'] # 'GRUCell'/'LSTMCell' 54 | self.dropout = kwargs['dropout_rate'] # default 0. (0,1) 55 | self.hidden_size = kwargs['hidden_size'] 56 | self.name = kwargs['name'] 57 | 58 | Cell = getattr(tf.nn.rnn_cell, self.cell_name) # class GRUCell/LSTMCell 59 | self.cell = Cell(self.hidden_size, name='rnn_cell') 60 | 61 | if isinstance(self.dropout, tf.Tensor): # 是placeholder则一定要先wrapper了 62 | self.cell = tf.nn.rnn_cell.DropoutWrapper(self.cell, input_keep_prob=1. - self.dropout, output_keep_prob=1.0) 63 | else: 64 | if self.dropout: # float not 0. 65 | self.cell = tf.nn.rnn_cell.DropoutWrapper(self.cell, input_keep_prob=1. - self.dropout, output_keep_prob=1.0) 66 | 67 | def __call__(self, embedding, seq_length, reuse=tf.AUTO_REUSE, one_step=False, **kwargs): 68 | """ 69 | :param inputs embedding: [batch, length, embed] 70 | :return: outputs: [batch, length, 2*hidden], state: [batch, 2*hidden] 71 | """ 72 | with tf.variable_scope(f'{self.name}_{self.cell_name}', reuse=reuse): 73 | outputs, state = tf.nn.dynamic_rnn(self.cell, 74 | embedding, 75 | sequence_length=seq_length, 76 | dtype=tf.float32) 77 | return outputs, state # [batch,length,hidden] [batch,hidden] 78 | 79 | def one_step(self, one_step_input, state): 80 | with tf.variable_scope(f'{self.name}_{self.cell_name}/rnn', reuse=tf.AUTO_REUSE): 81 | output, state = self.cell(one_step_input, state) 82 | return output, state # [batch,hidden] [batch,hidden] 83 | -------------------------------------------------------------------------------- /qiznlp/common/modules/bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /qiznlp/common/train_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os, pickle 4 | import numpy as np 5 | from . import utils 6 | 7 | 8 | def prepare_tfrecord(raw_data_file, 9 | model, 10 | token2id_dct, 11 | tokenize, 12 | preprocess_raw_data_fn, 13 | save_data_prefix, 14 | save_data_dir='../data', 15 | update_txt=False, 16 | update_tfrecord=False, 17 | **kwargs): 18 | # 1. read raw_data_file 19 | # 2. preprocess (seg/split...) -> .train .dev .test 20 | # 3. generate tfrecord 21 | # 4. load tfrecord 22 | if not os.path.exists(save_data_dir): 23 | os.makedirs(save_data_dir) 24 | # define data file name 25 | train_txt_file = f'{save_data_dir}/{save_data_prefix}_train_data.txt' 26 | dev_txt_file = f'{save_data_dir}/{save_data_prefix}_dev_data.txt' 27 | test_txt_file = f'{save_data_dir}/{save_data_prefix}_test_data.txt' 28 | train_tfrecord_file = f'{save_data_dir}/{save_data_prefix}_train_data.tfrecord' 29 | dev_tfrecord_file = f'{save_data_dir}/{save_data_prefix}_dev_data.tfrecord' 30 | test_tfrecord_file = f'{save_data_dir}/{save_data_prefix}_test_data.tfrecord' 31 | 32 | if update_txt: 33 | update_tfrecord = True 34 | if os.path.exists(train_txt_file): os.remove(train_txt_file) 35 | if os.path.exists(dev_txt_file): os.remove(dev_txt_file) 36 | if os.path.exists(test_txt_file): os.remove(test_txt_file) 37 | 38 | if not all([os.path.exists(train_tfrecord_file), os.path.exists(dev_tfrecord_file)]) or update_tfrecord: # 没有或需要更新tfrecord 39 | # 首先检查txt file 40 | if not all([os.path.exists(train_txt_file), os.path.exists(dev_txt_file)]): 41 | train_data, dev_data, test_data = preprocess_raw_data_fn(raw_data_file, tokenize=tokenize, token2id_dct=token2id_dct, **kwargs) 42 | if train_data: 43 | print(f'generating train txt file... filename: {train_txt_file}') 44 | utils.list2file(train_txt_file, train_data) 45 | if dev_data: 46 | print(f'generating dev txt file... filename: {dev_txt_file}') 47 | utils.list2file(dev_txt_file, dev_data) 48 | if test_data: 49 | print(f'generating test txt file... filename: {test_txt_file}') 50 | utils.list2file(test_txt_file, test_data) 51 | 52 | if os.path.exists(train_txt_file): 53 | model.generate_tfrecord(train_txt_file, token2id_dct, train_tfrecord_file) 54 | if os.path.exists(dev_txt_file): 55 | model.generate_tfrecord(dev_txt_file, token2id_dct, dev_tfrecord_file) 56 | if os.path.exists(test_txt_file): 57 | model.generate_tfrecord(test_txt_file, token2id_dct, test_tfrecord_file) 58 | 59 | return train_tfrecord_file, dev_tfrecord_file, test_tfrecord_file 60 | 61 | 62 | def prepare_pkldata(raw_data_file, 63 | model, 64 | token2id_dct, 65 | tokenize, 66 | preprocess_raw_data_fn, 67 | save_data_prefix, 68 | save_data_dir='../data', 69 | update_txt=False, 70 | update_pkl=False, 71 | **kwargs): 72 | # 返回data: 各个字段有全量数据 73 | # 同时返回训练及验证的数据数量 74 | if not os.path.exists(save_data_dir): 75 | os.makedirs(save_data_dir) 76 | # define data file name 77 | train_txt_file = f'{save_data_dir}/{save_data_prefix}_train_data.txt' 78 | dev_txt_file = f'{save_data_dir}/{save_data_prefix}_dev_data.txt' 79 | test_txt_file = f'{save_data_dir}/{save_data_prefix}_test_data.txt' 80 | train_pkl_file = f'{save_data_dir}/{save_data_prefix}_train_data.pkl' 81 | dev_pkl_file = f'{save_data_dir}/{save_data_prefix}_dev_data.pkl' 82 | test_pkl_file = f'{save_data_dir}/{save_data_prefix}_test_data.pkl' 83 | 84 | if update_txt: 85 | update_pkl = True 86 | if os.path.exists(train_txt_file): os.remove(train_txt_file) 87 | if os.path.exists(dev_txt_file): os.remove(dev_txt_file) 88 | if os.path.exists(test_txt_file): os.remove(test_txt_file) 89 | 90 | if not all([os.path.exists(train_pkl_file), os.path.exists(dev_pkl_file)]) or update_pkl: # 没有或需要更新pkldata 91 | # 首先检查txt file 92 | if not all([os.path.exists(train_txt_file), os.path.exists(dev_txt_file)]): 93 | train_data, dev_data, test_data = preprocess_raw_data_fn(raw_data_file, tokenize=tokenize, token2id_dct=token2id_dct, **kwargs) 94 | if train_data: 95 | utils.list2file(train_txt_file, train_data) 96 | print(f'generate train txt file ok! {train_txt_file}') 97 | if dev_data: 98 | utils.list2file(dev_txt_file, dev_data) 99 | print(f'generate dev txt file ok! {dev_txt_file}') 100 | if test_data: 101 | utils.list2file(test_txt_file, test_data) 102 | print(f'generate test txt file ok! {test_txt_file}') 103 | 104 | if os.path.exists(train_txt_file): 105 | train_pkldata = model.generate_data(train_txt_file, token2id_dct) 106 | pickle.dump(train_pkldata, open(train_pkl_file, 'wb')) 107 | print(f'generate and save train pkl file ok! {train_pkl_file}') 108 | if os.path.exists(dev_txt_file): 109 | dev_pkldata = model.generate_data(dev_txt_file, token2id_dct) 110 | pickle.dump(dev_pkldata, open(dev_pkl_file, 'wb')) 111 | print(f'generate and save train pkl file ok! {dev_pkl_file}') 112 | if os.path.exists(test_txt_file): 113 | test_pkldata = model.generate_data(test_txt_file, token2id_dct) 114 | pickle.dump(test_pkldata, open(test_pkl_file, 'wb')) 115 | print(f'generate and save train pkl file ok! {test_pkl_file}') 116 | 117 | return train_pkl_file, dev_pkl_file, test_pkl_file 118 | 119 | 120 | # 得到当前图中所有变量的名称 121 | # tensor_names = [tensor.name for tensor in graph.as_graph_def().node] # 得到当前图中所有变量的名称 122 | # for tensor_name in tensor_names: 123 | # if not tensor_name.startswith('save') and not tensor_name.startswith('gradients'): 124 | # if 'Adam' not in tensor_name and 'Initializer' not in tensor_name: 125 | # print(tensor_name) 126 | 127 | 128 | def gen_pos_neg_sample(items, sample_idx, num_neg_exm=9, seed=1234): 129 | import random 130 | import copy 131 | total_exm = [] 132 | random.seed(seed) # 采样随机负样本的随机性 133 | cands_ids = list(range(len(items))) 134 | 135 | for i, exm in enumerate(items): 136 | exm.append(1) # add pos label 137 | total_exm.append(exm) # add pos exm 138 | 139 | neg_exm_lst = [] 140 | neg_sample_set = set() 141 | 142 | while len(neg_exm_lst) < min(num_neg_exm, len(items) - 1): # data数量比负采样样本数还少的情况 143 | neg_id = random.choice(cands_ids) 144 | if neg_id == i: # 如果采到自己 145 | continue 146 | if items[neg_id][sample_idx] == exm[sample_idx]: # 如果sample的文本相同 147 | continue 148 | if items[neg_id][sample_idx] in neg_sample_set: # 如果sample已经采样了 149 | continue 150 | neg_sample_set.add(items[neg_id][sample_idx]) 151 | neg_exm = copy.deepcopy(exm) 152 | neg_exm[sample_idx] = items[neg_id][sample_idx] # change to neg 153 | neg_exm[-1] = 0 # add neg label 154 | neg_exm_lst.append(neg_exm) # add neg exm 155 | 156 | total_exm.extend(neg_exm_lst) # add neg exms 157 | 158 | return total_exm 159 | 160 | 161 | def calc_recall(epo_s1, epo_prob, epo_y, topk=5, strip_pad=False): 162 | comb = list(zip(epo_s1, epo_prob, epo_y)) 163 | s1_dct = {} 164 | for s1, prob, y in comb: 165 | if 'ndarray' in str(type(s1)): 166 | s1 = s1.flatten().tolist() 167 | if strip_pad: 168 | while s1[-1] == 0: 169 | s1.pop() 170 | s1 = tuple(s1) 171 | if s1 not in s1_dct: 172 | s1_dct[s1] = [] 173 | s1_dct[s1].append([prob, y]) 174 | for s1 in s1_dct: 175 | s1_dct[s1].sort(key=lambda e: e[0], reverse=True) 176 | total_recall = [] 177 | for s1 in s1_dct: 178 | y_lst = [e[1] for e in s1_dct[s1]] 179 | # print(len(y_lst)) 180 | if len(y_lst) < 2: # 没有正负样本 181 | continue 182 | # print(sum(y_lst)) 183 | if sum(y_lst) != 1: # 真实标签中没有正样本1 或超过1个正样本 184 | continue 185 | recall_item = [sum(y_lst[:topk]) for topk in range(1, topk+1)] 186 | total_recall.append(recall_item) 187 | total_recall = np.mean(total_recall, axis=0).tolist() 188 | return total_recall -------------------------------------------------------------------------------- /qiznlp/common/modules/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import re 5 | import numpy as np 6 | 7 | INF = 1e7 8 | zh_char = re.compile('[\u4e00-\u9fa5]') 9 | en_char = re.compile('[a-zA-Z!~?.\s]') 10 | punc_char = re.compile('[,,.。!!??]]') 11 | 12 | bert_embed = None 13 | w2i = None 14 | word_count_prob = None 15 | 16 | """ bad starts """ 17 | bad_starts = ['我也是', '不是吧', '是不是', '其实', '不能', '但是', '反正', ' 你不是', '哪里有', 18 | '因为', '难道', '不要', '可是我', '那就不用', '你是不是', 19 | '不是', '为什么'] 20 | # bad_starts += ['你不会','你不是'] 21 | 22 | """ bad ends """ 23 | bad_ends = ['我是你的', '是什么?', '是什么', '不喜欢,喜欢'] 24 | bad_ends += ['?', '吗', '么', '什么东西'] 25 | # bad_ends += [''] # 在限定的长度中未解码完成 26 | # bad_ends += ['的','不会'] 27 | 28 | """ bad words """ 29 | bad_words = [',但是', '小三', '看到你的评论我笑了', '大家都是这样', '我们都是好孩子', 30 | '下次见面就是你了', '现在玩游戏都是一样的'] 31 | 32 | """ 33 | re notes 34 | 匹配以abc开头: ^(?=abc).*$ 35 | 不以abc开头: ^(?!abc).*$ 36 | 以abc结尾: ^.*?(?<=abc)$ 37 | 不以abc结尾: ^.*?(?= 14 \ 88 | or is_bad1(output) \ 89 | or is_conflict(input_sent, output, conflict_items_2): 90 | results[i][1] = -INF 91 | continue 92 | # 如果问题是英文而恢复全是英文则直接分数设为极小值并跳过 93 | if not is_english(input_sent) and is_english(output): 94 | results[i][1] = -INF 95 | continue 96 | # 长度归一化 97 | if is_english(input_sent): # 如果全是英文正常归一化 98 | results[i][1] = results[i][1] / len(results[i][0]) 99 | else: 100 | results[i][1] = results[i][1] / count_length_score(results[i][0]) 101 | 102 | # 不同字比率 103 | # results[i][1] = results[i][1] + distinct_ratio(output) * 0.5 104 | 105 | # # 鼓励不同首字 106 | # head_freqs = {} 107 | # for i, [resp, score] in enumerate(results): 108 | # if score == -INF: 109 | # continue 110 | # head = resp[0] 111 | # results[i][1] = score - (head_freqs.get(head, 0.) * 0.5) 112 | # head_freqs[head] = head_freqs.get(head, 0.) + 1 113 | 114 | # 排序 115 | rarank_results = sorted(results, key=lambda item: item[1], reverse=True) 116 | # print('after sorted:', results) 117 | return rarank_results 118 | 119 | 120 | def is_bad(sent): 121 | """ 处理单词包含/起始/结尾的句式 122 | 处理正则匹配命中的句式 123 | """ 124 | for begin in bad_starts: 125 | if sent.startswith(begin): 126 | return True 127 | for end in bad_ends: 128 | if sent.endswith(end): 129 | return True 130 | for word in bad_words: 131 | if word in sent: 132 | return True 133 | for bad_re_pattern in bad_re_patterns: 134 | if re.search(bad_re_pattern, sent): 135 | return True 136 | return False 137 | 138 | 139 | def is_bad1(sent): 140 | """ 处理逗号分隔的句式 """ 141 | if ',' not in sent: 142 | return False 143 | seg = sent.split(',') 144 | if len(seg) > 2: # 有两个以上逗号 145 | return False 146 | s1, s2 = seg 147 | if not len(s1) or not len(s2): # 以逗号开头和结尾 148 | return True 149 | set1, set2 = set(s1), set(s2) 150 | if len(s1) >= 4 and len(set1) / len(s1) < 0.6: 151 | return True 152 | if len(s2) >= 4 and len(set2) / len(s2) < 0.6: 153 | return True 154 | union_set = set1 & set2 155 | if len(union_set) >= 3: # 相同字数3个及以上 156 | return True 157 | if len(union_set) / len(s1) > 0.5 or len(union_set) / len(s2) > 0.5: # 前(后)半句几乎仅仅是相同的字 158 | return True 159 | if is_conflict(s1, s1, conflict_items_1): 160 | return True 161 | return False 162 | 163 | 164 | def is_conflict(s1, s2, conflict_items): 165 | """ 两句话前后出现互相矛盾 """ 166 | for i1, i2 in conflict_items: 167 | if not i1 and i2: # ('','问') 168 | if i2 not in s1 and i2 in s2: 169 | return True 170 | elif i1 and not i2: # ('好','') 171 | if i1 in s1 and i1 not in s2: 172 | return True 173 | elif i1 in i2: # 包含情况,长的是i2 ('喜欢','不喜欢') 174 | if i1 in s1 and i2 not in s1 and i2 in s2: # 我喜欢你,我不喜欢你 175 | return True 176 | if i1 in s2 and i2 not in s2 and i2 in s1: # 我不喜欢你,我喜欢你 177 | return True 178 | else: # 不包含的情况 (早,晚) (我们,你们) 179 | if (i1 in s1 and i2 in s1) or (i1 in s2 and i2 in s1): # 如果矛盾的两字同时出现在前半句或后半句中,不在此处理 180 | return False 181 | if (i1 in s1 and i2 in s2) or (i1 in s2 and i2 in s1): 182 | return True 183 | return False 184 | 185 | 186 | def distinct_ratio(sent): 187 | sent = re.sub(punc_char, '', sent) 188 | if len(sent) == 0: 189 | return 1. 190 | return round(len(set(sent)) / len(sent), 4) 191 | 192 | 193 | def is_english(sent): 194 | if not sent: 195 | return False 196 | if len(en_char.sub('', sent)) == 0: 197 | return True 198 | else: 199 | return False 200 | 201 | 202 | def count_length_score(sent): 203 | punc = ",。.。!!??" # 标点符号 204 | special_zh_char = '哈呵' # 特殊字符 205 | length_score = 0 206 | for s in sent: 207 | if re.match(zh_char, s): # 是中文 208 | length_score += 1 209 | elif s in punc: # 标点符号 210 | length_score += 0.5 211 | elif s in special_zh_char: # 中文特殊符号 212 | length_score += 0.7 213 | elif s.isdigit(): # 是数字 214 | length_score += 0.8 215 | else: 216 | pass 217 | return length_score 218 | 219 | 220 | def cos_sim(q, a): 221 | q, a = np.mat(q), np.mat(a) 222 | num = float(q * a.T) 223 | denom = np.linalg.norm(q) * np.linalg.norm(a) 224 | cos = num / denom 225 | sim = 0.5 + 0.5 * cos # 归一化为0-1 226 | return sim 227 | 228 | 229 | def calc_cos_sim(q, a, embed, w2i): 230 | vec_q = np.zeros(embed.shape[1]) 231 | vec_a = np.zeros(embed.shape[1]) 232 | for c in q: 233 | if c in w2i: 234 | vec_q += embed[w2i[c]] # bow 235 | for c in a: 236 | if c in w2i: 237 | vec_a += embed[w2i[c]] # bow 238 | return cos_sim(vec_q, vec_a) 239 | 240 | 241 | def levenshtein_distance(s1, s2): 242 | """ 编辑距离 """ 243 | rows = len(s1) + 1 244 | columns = len(s2) + 1 245 | # 创建矩阵 246 | matrix = [[0 for j in range(columns)] for i in range(rows)] # row * column 247 | for j in range(columns): # 矩阵第一行 248 | matrix[0][j] = j 249 | for i in range(rows): # 矩阵第一列 250 | matrix[i][0] = i 251 | # 根据状态转移方程逐步得到编辑距离 252 | for i in range(1, rows): 253 | for j in range(1, columns): 254 | if s1[i - 1] == s2[j - 1]: 255 | cost = 0 # 不需更改 256 | else: 257 | cost = 1 # 替换操作 258 | matrix[i][j] = min(matrix[i - 1][j - 1] + cost, 259 | matrix[i - 1][j] + 1, 260 | matrix[i][j - 1] + 1) 261 | return matrix[rows - 1][columns - 1] 262 | 263 | 264 | def print_distinct_ratio(output_sents): 265 | distinct_ratio_list = [[sent, distinct_ratio(sent)] for sent, _ in output_sents] 266 | distinct_ratio_list.sort(key=lambda x: x[1], reverse=True) 267 | print('-' * 10, 'distinct_ratio', sep='') 268 | for item in distinct_ratio_list: 269 | print(item) 270 | print('-' * 10, 'distinct_ratio', sep='') 271 | 272 | 273 | if __name__ == '__main__': 274 | print(rerank([['我也想吃', -2.1], ['我也想要', -2.5]])) 275 | pass 276 | -------------------------------------------------------------------------------- /qiznlp/common/tfrecord_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os, re, glob 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | TF_VERSION = int(tf.__version__.split('.')[1]) 9 | 10 | 11 | def get_num(tfrecord_file): 12 | gen = tf.python_io.tf_record_iterator(tfrecord_file) 13 | num = sum(1 for _ in gen) 14 | gen.close() 15 | return num 16 | 17 | 18 | def tf_sparse_to_dense_new(v): 19 | if isinstance(v, tf.sparse.SparseTensor): 20 | return tf.sparse.to_dense(v) 21 | return v 22 | 23 | 24 | def tf_sparse_to_dense_old(v): 25 | if isinstance(v, tf.SparseTensor): 26 | return tf.sparse_to_dense(v.indices, v.dense_shape, v.values) 27 | return v 28 | 29 | 30 | tf_sparse_to_densor = tf_sparse_to_dense_new if TF_VERSION >= 12 else tf_sparse_to_dense_old 31 | 32 | 33 | def flat(l): 34 | # 平摊flatten 35 | for k in l: 36 | if not isinstance(k, (list, tuple)): 37 | yield k 38 | else: 39 | yield from flat(k) 40 | 41 | 42 | def delete_exist_tfrecord_file(tfrecord_file): 43 | files = glob.glob(f'{tfrecord_file}_*') 44 | files = list(filter(lambda f: re.match(r'^.*_\d+$', f), files)) # filter invalid 45 | for f in files: 46 | if os.path.isfile(f): 47 | os.remove(f) 48 | 49 | 50 | def add_num(file, num): 51 | return f'{file}_{num}' 52 | 53 | 54 | def items2tfrecord(items, tfrecord_file): 55 | if os.path.exists(tfrecord_file): # 删除已有的 56 | os.remove(tfrecord_file) 57 | 58 | def int_feat(value): 59 | if not isinstance(value, (list, np.ndarray)): 60 | value = [value] 61 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 62 | 63 | def float_feat(value): 64 | if not isinstance(value, (list, np.ndarray)): 65 | value = [value] 66 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 67 | 68 | def byte_feat(value): 69 | if not isinstance(value, (list, np.ndarray)): 70 | value = [value] 71 | return tf.train.Feature(bytesList=tf.train.BytesList(value=value)) 72 | 73 | print(f'generating tfrecord... filename: {tfrecord_file}') 74 | writer = tf.python_io.TFRecordWriter(tfrecord_file) 75 | count = 0 76 | for item in items: 77 | if count and not count % 100000: 78 | print(f'generating tfrecord... count: {count}') 79 | features = {} 80 | try: 81 | for k, v in item.items(): 82 | # maybe should flatten 83 | if isinstance(v, np.ndarray) and v.ndim > 1: 84 | v = v.flatten() 85 | elif isinstance(v, list) and isinstance(v[0], list): 86 | v = list(flat(v)) 87 | 88 | ele = v[0] if isinstance(v, (list, np.ndarray)) else v # 需检查元素的类型 89 | 90 | if isinstance(ele, (int, np.int, np.int32, np.int64)): 91 | features[k] = int_feat(v) 92 | elif isinstance(ele, (float, np.float, np.float16, np.float32, np.float64)): 93 | features[k] = float_feat(v) 94 | else: 95 | features[k] = byte_feat(v) 96 | example = tf.train.Example(features=tf.train.Features(feature=features)) 97 | except: 98 | print('error item:', item) 99 | continue 100 | writer.write(example.SerializeToString()) 101 | count += 1 102 | writer.close() 103 | if count == 0: 104 | raise Exception(f'error! count = {count} no example to save') 105 | print(f'save tfrecord file ok! {tfrecord_file} total count: {count}') 106 | return count 107 | 108 | 109 | def tfrecord2dataset(tfrecord_files, feat_dct, shape_dct=None, batch_size=100, auto_pad=False, index=None, shard=None): 110 | """ 111 | tf.VarLenFeature只能是平展后的1维数据,故原始数据是二维的话,如[None,4],则需传入shape_dct进行reshape 112 | feat_dct = { 113 | 'target': tf.FixedLenFeature([], tf.int64), 114 | 's1': tf.VarLenFeature(tf.int64), 115 | 's1_char': tf.VarLenFeature(tf.int64), 116 | 'others': tf.FixedLenFeature([3], tf.string), 117 | } 118 | shape_dct = { 119 | 's1_char': [-1, 4] # if s1_char is 2-D Tensor and need to pad at the first dimension when batch 120 | } 121 | """ 122 | # [注]该函数返回dataset的过程不需在图中,但后续迭代需在图中使用,如: 123 | # with self.graph.as_default(): 124 | # iterator = dataset.make_one_shot_iterator() 125 | # dataset_features = iterator.get_next() 126 | # for i in train_step: 127 | # features = sess.run(dataset_features) 128 | 129 | if not isinstance(tfrecord_files, list): 130 | tfrecord_files = [tfrecord_files] 131 | tfrecord_files = [f for f in tfrecord_files if os.path.exists(f)] 132 | total_count = sum(get_num(f) for f in tfrecord_files) 133 | print(f'load exist tfrecord file ok! {" & ".join(tfrecord_files)} total count: {total_count}') 134 | 135 | def exm_parse(serialized_example): 136 | parsed_features = tf.parse_single_example(serialized_example, features=feat_dct) 137 | # VarLenFeature will return sparse tensor 138 | for k, v in parsed_features.items(): 139 | parsed_features[k] = tf_sparse_to_densor(v) # Convert sparse to dense when need 140 | # maybe need to reshape 141 | if shape_dct is not None: 142 | for name, shape in shape_dct.items(): 143 | parsed_features[name] = tf.reshape(parsed_features[name], shape) 144 | return parsed_features 145 | 146 | def ints2int32(features): 147 | for k, v in features.items(): 148 | if v.dtype in [tf.int64, tf.uint8]: 149 | features[k] = tf.cast(v, tf.int32) 150 | return features 151 | 152 | def padded_batch(dataset, batch_size, padded_shapes=None): 153 | if padded_shapes is None: 154 | padded_shapes = {k: [None] * len(shape) for k, shape in dataset.output_shapes.items()} 155 | return dataset.padded_batch(batch_size, padded_shapes) 156 | 157 | # tf.contrib.keras.preprocessing.sequence.pad_sequences 158 | 159 | dataset = tf.data.TFRecordDataset(tf.constant(tfrecord_files)) 160 | # dataset = tf.data.TFRecordDataset(tf.data.Dataset.from_tensor_slices(tf.constant(data_files))) 161 | 162 | # 是否分布式训练数据分片, 上层传来的local_rank设置shard的数据,保证各个gpu采样的数据不重叠。 163 | if shard is not None and index is not None and index < shard: 164 | dataset = dataset.shard(shard, index) 165 | batch_size = int(np.ceil(batch_size / shard)) # batch再均分 166 | 167 | dataset = dataset.map(exm_parse) 168 | # dataset = dataset.map(exm_parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) # need tf >= 1.14 169 | 170 | dataset = dataset.shuffle(buffer_size=5000, reshuffle_each_iteration=True) 171 | 172 | dataset = padded_batch(dataset, batch_size) if auto_pad else dataset.batch(batch_size) # batch时是否要自动补齐 173 | 174 | dataset = dataset.map(ints2int32) # 可向量化的操作放在batch后面以提高效率 175 | # dataset = dataset.map(ints2int32, num_parallel_calls=tf.data.experimental.AUTOTUNE) # 可向量化的操作放在batch后面以提高效率 176 | 177 | dataset = dataset.repeat() # repeat放在batch后面 178 | 179 | dataset = dataset.prefetch(2) # pipline先异步准备好n个batch数据,训练step时数据同时处理 180 | # dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # pipline先异步准备好n个batch数据,训练step时数据同时处理 181 | 182 | return dataset, total_count 183 | 184 | 185 | def tfrecord2queue(tfrecord_files, feat_dct): 186 | # 返回后还需使用 187 | # coord = tf.train.Coordinator() 188 | # threads = tf.train.start_queue_runners(coord=coord) 189 | # sess.run(tfrecord2queue) 190 | # coord.request_stop() 191 | # coord.join(threads) 192 | filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(tfrecord_files), 193 | shuffle=None, 194 | num_epochs=None) 195 | reader = tf.TFRecordReader() 196 | _, serialized_example = reader.read(filename_queue) 197 | 198 | features = tf.parse_single_example(serialized_example, features=feat_dct) 199 | 200 | for k in features: 201 | if features[k].dtype in [tf.int64, tf.uint8]: 202 | features[k] = tf.cast(features[k], tf.int32) 203 | 204 | # reshpae 205 | # features[k] = tf.reshape(features[k], [-1,1]) 206 | sorted_keys = list(sorted(features)) 207 | 208 | input_queue = tf.train.slice_input_producer([features[k] for k in sorted_keys], shuffle=False) 209 | # tf.data.Dataset.from_tensor_slices 210 | data = tf.train.batch(input_queue, batch_size=128, allow_smaller_final_batch=True, num_threads=8) 211 | ret = dict(zip(sorted_keys, data)) 212 | return ret # 返回的tensor供sess.run 213 | 214 | 215 | def check_tfrecord(tfrecord_file, feat_dct): 216 | true_count = 0 217 | dataset, total_count = tfrecord2dataset(tfrecord_file, feat_dct) 218 | iterator = dataset.make_one_shot_iterator() 219 | features = iterator.get_next() 220 | sess = tf.Session() 221 | while True: 222 | try: 223 | feat = sess.run(features) 224 | true_count += len(list(feat.values())[0]) 225 | except Exception as e: 226 | print(e) 227 | print(true_count) 228 | raise e 229 | 230 | 231 | if __name__ == '__main__': 232 | # os.environ['CUDA_VISIBLE_DEVICES'] = '6' 233 | # 234 | # feat_dct = { 235 | # 'contents': tf.FixedLenFeature([4, 50], tf.int64), 236 | # 'content_masks': tf.FixedLenFeature([4, 50], tf.int64), 237 | # 'content_lens': tf.FixedLenFeature([4], tf.int64), 238 | # 'char_contents': tf.FixedLenFeature([4, 50, 4], tf.int64), 239 | # 'char_content_masks': tf.FixedLenFeature([4, 50, 4], tf.int64), 240 | # 'char_content_lens': tf.FixedLenFeature([4, 50], tf.int64), 241 | # 'responses': tf.FixedLenFeature([50], tf.int64), 242 | # 'response_masks': tf.FixedLenFeature([50], tf.int64), 243 | # 'response_lens': tf.FixedLenFeature([], tf.int64), 244 | # 'char_responses': tf.FixedLenFeature([50, 4], tf.int64), 245 | # 'char_response_masks': tf.FixedLenFeature([50, 4], tf.int64), 246 | # 'char_response_lens': tf.FixedLenFeature([50], tf.int64), 247 | # 'targets': tf.FixedLenFeature([], tf.int64), 248 | # 'intents': tf.FixedLenFeature([4], tf.int64), 249 | # } 250 | # check_tfrecord('/dockerdata/yonaszhang/1223/data/match_ckpt_4_train_data.tfrecord_1', feat_dct) 251 | # exit(0) 252 | 253 | """ test """ 254 | import time 255 | 256 | # items = [ 257 | # {'s1':[1,2,3,4],'s2':[5]}, 258 | # {'s1':[2,3,4],'s2':[6]}, 259 | # ] 260 | # items2tfrecord(items, './test.tfrecord') 261 | feat_dct = { 262 | 's1': tf.VarLenFeature(tf.int64), 263 | 's2': tf.FixedLenFeature([], tf.int64), 264 | } 265 | dataset, _ = tfrecord2dataset('./test.tfrecord', feat_dct, batch_size=2, auto_pad=True) 266 | feature = dataset.make_one_shot_iterator().get_next() 267 | with tf.Session() as sess: 268 | while True: 269 | print(sess.run(feature)) 270 | time.sleep(1) 271 | # input('输入任意键以继续下一个') 272 | -------------------------------------------------------------------------------- /qiznlp/common/modules/bert/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /qiznlp/common/modules/DAM/operations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from scipy.stats import multivariate_normal 4 | import tensorflow as tf 5 | 6 | def learning_rate(step_num, d_model=512, warmup_steps=4000): 7 | a = step_num**(-0.5) 8 | b = step_num*warmup_steps**(-1.5) 9 | return a, b, d_model**(-0.5) * min(step_num**(-0.5), step_num*(warmup_steps**(-1.5))) 10 | 11 | def selu(x): 12 | alpha = 1.6732632423543772848170429916717 13 | scale = 1.0507009873554804934193349852946 14 | print('use selu') 15 | return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x)) 16 | 17 | def bilinear_sim_4d(x, y, is_nor=True): 18 | '''calulate bilinear similarity with two 4d tensor. 19 | 20 | Args: 21 | x: a tensor with shape [batch, time_x, dimension_x, num_stacks] 22 | y: a tensor with shape [batch, time_y, dimension_y, num_stacks] 23 | 24 | Returns: 25 | a tensor with shape [batch, time_x, time_y, num_stacks] 26 | 27 | Raises: 28 | ValueError: if 29 | the shapes of x and y are not match; 30 | bilinear matrix reuse error. 31 | ''' 32 | M = tf.get_variable( 33 | name="bilinear_matrix", 34 | shape=[x.shape[2], y.shape[2], x.shape[3]], 35 | dtype=tf.float32, 36 | initializer=tf.orthogonal_initializer()) 37 | sim = tf.einsum('biks,kls,bjls->bijs', x, M, y) 38 | 39 | if is_nor: 40 | scale = tf.sqrt(tf.cast(x.shape[2] * y.shape[2], tf.float32)) 41 | scale = tf.maximum(1.0, scale) 42 | return sim / scale 43 | else: 44 | return sim 45 | 46 | 47 | def bilinear_sim(x, y, is_nor=True): 48 | '''calculate bilinear similarity with two tensor. 49 | Args: 50 | x: a tensor with shape [batch, time_x, dimension_x] 51 | y: a tensor with shape [batch, time_y, dimension_y] 52 | 53 | Returns: 54 | a tensor with shape [batch, time_x, time_y] 55 | Raises: 56 | ValueError: if 57 | the shapes of x and y are not match; 58 | bilinear matrix reuse error. 59 | ''' 60 | M = tf.get_variable( 61 | name="bilinear_matrix", 62 | shape=[x.shape[-1], y.shape[-1]], 63 | dtype=tf.float32, 64 | initializer=tf.orthogonal_initializer()) 65 | sim = tf.einsum('bik,kl,bjl->bij', x, M, y) 66 | 67 | if is_nor: 68 | scale = tf.sqrt(tf.cast(x.shape[-1] * y.shape[-1], tf.float32)) 69 | scale = tf.maximum(1.0, scale) 70 | return sim / scale 71 | else: 72 | return sim 73 | 74 | def dot_sim(x, y, is_nor=True): 75 | '''calculate dot similarity with two tensor. 76 | 77 | Args: 78 | x: a tensor with shape [batch, time_x, dimension] 79 | y: a tensor with shape [batch, time_y, dimension] 80 | 81 | Returns: 82 | a tensor with shape [batch, time_x, time_y] 83 | Raises: 84 | AssertionError: if 85 | the shapes of x and y are not match. 86 | ''' 87 | assert x.shape[-1] == y.shape[-1] 88 | 89 | sim = tf.einsum('bik,bjk->bij', x, y) 90 | 91 | if is_nor: 92 | scale = tf.sqrt(tf.cast(x.shape[-1], tf.float32)) 93 | scale = tf.maximum(1.0, scale) 94 | return sim / scale 95 | else: 96 | return sim 97 | 98 | def layer_norm(x, axis=None, epsilon=1e-6): 99 | '''Add layer normalization. 100 | 101 | Args: 102 | x: a tensor 103 | axis: the dimensions to normalize 104 | 105 | Returns: 106 | a tensor the same shape as x. 107 | 108 | Raises: 109 | ''' 110 | print('wrong version of layer_norm') 111 | scale = tf.get_variable( 112 | name='scale', 113 | shape=[1], 114 | dtype=tf.float32, 115 | initializer=tf.ones_initializer()) 116 | bias = tf.get_variable( 117 | name='bias', 118 | shape=[1], 119 | dtype=tf.float32, 120 | initializer=tf.zeros_initializer()) 121 | 122 | if axis is None: 123 | axis = [-1] 124 | 125 | mean = tf.reduce_mean(x, axis=axis, keep_dims=True) 126 | variance = tf.reduce_mean(tf.square(x - mean), axis=axis, keep_dims=True) 127 | norm = (x-mean) * tf.rsqrt(variance + epsilon) 128 | return scale * norm + bias 129 | 130 | def layer_norm_debug(x, axis = None, epsilon=1e-6): 131 | '''Add layer normalization. 132 | 133 | Args: 134 | x: a tensor 135 | axis: the dimensions to normalize 136 | 137 | Returns: 138 | a tensor the same shape as x. 139 | 140 | Raises: 141 | ''' 142 | if axis is None: 143 | axis = [-1] 144 | shape = [x.shape[i] for i in axis] 145 | 146 | scale = tf.get_variable( 147 | name='scale', 148 | shape=shape, 149 | dtype=tf.float32, 150 | initializer=tf.ones_initializer()) 151 | bias = tf.get_variable( 152 | name='bias', 153 | shape=shape, 154 | dtype=tf.float32, 155 | initializer=tf.zeros_initializer()) 156 | 157 | mean = tf.reduce_mean(x, axis=axis, keep_dims=True) 158 | variance = tf.reduce_mean(tf.square(x - mean), axis=axis, keep_dims=True) 159 | norm = (x-mean) * tf.rsqrt(variance + epsilon) 160 | return scale * norm + bias 161 | 162 | def dense(x, out_dimension=None, add_bias=True): 163 | '''Add dense connected layer, Wx + b. 164 | 165 | Args: 166 | x: a tensor with shape [batch, time, dimension] 167 | out_dimension: a number which is the output dimension 168 | 169 | Return: 170 | a tensor with shape [batch, time, out_dimension] 171 | 172 | Raises: 173 | ''' 174 | if out_dimension is None: 175 | out_dimension = x.shape[-1] 176 | 177 | W = tf.get_variable( 178 | name='weights', 179 | shape=[x.shape[-1], out_dimension], 180 | dtype=tf.float32, 181 | # initializer=tf.orthogonal_initializer()) 182 | initializer=tf.contrib.layers.xavier_initializer()) 183 | if add_bias: 184 | bias = tf.get_variable( 185 | name='bias', 186 | shape=[1], 187 | dtype=tf.float32, 188 | initializer=tf.zeros_initializer()) 189 | return tf.einsum('bik,kj->bij', x, W) + bias 190 | else: 191 | return tf.einsum('bik,kj->bij', x, W) 192 | 193 | def matmul_2d(x, out_dimension, drop_prob=None): 194 | '''Multiplies 2-d tensor by weights. 195 | 196 | Args: 197 | x: a tensor with shape [batch, dimension] 198 | out_dimension: a number 199 | 200 | Returns: 201 | a tensor with shape [batch, out_dimension] 202 | 203 | Raises: 204 | ''' 205 | W = tf.get_variable( 206 | name='weights', 207 | shape=[x.shape[1], out_dimension], 208 | dtype=tf.float32, 209 | initializer=tf.orthogonal_initializer()) 210 | if drop_prob is not None: 211 | W = tf.nn.dropout(W, drop_prob) 212 | print('W is dropout') 213 | 214 | return tf.matmul(x, W) 215 | 216 | def gauss_positional_encoding_vector(x, role=0, value=0): 217 | position = int(x.shape[1]) 218 | dimension = int(x.shape[2]) 219 | print('position: %s' %position) 220 | print('dimension: %s' %dimension) 221 | 222 | _lambda = tf.get_variable( 223 | name='lambda', 224 | shape=[position], 225 | dtype=tf.float32, 226 | initializer=tf.constant_initializer(value)) 227 | _lambda = tf.expand_dims(_lambda, axis=-1) 228 | 229 | mean = [position/2.0, dimension/2.0] 230 | 231 | #cov = [[position/3.0, 0], [0, dimension/3.0]] 232 | sigma_x = position/math.sqrt(4.0*dimension) 233 | sigma_y = math.sqrt(dimension/4.0) 234 | cov = [[sigma_x*sigma_x, role*sigma_x*sigma_y], 235 | [role*sigma_x*sigma_y, sigma_y*sigma_y]] 236 | 237 | pos = np.dstack(np.mgrid[0:position, 0:dimension]) 238 | 239 | 240 | rv = multivariate_normal(mean, cov) 241 | signal = rv.pdf(pos) 242 | signal = signal - np.max(signal)/2.0 243 | 244 | signal = tf.multiply(_lambda, signal) 245 | signal = tf.expand_dims(signal, axis=0) 246 | 247 | print('gauss positional encoding') 248 | 249 | return x + _lambda * signal 250 | 251 | def positional_encoding(x, min_timescale=1.0, max_timescale=1.0e4, value=0): 252 | '''Adds a bunch of sinusoids of different frequencies to a tensor. 253 | 254 | Args: 255 | x: a tensor with shape [batch, length, channels] 256 | min_timescale: a float 257 | max_timescale: a float 258 | 259 | Returns: 260 | a tensor the same shape as x. 261 | 262 | Raises: 263 | ''' 264 | length = x.shape[1] 265 | channels = x.shape[2] 266 | _lambda = tf.get_variable( 267 | name='lambda', 268 | shape=[1], 269 | dtype=tf.float32, 270 | initializer=tf.constant_initializer(value)) 271 | 272 | position = tf.to_float(tf.range(length)) 273 | num_timescales = channels // 2 274 | log_timescale_increment = ( 275 | math.log(float(max_timescale) / float(min_timescale)) / 276 | (tf.to_float(num_timescales) - 1)) 277 | inv_timescales = min_timescale * tf.exp( 278 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 279 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 280 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 281 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) 282 | #signal = tf.reshape(signal, [1, length, channels]) 283 | signal = tf.expand_dims(signal, axis=0) 284 | 285 | return x + _lambda * signal 286 | 287 | 288 | def positional_encoding_vector(x, min_timescale=1.0, max_timescale=1.0e4, value=0): 289 | '''Adds a bunch of sinusoids of different frequencies to a tensor. 290 | 291 | Args: 292 | x: a tensor with shape [batch, length, channels] 293 | min_timescale: a float 294 | max_timescale: a float 295 | 296 | Returns: 297 | a tensor the same shape as x. 298 | 299 | Raises: 300 | ''' 301 | length = x.shape[1] 302 | channels = x.shape[2] 303 | _lambda = tf.get_variable( 304 | name='lambda', 305 | shape=[length], 306 | dtype=tf.float32, 307 | initializer=tf.constant_initializer(value)) 308 | _lambda = tf.expand_dims(_lambda, axis=-1) 309 | 310 | position = tf.to_float(tf.range(length)) 311 | num_timescales = channels // 2 312 | log_timescale_increment = ( 313 | math.log(float(max_timescale) / float(min_timescale)) / 314 | (tf.to_float(num_timescales) - 1)) 315 | inv_timescales = min_timescale * tf.exp( 316 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 317 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 318 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 319 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) 320 | 321 | signal = tf.multiply(_lambda, signal) 322 | signal = tf.expand_dims(signal, axis=0) 323 | 324 | return x + signal 325 | 326 | def mask(row_lengths, col_lengths, max_row_length, max_col_length): 327 | '''Return a mask tensor representing the first N positions of each row and each column. 328 | 329 | Args: 330 | row_lengths: a tensor with shape [batch] 331 | col_lengths: a tensor with shape [batch] 332 | 333 | Returns: 334 | a mask tensor with shape [batch, max_row_length, max_col_length] 335 | 336 | Raises: 337 | ''' 338 | row_mask = tf.sequence_mask(row_lengths, max_row_length) #bool, [batch, max_row_len] 339 | col_mask = tf.sequence_mask(col_lengths, max_col_length) #bool, [batch, max_col_len] 340 | 341 | row_mask = tf.cast(tf.expand_dims(row_mask, -1), tf.float32) 342 | col_mask = tf.cast(tf.expand_dims(col_mask, -1), tf.float32) 343 | 344 | return tf.einsum('bik,bjk->bij', row_mask, col_mask) 345 | 346 | def weighted_sum(weight, values): 347 | '''Calcualte the weighted sum. 348 | 349 | Args: 350 | weight: a tensor with shape [batch, time, dimension] 351 | values: a tensor with shape [batch, dimension, values_dimension] 352 | 353 | Return: 354 | a tensor with shape [batch, time, values_dimension] 355 | 356 | Raises: 357 | ''' 358 | return tf.einsum('bij,bjk->bik', weight, values) 359 | 360 | 361 | 362 | 363 | -------------------------------------------------------------------------------- /qiznlp/run/run_mch.py: -------------------------------------------------------------------------------- 1 | import os, sys, re 2 | import time 3 | import jieba 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import qiznlp.common.utils as utils 8 | import qiznlp.common.train_helper as train_helper 9 | from qiznlp.run.run_base import Run_Model_Base, check_and_update_param_of_model_pyfile 10 | 11 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 12 | sys.path.append(curr_dir + '/..') # 添加上级目录即默认qiznlp根目录 13 | from model.mch_model import Model as MCH_Model 14 | 15 | try: 16 | import horovod.tensorflow as hvd 17 | # 示例:horovodrun -np 2 -H localhost:2 python run_mch.py 18 | # 可根据local_rank设置shard的数据,保证各个gpu采样的数据不重叠。 19 | except: 20 | HVD_ENABLE = False 21 | else: 22 | HVD_ENABLE = True 23 | 24 | conf = utils.dict2obj({ 25 | 'early_stop_patience': None, 26 | 'just_save_best': False, 27 | 'n_epochs': 20, 28 | 'data_type': 'tfrecord', 29 | # 'data_type': 'pkldata', 30 | }) 31 | 32 | 33 | class Run_Model_Mch(Run_Model_Base): 34 | def __init__(self, model_name, tokenize=None, pbmodel_dir=None, use_hvd=False): 35 | # 维护sess graph config saver 36 | self.model_name = model_name 37 | if tokenize is None: 38 | self.jieba = jieba.Tokenizer() 39 | # self.jieba.load_userdict(f'{curr_dir}/segword.dct') 40 | self.tokenize = lambda t: self.jieba.lcut(re.sub(r'\s+', ',', t)) 41 | else: 42 | self.tokenize = tokenize 43 | self.cut = lambda t: ' '.join(self.tokenize(t)) 44 | self.token2id_dct = { 45 | 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/mch_word2id.dct', use_line_no=True), # 自有数据 46 | } 47 | self.config = tf.ConfigProto(allow_soft_placement=True, 48 | gpu_options=tf.GPUOptions(allow_growth=True), 49 | ) 50 | self.use_hvd = use_hvd if HVD_ENABLE else False 51 | if self.use_hvd: 52 | hvd.init() 53 | self.hvd_rank = hvd.rank() 54 | self.hvd_size = hvd.size() 55 | self.config.gpu_options.visible_device_list = str(hvd.local_rank()) 56 | self.graph = tf.Graph() 57 | self.sess = tf.Session(config=self.config, graph=self.graph) 58 | 59 | if pbmodel_dir is not None: # 只能做predict 60 | self.model = MCH_Model.from_pbmodel(pbmodel_dir, self.sess) 61 | else: 62 | with self.graph.as_default(): 63 | self.model = MCH_Model(model_name=self.model_name, run_model=self) 64 | if self.use_hvd: 65 | self.model.optimizer._lr = self.model.optimizer._lr * self.hvd_size # 分布式训练大batch增大学习率 66 | self.model.hvd_optimizer = hvd.DistributedOptimizer(self.model.optimizer) 67 | self.model.train_op = self.model.hvd_optimizer.minimize(self.model.loss, global_step=self.model.global_step) 68 | self.sess.run(tf.global_variables_initializer()) 69 | if self.use_hvd: 70 | self.sess.run(hvd.broadcast_global_variables(0)) 71 | 72 | with self.graph.as_default(): 73 | self.saver = tf.train.Saver(max_to_keep=100) # must in the graph context 74 | 75 | def train_step(self, feed_dict): 76 | _, step, loss, accuracy = self.sess.run([self.model.train_op, 77 | self.model.global_step, 78 | self.model.loss, 79 | self.model.accuracy, 80 | ], 81 | feed_dict=feed_dict) 82 | return step, loss, accuracy 83 | 84 | def eval_step(self, feed_dict): 85 | loss, accuracy, y_prob = self.sess.run([self.model.loss, 86 | self.model.accuracy, 87 | self.model.y_prob, 88 | ], 89 | feed_dict=feed_dict) 90 | return loss, accuracy, y_prob 91 | 92 | def train(self, ckpt_dir, raw_data_file, preprocess_raw_data, batch_size = 100, save_data_prefix = None): 93 | save_data_prefix = os.path.basename(ckpt_dir) if save_data_prefix is None else save_data_prefix 94 | train_epo_steps, dev_epo_steps, test_epo_steps, gen_feed_dict = self.prepare_data(conf.data_type, 95 | raw_data_file, 96 | preprocess_raw_data, 97 | batch_size, 98 | save_data_prefix=save_data_prefix, 99 | ) 100 | self.is_master = True 101 | if hasattr(self, 'hvd_rank') and self.hvd_rank != 0: # 分布式训练且非master 102 | dev_epo_steps, test_epo_steps = None, None # 不进行验证和测试 103 | self.is_master = False 104 | 105 | # 字典大小自动对齐 106 | check_and_update_param_of_model_pyfile({ 107 | 'vocab_size': (self.model.conf.vocab_size, len(self.token2id_dct['word2id'])), 108 | }, self.model) 109 | 110 | train_info = {} 111 | for epo in range(1, 1 + conf.n_epochs): 112 | train_info[epo] = {} 113 | 114 | # train 115 | time0 = time.time() 116 | epo_num_example = 0 117 | trn_epo_loss = [] 118 | trn_epo_acc = [] 119 | for i in range(train_epo_steps): 120 | feed_dict = gen_feed_dict(i, epo, 'train') 121 | epo_num_example += feed_dict.pop('num') 122 | 123 | step_start_time = time.time() 124 | step, loss, acc = self.train_step(feed_dict) 125 | trn_epo_loss.append(loss) 126 | trn_epo_acc.append(acc) 127 | 128 | if self.is_master: 129 | print(f'\repo:{epo} step:{i + 1}/{train_epo_steps} num:{epo_num_example} ' 130 | f'cur_loss:{loss:.3f} epo_loss:{np.mean(trn_epo_loss):.3f} ' 131 | f'epo_acc:{np.mean(trn_epo_acc):.3f} ' 132 | f'sec/step:{time.time() - step_start_time:.2f}', 133 | end=f'{os.linesep if i == train_epo_steps - 1 else ""}', 134 | ) 135 | 136 | trn_loss = np.mean(trn_epo_loss) 137 | trn_acc = np.mean(trn_epo_acc) 138 | if self.is_master: 139 | print(f'epo:{epo} trn loss {trn_loss:.3f} ' 140 | f'trn acc {trn_acc:.3f} ' 141 | f'elapsed {(time.time() - time0) / 60:.2f} min') 142 | train_info[epo]['trn_loss'] = trn_loss 143 | train_info[epo]['trn_acc'] = trn_acc 144 | 145 | if not self.is_master: 146 | continue 147 | 148 | # dev or test 149 | for mode in ['dev', 'test']: 150 | epo_steps = {'dev': dev_epo_steps, 'test': test_epo_steps}[mode] 151 | if epo_steps is None: 152 | continue 153 | time0 = time.time() 154 | epo_loss = [] 155 | epo_acc = [] 156 | # to calc recall 157 | epo_s1 = [] 158 | epo_prob = [] 159 | epo_y = [] 160 | for i in range(epo_steps): 161 | feed_dict = gen_feed_dict(i, epo, mode) 162 | loss, acc, y_prob = self.eval_step(feed_dict) 163 | 164 | epo_loss.append(loss) 165 | epo_acc.append(acc) 166 | 167 | # to calc recall 168 | epo_s1.extend(feed_dict[self.model.s1]) 169 | epo_prob.extend(y_prob) 170 | epo_y.extend(feed_dict[self.model.target]) 171 | 172 | loss = np.mean(epo_loss) 173 | acc = np.mean(epo_acc) 174 | recall = train_helper.calc_recall(epo_s1, epo_prob, epo_y, strip_pad=True) # calc recall 175 | recall = list(map(lambda e: round(e, 3), recall)) 176 | 177 | print(f'epo:{epo} {mode} loss {loss:.3f} ' 178 | f'{mode} acc {acc:.3f} ' 179 | f'{mode} recall@n {recall}' # print recall 180 | f'elapsed {(time.time() - time0) / 60:.2f} min') 181 | train_info[epo][f'{mode}_loss'] = loss 182 | train_info[epo][f'{mode}_acc'] = acc 183 | train_info[epo][f'{mode}_recall@n'] = recall 184 | 185 | info_str = f'{trn_loss:.2f}-{train_info[epo]["dev_loss"]:.2f}' 186 | info_str += f'-{trn_acc:.3f}-{train_info[epo]["dev_acc"]:.3f}' 187 | info_str += f'-{train_info[epo]["dev_recall@n"][0]:.3f}' 188 | 189 | if conf.just_save_best: 190 | if self.should_save(epo, train_info, 'dev_recall@n', greater_is_better=True): 191 | self.delete_ckpt(ckpt_dir=ckpt_dir) # 删掉已存在的 192 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 193 | else: 194 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 195 | 196 | utils.obj2json(train_info, f'{ckpt_dir}/metrics.json') 197 | print('=' * 40, end='\n') 198 | if conf.early_stop_patience: 199 | if self.stop_training(conf.early_stop_patience, train_info, 'dev_acc'): 200 | print('early stop training!') 201 | print('train_info', train_info) 202 | break 203 | 204 | def predict(self, s1_lst, s2_lst, need_cut=True, batch_size=100): 205 | assert len(s1_lst) == len(s2_lst) 206 | if need_cut: 207 | s1_lst = [self.cut(s1) for s1 in s1_lst] 208 | s2_lst = [self.cut(s2) for s2 in s2_lst] 209 | 210 | pred_lst = [] 211 | for i in range(0, len(s1_lst), batch_size): 212 | batch_s1 = s1_lst[i:i + batch_size] 213 | batch_s2 = s2_lst[i:i + batch_size] 214 | feed_dict = self.model.create_feed_dict_from_raw(batch_s1, batch_s2, [], self.token2id_dct, mode='infer') 215 | probs = self.sess.run(self.model.y_prob, feed_dict) # [batch] 216 | preds = [1 if prob >= 0.5 else 0 for prob in probs] 217 | pred_lst.extend(preds) 218 | 219 | return pred_lst 220 | 221 | 222 | def preprocess_raw_data(file, tokenize, token2id_dct, **kwargs): 223 | """ 224 | # 处理自有数据函数模板 225 | # file文件数据格式: 句子1\t句子2 (正样本对,框架自行负采样) 226 | # [filter] 过滤 227 | # [segment] 分词 228 | # [build vocab] 构造词典 229 | # [split] train-dev-test 230 | """ 231 | seg_file = file.rsplit('.', 1)[0] + '_seg.txt' 232 | if not os.path.exists(seg_file): 233 | items = utils.file2items(file) 234 | # 过滤 235 | # filter here 236 | 237 | print('过滤后数据量', len(items)) 238 | 239 | # 分词 240 | for i, item in enumerate(items): 241 | item[0] = ' '.join(tokenize(item[0])) 242 | item[1] = ' '.join(tokenize(item[1])) 243 | utils.list2file(seg_file, items) 244 | print('保存分词后数据成功', '数据量', len(items), seg_file) 245 | else: 246 | # 读取分词好的数据 247 | items = utils.file2items(seg_file) 248 | 249 | # 划分 不分测试集 250 | train_items, dev_items = utils.split_file(items, ratio='19:1', shuffle=True, seed=1234) 251 | 252 | # 构造词典(option) 253 | need_to_rebuild = [] 254 | for token2id_name in token2id_dct: 255 | if not token2id_dct[token2id_name]: 256 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 257 | need_to_rebuild.append(token2id_name) 258 | 259 | if need_to_rebuild: 260 | print(f'生成缺失词表文件...{need_to_rebuild}') 261 | for items in [train_items, dev_items]: # 字典只统计train和dev 262 | for item in items: 263 | if 'word2id' in need_to_rebuild: 264 | token2id_dct['word2id'].to_count(item[0].split(' ')) 265 | token2id_dct['word2id'].to_count(item[1].split(' ')) 266 | if 'word2id' in need_to_rebuild: 267 | token2id_dct['word2id'].rebuild_by_counter(restrict=['', '', ''], min_freq=5, max_vocab_size=30000) 268 | token2id_dct['word2id'].save(f'{curr_dir}/../data/mch_word2id.dct') 269 | else: 270 | print('使用已有词表文件...') 271 | 272 | # 负采样 273 | train_items = train_helper.gen_pos_neg_sample(train_items, sample_idx=1, num_neg_exm=4) 274 | dev_items = train_helper.gen_pos_neg_sample(dev_items, sample_idx=1, num_neg_exm=4) 275 | return train_items, dev_items, None 276 | 277 | 278 | if __name__ == '__main__': 279 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用CPU设为'-1' 280 | 281 | rm_mch = Run_Model_Mch('esim') # use ESIM match model 282 | 283 | # 训练自有数据 284 | rm_mch.train('mch_ckpt_1', '../data/mch_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512) # train 285 | 286 | # demo自有数据语义匹配模型 287 | rm_mch.restore('mch_ckpt_1') # for infer 288 | import readline 289 | while True: 290 | inp = input('(use ||| to split) enter:') 291 | sent1, sent2 = inp.split('|||') 292 | need_cut = False if ' ' in sent1 else True 293 | time0 = time.time() 294 | ret = rm_mch.predict([sent1], [sent2], need_cut=need_cut) 295 | print(ret[0]) 296 | print('elapsed:', time.time() - time0) 297 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://img.shields.io/pypi/v/QizNLP?logo=pypi)](https://pypi.org/project/QizNLP/) 2 | ![](https://img.shields.io/pypi/dm/QizNLP.svg?label=pypi%20downloads&logo=PyPI) 3 | ![](https://img.shields.io/pypi/l/QizNLP?color=green&logo=pypi) 4 | ![](https://img.shields.io/badge/platform-windows%20%7C%20macos%20%7C%20linux-lightgrey) 5 | ![](https://img.shields.io/pypi/pyversions/QizNLP?logo=pypi) 6 | ![](https://img.shields.io/badge/tensorflow-1.8%20%7C%201.9%20%7C%201.10%20%7C%201.11%20%7C%201.12%20%7C%201.13%20%7C%201.14-blue) 7 | 8 | 如对您有帮助,欢迎star本项目~~感谢! 9 | [文章系列](https://www.zhihu.com/column/c_1310303923157647360) 10 | ##### 一键运行分类 Demo (用cpu演示所以训练较慢) 11 | ![run_demo](run_demo.gif) 12 | 13 | ##### 任务/模型支持概览 14 | 15 | |任务|支持模型
(*表示默认)|相关代码
(训练/模型)|默认数据|数据文件|来源| 16 | |:------:|:---:|:---|---|:---|---| 17 | |分类|TransEncoder+MeanPooling
*TransEncoder+MHAttPooling
BERT|run_cls.py
model_cls.py|头条新闻分类|train/valid/test.toutiao.cls.txt|https://github.com/luopeixiang/textclf| 18 | |序列标注|*BiLSTM+CRF
IDCNN+CRF
BERT+CRF|run_s2l.py
model_s2l.py|ResumeNER简历数据|train/dev/test.char.bmes.txt|https://github.com/jiesutd/LatticeLSTM| 19 | |匹配|*ESIM|run_mch.py
model_mch.py|ChineseSTS相似语义文本|mch_example_data.txt|https://github.com/IAdmireu/ChineseSTS| 20 | |生成|LSTM_Seq2Seq+Attn
*Transformer|run_s2s.py
model_s2s.py|小黄鸡闲聊5万|XHJ_5w.txt|https://github.com/candlewill/Dialog_Corpus| 21 | || 22 | |多轮匹配|DAM
*MRFN|run_multi_mch.py
multi_mch_model.py|豆瓣多轮会话600+|Douban_Sess662.txt|https://github.com/MarkWuNLP/MultiTurnResponseSelection 23 | |多轮生成|HRED
HRAN
*ReCoSa|run_multi_s2s.py
multi_s2s_model.py|小黄鸡闲聊5万
豆瓣多轮会话600+|XHJ_5w.txt
+Douban_Sess662.txt|https://github.com/candlewill/Dialog_Corpus
https://github.com/MarkWuNLP/MultiTurnResponseSelection 24 | 25 | ##### 目录 26 | * [QizNLP简介](#QizNLP简介) 27 | * [安装流程](#安装流程) 28 | * [使用示例](#使用示例) 29 | * [快速运行(使用默认数据训练)](#1快速运行使用默认数据训练) 30 | * [使用自有数据](#2使用自有数据) 31 | * [加载预训练模型](#3加载预训练模型) 32 | * [框架设计思路](#框架设计思路) 33 | * [公共模块](#公共模块) 34 | * [修改适配需关注点](#修改适配需关注点) 35 | * [生成词表字典](#1生成词表字典) 36 | * [数据处理相关](#2数据处理相关) 37 | * [run和model的conf参数](#3run和model的conf参数) 38 | * [使用分布式](#4使用分布式) 39 | * [类图](#类图) 40 | * [TODO](#todo) 41 | * [参考](#参考) 42 | * [后记](#后记) 43 | 44 | 45 | ## QizNLP简介 46 | QizNLP(Quick NLP)是一个面向NLP四大常见范式(分类、序列标注、匹配、生成),提供基本全流程(数据处理、模型训练、部署推断),基于TensorFlow的一套NLP框架。 47 | 48 | 设计动机是为了在各场景下(实验/比赛/工作),可快速灌入数据到模型,验证效果。从而在原型试验阶段可更快地了解到数据特点、学习难度及对比不同模型。 49 | 50 | QizNLP的特点如下: 51 | 52 | * 封装了训练数据处理函数(TFRecord或Python原生数据两种方式)及词表生成函数。 53 | * 针对分类、序列标注、匹配、生成这四种NLP任务提供了使用该框架进行模型训练的参考代码,可一键运行(提供了默认数据及默认模型) 54 | * 封装了模型导出装载等函数,用以支持推断部署。提供部署参考代码。 55 | * 封装了少量常用模型。封装了多种常用的TF神经网络操作函数。 56 | * 封装了并提供了分布式训练方式(利用horovod) 57 | 58 | 设计原则: 59 | 60 | 框架设计并非追求面面俱到,因为每个人在实践过程中需求是不一样的(如特殊的输入数据处理、特殊的训练评估打印保存等过程)。故框架仅是尽量将可复用功能封装为公共模块,然后为四大范式(分类/序列标注/匹配/生成)提供一个简单的训练使用示例,供使用者根据自己的情况进行参考修改。框架原则上是重在灵活性,故不可避免地牺牲了部分易用性。虽然这可能给零基础的初学者带来一定困难,但框架的设计初衷也是希望能作为NLP不同实践场景中的一个编码起点(相当于初始弹药库),并且能在个人需求不断变化时也能灵活进行适配及持续使用。 61 | 62 | ## 安装流程 63 | 项目依赖 64 | ``` 65 | python>=3.6 66 | 1.8<=tensorflow<=1.14 67 | ``` 68 | 已发布pypi包,可直接通过```pip```安装(推荐) 69 | ```shell script 70 | pip install QizNLP 71 | ``` 72 | 或通过本项目github安装 73 | ```shell script 74 | pip install git+https://github.com/Qznan/QizNLP.git 75 | ``` 76 | 安装完毕后,进入到你自己创建的工作目录,输入以下命令: 77 | ```shell script 78 | qiznlp_init 79 | ``` 80 | 回车后,会在当前工作目录生成主要文件: 81 | ```bash 82 | . 83 | ├── model # 各个任务的模型代码示例 84 | │ ├── cls_model.py 85 | │ ├── mch_model.py 86 | │ ├── s2l_model.py 87 | │ ├── s2s_model.py 88 | │ ├── multi_mch_model.py 89 | │ └── multi_s2s_model.py 90 | ├── run # 各个任务的模型训练代码示例 91 | │ ├── run_cls.py 92 | │ ├── run_mch.py 93 | │ ├── run_s2l.py 94 | │ ├── run_s2s.py 95 | │ ├── run_multi_mch.py 96 | │ └── run_multi_s2s.py 97 | ├── deploy # 模型载入及部署的代码示例 98 | │ ├── example.py 99 | │ └── web_API.py 100 | ├── data # 各个任务的默认数据 101 | │ ├── train.toutiao.cls.txt 102 | │ ├── valid.toutiao.cls.txt 103 | │ ├── test.toutiao.cls.txt 104 | │ ├── train.char.bmes.txt 105 | │ ├── dev.char.bmes.txt 106 | │ ├── test.char.bmes.txt 107 | │ ├── mch_example_data.txt 108 | │ ├── XHJ_5w.txt 109 | │ └── Douban_Sess662.txt 110 | └── common # 存放预训练bert模型等 111 | └── modules 112 | └── bert 113 | └── chinese_L-12_H-768_A-12 114 | ``` 115 | 注意:如果不是通过pip安装此项目而是直接从github上克隆项目源码,则进行后续操作前需将包显式加入python路径中: 116 | ``` 117 | # Linux & Mac 118 | export PYTHONPATH=$PYTHONPATH:<克隆的qiznlp所在dir> 119 | # Windows 120 | set PYTHONPATH=<克隆的qiznlp所在dir> 121 | ``` 122 | ## 使用示例 123 | #### 1.快速运行(使用默认数据训练) 124 | ```shell script 125 | cd run 126 | 127 | # 运行分类任务 128 | python run_cls.py 129 | 130 | # 运行序列标注任务 131 | python run_s2l.py 132 | 133 | # 运行匹配任务 134 | python run_mch.py 135 | 136 | # 运行生成任务 137 | python run_s2s.py 138 | 139 | # 运行多轮匹配任务 140 | python run_multi_mch.py 141 | 142 | # 运行多轮生成任务 143 | python run_multi_s2s.py 144 | 145 | ``` 146 | 各任务默认数据及模型说明 147 | [见上](#任务模型支持概览) 148 | 149 | #### 2.使用自有数据 150 | 根据输入数据文本格式修改```run_*.py```中的```preprocess_raw_data()```函数,决定如何读取自有数据。 151 | 各```run_*.py```中,均已有```preprocess_raw_data()```的函数参考示例,其中默认文本格式如下: 152 | * ```run_cls.py```:默认输入文本格式的每行为:```句子\t类别``` 153 | > 自然语言难处理\t科技类 154 | * ```run_s2l.py```:默认输入文本格式的每行为:```句子(以空格分好)\t标签(以空格分好)``` 155 | > 我 是 王 五 , 没 有 人 比 我 更 懂 N L P 。\tO O B-NAME I-NAME O O O O O O O O B-TECH I-TECH I-TECH O 156 | * ```run_mch.py```:默认输入文本格式的每行为:```句子1\t句子2```(正样本对,框架自行负采样) 157 | > 自然语言难处理\t自然语言处理难 158 | * ```run_s2s.py```:默认输入文本格式的每行为:```句子1\t句子2``` 159 | > 自然语言难处理\t机器也不学习 160 | * ```run_multi_mch.py```:默认输入文本格式的每行为:```多轮对话句子1\t多轮对话句子2\t...\t多轮对话句子n``` 161 | > 自然语言难处理\t机器也不学习\t还说是人工智能\t简直就是人工智障\t大佬所见略同\t握手 162 | * ```run_multi_s2s.py```:默认输入文本格式的每行为:```多轮对话句子1\t多轮对话句子2\t...\t多轮对话句子n``` 163 | > 自然语言难处理\t机器也不学习\t还说是人工智能\t简直就是人工智障\t大佬所见略同\t握手 164 | 165 | 然后在```run_*.py```中指定```train()```的数据处理函数为```preprocess_raw_data```,以```run_cls.py```为例: 166 | ```python 167 | # 参数分别为:模型ckpt保存路径、自有数据文件路径、数据处理函数、训练batch size 168 | rm_cls.train('cls_ckpt_1', '../data/cls_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512) 169 | 170 | # 注意:如果数据集不变的情况修改了模型想继续实验(这应该是调模型的大部分情况),在设置ckpt保存路径为'cls_ckpt_2'后,可设置参数 171 | # save_data_prefix='cls_ckpt_1'。表示使用前一次实验处理的已有数据,以节省时间。如下: 172 | # 修改了模型后的再次实验: 173 | rm_cls.train('cls_ckpt_2', '../data/cls_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512, save_data_prefix='cls_ckpt_1') 174 | ``` 175 | 具体更多细节可自行查阅代码,相信你能很快理解并根据自己的需求进行修改以适配自有数据 :) 176 | 177 | 这里有个彩蛋:第一次运行自有数据会不成功,需要对应修改```model_*.py```中conf与字典大小相关的参数,详情请参考下文:字典生成中的[提醒](#1生成词表字典) 178 | 179 | #### 3.加载预训练模型 180 | 181 | 默认使用了谷歌官方中文bert-base预训练模型,[下载](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) 并将对应模型文件放入当前工作目录的以下文件中: 182 | ```bash 183 | common/modules/bert/chinese_L-12_H-768_A-12 184 | ├── bert_model.ckpt.data-00000-of-00001 # 自行下载 185 | ├── bert_model.ckpt.index # 自行下载 186 | ├── bert_model.ckpt.meta # 自行下载 187 | ├── bert_config.json # 框架已提供 188 | └── vocab.txt # 框架已提供 189 | ``` 190 | model中则是通过该路径构造bert模型,以cls_model.py中的bert类模型为例: 191 | ```python 192 | bert_model_dir = f'{curr_dir}/../common/modules/bert/chinese_L-12_H-768_A-12' 193 | ``` 194 | 195 | ### 框架设计思路 196 | ——只有大体了解了框架设计才能更自由地进行适配 :) 197 | 198 | 一个NLP任务可以分为:数据选取及处理、模型输入输出设计、模型结构设计、训练及评估、推断、部署。 199 | 由此抽象出一些代码模块: 200 | * run模块负责维护以下实例对象及方法: 201 | * TF中的sess/graph/config/saver 202 | * 分词方式、字典实例 203 | * 分布式训练(hvd)相关设置 204 | * 一个model实例 205 | * 模型的保存与恢复方法 206 | * 训练、评估方法 207 | * 基本的推断方法 208 | * 原始数据处理方法(数据集切分、分词、构造字典) 209 | 210 | 211 | * model模块负责维护: 212 | * 输入输出设计 213 | * 模型结构设计 214 | * 输入输出签名暴露的接口 215 | * 从pb/meta恢复模型结构方法 216 | 217 | model之间的主要区别是维护了自己特有的输入输出(即tf.placeholder的设计),故有以下实践建议: 218 | 219 | **何时不需要新建model?** 220 | 原则上只要输入输出不变,只有模型结构改变,则直接在原有model中增加新模型代码并在初始化时选择新的模型结构。 221 | 这种情况经常出现在针对某个任务的结构微调及试错过程中。 222 | 223 | **何时需要新建model?** 224 | 当输入输出有调整。 225 | 例如想额外考虑词性特征,输入中要加入词性信息字段。或者要解决一个全新的任务。 226 | 227 | **model与run的关系?** 228 | 一般一个model对应一个专有run。新建model后则应新建一个相应run。 229 | 原因主要考虑到run的训练评估过程需要与model的输入输出对齐。同时,model的不同输入可能也依赖于run进行特别的数据处理(如分词还是分字,词表大小,unk特殊规则等) 230 | 231 | **model与run有哪些数据交互?** 232 | 不同任务的主要区别包括如何对文本进行向量化(即token转id),需要设计分词、字典、如何生成向量、如何对齐到网络的placeholder。 233 | 这里让model负责该向量化方法,run会将自己的分词、字典传过去。并且该向量化方法会被其它许多地方调用。举例: 234 | * 生成向量方式会被应用于run的数据处理(生成tfrecord或原生py的pkl数据),以及对原始数据进行推断时的预处理 235 | * 生成向量后对齐到placeholder的方式则会被应用在run的训练及推断。 236 | 237 | **为何使用bert类模型时输入改变了但不必新增model与run?** 238 | bert类模型的输入有额外的\[CLS]\\\[SEP]等特殊符号,但本质上是模型层面的输入适配而不是任务数据层面的改变,故直接在原有model中重写与输入有关的函数,包装一层处理成bert输入的方法即可。 239 | 240 | ### 公共模块 241 | qiznlp包的公共模块文件如下: 242 | (因为是基本不需更改的基础模块,故没有在```qiznlp_init```命令初始化随其他文件一起复制到当前目录, 通过```qiznlp.common.*```调用) 243 | ```bash 244 | └── common 245 | ├── modules # 封装少量常用模型及各种TF的神经网络常用操作函数。 246 | ├── tfrecord_utils.py # 封装TFRecord数据保存及读取操作。 247 | ├── train_helper.py # 封装train过程中对数据的处理及其它相关方法。 248 | └── utils.py # 基础类,个人对python中一些数据IO的简单封装,框架的许多IO操作调用了里面封装的方法,建议详细看看。 249 | ``` 250 | 251 | ### 修改适配需关注点 252 | ##### 1、生成词表字典 253 | ```utils.Any2id```类封装了字典相关的功能,可通过传入文件进行初始化,在run中示例如下 254 | ```python 255 | token2id_dct = { 256 | 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiaoword2id.dct', use_line_no=True), 257 | } 258 | # use_line_no参数表示直接使用字典文件中的行号作为id 259 | ``` 260 | 如果传入的字典文件为空,则需要在run的数据处理函数中进行字典的构建,并保存到文件,方便下次直接读取 261 | ```python 262 | # 在迭代处理数据时循环调用 263 | token2id_dct['word2id'].to_count(cuted_sentent.split(' ')) # 迭代统计token信息,句子已分词,空格分隔 264 | # 结束迭代后构造字典 265 | # 参数包括 预留词、最小词频、最大词表大小 266 | token2id_dct['word2id'].rebuild_by_counter(restrict=['', ''], min_freq=1, max_vocab_size=20000) 267 | token2id_dct['word2id'].save(f'{curr_dir}/../data/toutiaoword2id.dct') # 保存到文件 268 | ``` 269 | **注意-1**: 在某个任务切换跑自有数据与公共数据集时,记得切换token2id_dct的字典文件名,如下: 270 | ```python 271 | self.token2id_dct = { 272 | # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/cls_word2id.dct', use_line_no=True), # 自有数据 273 | # 'label2id': utils.Any2Id.from_file(f'{curr_dir}/../data/cls_label2id.dct', use_line_no=True), # 自有数据 274 | 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiao_cls_word2id.dct', use_line_no=True), # toutiao新闻 275 | 'label2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiao_cls_label2id.dct', use_line_no=True), # toutiao新闻 276 | } 277 | ``` 278 | **注意-2**: 对自有数据进行训练时,由于模型的初始化比训练数据的处理更早,所以model源码中conf的相关参数(如:vocab size/label size等)只能先随意指定。之后等对训练数据的处理(分词、构造字典)完毕后才确认这些参数。 279 | 目前解决方式是运行两次run:第一次运行构造字典完毕后,会检查字典大小与model源码的vocab size等相关参数是否一致,不一致则自动更新model源码,请根据提示再次运行即可。如下: 280 | ```bash 281 | some param should be update: 282 | vocab_size => param: 40000 != dict: 4500 283 | update vocab_size success 284 | script will exit! please run the script again! e.g. python run_***.py 285 | ``` 286 | (当然也可以使用自己预定义的字典文件,然后在model源码conf中设置正确的相关参数后直接运行run) 287 | 288 | ##### 2、数据处理相关 289 | ```preprocess_raw_data```返回训练、验证、测试数据的元组: 290 | ``` 291 | def preprocess_raw_data(): 292 | # ... 293 | return train_items, dev_items, test_items # 分别对应训练/验证/测试 294 | 295 | # 其中验证和测试可为None,此时模型将不进行相应验证或测试,如: 296 | # return train_items, dev_items, None # 不进行测试 297 | ``` 298 | 299 | ##### 3、run和model的conf参数 300 | run中的conf示例如下: 301 | ``` 302 | # run_cls.py 303 | conf = utils.dict2obj({ 304 | 'early_stop_patience': None, # 根据指标是否早停 305 | 'just_save_best': True, # 仅保存指标最好的模型(减少磁盘空间占用) 306 | 'n_epochs': 20, # 训练轮数 307 | 'data_type': 'tfrecord', # 训练数据处理成TFRecord 308 | # 'data_type': 'pkldata', # 训练数据处理成py原生数据 309 | }) 310 | # 前两者的具体指标通过修改train时相关方法的参数确定 311 | ``` 312 | model中conf示例如下: 313 | ``` 314 | # model_cls.py 315 | conf = utils.dict2obj({ 316 | 'vocab_size'"': 14180, # 词表大小,也就是上文所述构建字典后需注意对齐的参数 317 | 'label_size': 16, # 类别数量,需注意对齐 318 | 'embed_size': 300, 319 | 'hidden_size': 300, 320 | 'num_heads': 6, 321 | 'num_encoder_layers': 6, 322 | 'dropout_rate': 0.2, 323 | 'lr': 1e-3, 324 | 'pretrain_emb': None, # 不使用预训练词向量 325 | # 'pretrain_emb': np.load(f'{curr_dir}/pretrain_word_emb300.npy'), # 使用预训练词向量(np格式)[vocab_size,embed_size] 326 | }) 327 | ``` 328 | 具体参数可根据个人任务情况进行增删改。 329 | ##### 4、使用分布式 330 | 框架提供的分布式功能基于horovod(使用一种同步数据并行策略),即将batch数据分为多个小batch,分配到多机或多卡来训练。 331 | 332 | 前提:```pip install horovod``` 333 | 限制:只能用TFRecord数据格式(因为需利用其提供的分片shard功能)。但生成TFRecord的过程不方便多个worker并行,故实践建议分两次运行,第一次采用非分布式正常运行生成数据、字典并能跑通训练,第二次运行才进行分布式训练 334 | 操作步骤:先按照正常方式运行一遍,以```run_cls.py```为例,终端运行: 335 | ``` 336 | python run_cls.py 337 | ``` 338 | 其中run的初始化为: 339 | ``` 340 | rm_cls = Run_Model_Cls('trans_mhattnpool') 341 | rm_cls.train('cls_ckpt_taskname', 'raw_data_file', preprocess_raw_data=preprocess_raw_data_fn, batch_size=batch_size) # train 342 | ``` 343 | 等通过终端日志确定已生成完TFRecord数据后,并且字典相关大小也和model对齐之后,```ctrl+c```退出 344 | 继而修改初始化参数```use_hvd=True```: 345 | ``` 346 | rm_cls = Run_Model_Cls('trans_mhattnpool', use_hvd=True) 347 | rm_cls.train('cls_ckpt_taskname', 'raw_data_file', preprocess_raw_data=preprocess_raw_data_fn, batch_size=batch_size) # train 348 | ``` 349 | 并在终端按照horovod要求的格式运行命令: 350 | ``` 351 | horovodrun -np 2 -H localhost:2 python run_cls.py 352 | # -np 2 代表总的worker数量为2 353 | # -H localhost:2 代表使用本机的2块GPU 354 | # 注意此时需要在代码中事先设置好正确数量的可见GPU,如:os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 355 | ``` 356 | 提醒:分布式训练中```train()```指定的```batch_size```参数即为有效(真实)batch size,内部会将batch切分为对应每个机或卡的小batch。故分布式训练实践中可在```train()```中直接指定较大的```batch_size```。 357 | ### 类图 358 | 附上主要的类设计图说明 359 | ![main_class_diagram](main_class_diagram.png) 360 | ## TODO 361 | * 完善对model和run模块的单独说明 362 | * 完善对公共module模块相关说明 363 | * 完善对deploy模块相关说明 364 | * 继续增加各任务默认模型(尤其预训练模型),各任务数据集 365 | * 继续完善框架,保持灵活性的同时尽量增加易用性 366 | * README.md英文化 367 | * 增加更多其他任务(如~~多轮检索和生成~~、MRC、few-shot-learning等) 368 | 369 | ## 参考 370 | * [tensor2tensor](https://github.com/tensorflow/tensor2tensor) 371 | * [bert](https://github.com/google-research/bert) 372 | 373 | ## License 374 | Mozilla Public License 2.0 (MPL 2.0) 375 | 376 | ## 后记 377 | 框架形成历程: 378 | 最早是在研究T2T官方transformer时,将transformer相关代码抽取独立出来,方便其他任务。 379 | 之后增加了自己优化的S2S的beam_search代码(支持一些多样性方法),以及总结了TF模型的导出部署代码。 380 | 后续在解决各种任务类型时,考虑着代码复用,不断重构,追求设计方案的灵活,最终得到现版本。 381 | 382 | 深知目前本项目仍有许多可改进的地方,欢迎issue和PR,也希望感兴趣的人能一起加入来改进! 383 | 如觉得本项目有用,感谢您的star~~ 384 | 385 | -------------------------------------------------------------------------------- /qiznlp/common/modules/bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /qiznlp/model/mch_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | from qiznlp.common.modules.common_layers import mask_nonpad_from_embedding 8 | from qiznlp.common.modules.embedding import embedding 9 | from qiznlp.common.modules.birnn import Bi_RNN 10 | import qiznlp.common.utils as utils 11 | 12 | conf = utils.dict2obj({ 13 | 'vocab_size': 1142, 14 | 'embed_size': 300, 15 | 'birnn_hidden_size': 300, 16 | 'l2_reg': 0.0001, 17 | 'dropout_rate': 0.2, 18 | 'lr': 1e-3, 19 | 'pretrain_emb': None, 20 | }) 21 | 22 | 23 | class Model(object): 24 | def __init__(self, build_graph=True, **kwargs): 25 | self.conf = conf 26 | self.run_model = kwargs.get('run_model', None) # acquire outside run_model instance 27 | if build_graph: 28 | # build placeholder 29 | self.build_placeholder() 30 | # build model 31 | self.model_name = kwargs.get('model_name', 'esim') 32 | { 33 | 'esim': self.build_model1, 34 | # add new here 35 | }[self.model_name]() 36 | print(f'model_name: {self.model_name} build graph ok!') 37 | 38 | def build_placeholder(self): 39 | # placeholder 40 | # 原则上模型输入输出不变,不需换新model 41 | self.s1 = tf.placeholder(tf.int32, [None, None], name='s1') 42 | self.s2 = tf.placeholder(tf.int32, [None, None], name='s2') 43 | self.target = tf.placeholder(tf.int32, [None], name="target") 44 | self.dropout_rate = tf.placeholder(tf.float32, name="dropout_rate") 45 | 46 | def build_model1(self): 47 | # embedding 48 | # [batch,len,embed] 49 | s1_embed, _ = embedding(tf.expand_dims(self.s1, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb) 50 | s2_embed, _ = embedding(tf.expand_dims(self.s2, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb) 51 | 52 | s1_input_mask = mask_nonpad_from_embedding(s1_embed) # [batch,len1] 1 for nonpad; 0 for pad 53 | s2_input_mask = mask_nonpad_from_embedding(s2_embed) # [batch,len2] 1 for nonpad; 0 for pad 54 | s1_seq_len = tf.cast(tf.reduce_sum(s1_input_mask, axis=-1), tf.int32) # [batch] 55 | s2_seq_len = tf.cast(tf.reduce_sum(s2_input_mask, axis=-1), tf.int32) # [batch] 56 | 57 | # bilstm sent encoder 58 | self.bilstm_encoder1 = Bi_RNN(cell_name='LSTMCell', hidden_size=conf.birnn_hidden_size, dropout_rate=self.dropout_rate) 59 | s1_bar, _ = self.bilstm_encoder1(s1_embed, s1_seq_len) # [batch,len1,2hid] 60 | s2_bar, _ = self.bilstm_encoder1(s2_embed, s2_seq_len) # [batch,len2,2hid] 61 | 62 | # local inference 局部推理 63 | with tf.variable_scope('local_inference'): 64 | # 点积注意力 65 | attention_logits = tf.matmul(s1_bar, tf.transpose(s2_bar, [0, 2, 1])) # [batch,len1,len2] 66 | 67 | # 注意需attention mask pad_mask * -inf + logits 68 | attention_s1 = tf.nn.softmax(attention_logits + tf.expand_dims((1. - s2_input_mask) * -1e9, 1)) # [batch,len1,len2] 69 | 70 | attention_s2 = tf.nn.softmax(tf.transpose(attention_logits, [0, 2, 1]) + tf.expand_dims((1. - s1_input_mask) * -1e9, 1)) # [batch,len2,len1] 71 | 72 | s1_hat = tf.matmul(attention_s1, s2_bar) # [batch,len1,2hid] 73 | s2_hat = tf.matmul(attention_s2, s1_bar) # [batch,len2,2hid] 74 | 75 | s1_diff = s1_bar - s1_hat 76 | s1_mul = s1_bar * s1_hat 77 | 78 | s2_diff = s2_bar - s2_hat 79 | s2_mul = s2_bar * s2_hat 80 | 81 | m_s1 = tf.concat([s1_bar, s1_hat, s1_diff, s1_mul], axis=2) # [batch,len1,8hid] 82 | m_s2 = tf.concat([s2_bar, s2_hat, s2_diff, s2_mul], axis=2) # [batch,len2,8hid] 83 | 84 | # composition 推理组成 85 | with tf.variable_scope('composition'): 86 | self.bilstm_encoder2 = Bi_RNN(cell_name='LSTMCell', hidden_size=conf.birnn_hidden_size, dropout_rate=self.dropout_rate) 87 | v_s1, _ = self.bilstm_encoder2(m_s1, s1_seq_len) # [batch,len1,2hid] 88 | v_s2, _ = self.bilstm_encoder2(m_s2, s2_seq_len) # [batch,len2,2hid] 89 | 90 | # average pooling # 需将pad的vector变为0 91 | v_s1 = v_s1 * tf.expand_dims(s1_input_mask, -1) # [batch,len1,2hid] 92 | v_s2 = v_s2 * tf.expand_dims(s2_input_mask, -1) # [batch,len1,2hid] 93 | v_s1_avg = tf.reduce_sum(v_s1, axis=1) / tf.cast(tf.expand_dims(s1_seq_len, -1), tf.float32) # [batch,2hid] 94 | v_s2_avg = tf.reduce_sum(v_s2, axis=1) / tf.cast(tf.expand_dims(s2_seq_len, -1), tf.float32) # [batch,2hid] 95 | 96 | # max pooling # 需将pad的vector变为极小值 97 | v_s1_max = tf.reduce_max(v_s1 + tf.expand_dims((1. - s1_input_mask) * -1e9, -1), axis=1) # [batch,2hid] 98 | v_s2_max = tf.reduce_max(v_s2 + tf.expand_dims((1. - s2_input_mask) * -1e9, -1), axis=1) # [batch,2hid] 99 | 100 | v = tf.concat([v_s1_avg, v_s1_max, v_s2_avg, v_s2_max], axis=-1) # [batch,8hid] 101 | 102 | with tf.variable_scope('ffn'): 103 | h_ = tf.layers.dropout(v, rate=self.dropout_rate) 104 | h = tf.layers.dense(h_, 256, activation=tf.nn.relu, kernel_initializer=tf.random_normal_initializer(0.0, 0.1)) # [batch,256] 105 | o_ = tf.layers.dropout(h, rate=self.dropout_rate) 106 | o = tf.layers.dense(o_, 1, kernel_initializer=tf.random_normal_initializer(0.0, 0.1)) # [batch,1] 107 | self.logits = tf.squeeze(o, -1) # [batch] 108 | 109 | # loss 110 | with tf.name_scope('loss'): 111 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self.target, tf.float32), logits=self.logits) # [batch] 112 | loss = tf.reduce_mean(loss, -1) # scalar 113 | 114 | l2_reg = conf.l2_reg 115 | weights = [v for v in tf.trainable_variables() if ('w' in v.name) or ('kernel' in v.name)] 116 | l2_loss = tf.add_n([tf.nn.l2_loss(w) for w in weights]) * l2_reg 117 | loss += l2_loss 118 | 119 | self.loss = loss 120 | self.y_prob = tf.nn.sigmoid(self.logits) 121 | self.y_prob = tf.identity(self.y_prob, name='y_prob') 122 | 123 | with tf.name_scope("accuracy"): 124 | self.correct = tf.equal( 125 | tf.cast(tf.greater_equal(self.y_prob, 0.5), tf.int32), 126 | self.target) 127 | self.accuracy = tf.reduce_mean(tf.cast(self.correct, tf.float32)) 128 | 129 | self.global_step = tf.train.get_or_create_global_step() 130 | self.optimizer = tf.train.AdamOptimizer(learning_rate=conf.lr) 131 | self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step) 132 | 133 | @classmethod 134 | def sent2ids(cls, sent, word2id, max_word_len=None): 135 | # sent 已分好词 ' '隔开 136 | # 形成batch时才动态补齐长度 137 | words = sent.split(' ') 138 | token_ids = [word2id.get(word, word2id['']) for word in words] 139 | if max_word_len: 140 | token_ids = token_ids[:max_word_len] 141 | # token_ids = utils.pad_sequences([token_ids], padding='post', maxlen=max_word_len)[0] 142 | return token_ids # [len] 143 | 144 | def create_feed_dict_from_data(self, data, ids, mode='train'): 145 | # data:数据已经转为id, data不同字段保存该段字段全量数据 146 | batch_s1 = [data['s1'][i] for i in ids] 147 | batch_s2 = [data['s2'][i] for i in ids] 148 | if len(set([len(e) for e in batch_s1])) != 1: # 长度不等 149 | batch_s1 = utils.pad_sequences(batch_s1, padding='post') 150 | if len(set([len(e) for e in batch_s2])) != 1: # 长度不等 151 | batch_s2 = utils.pad_sequences(batch_s2, padding='post') 152 | feed_dict = { 153 | self.s1: batch_s1, 154 | self.s2: batch_s2, 155 | self.target: [data['target'][i] for i in ids], 156 | } 157 | if mode == 'train': feed_dict['num'] = len(batch_s1) 158 | feed_dict[self.dropout_rate] = conf.dropout_rate if mode == 'train' else 0. 159 | return feed_dict 160 | 161 | def create_feed_dict_from_features(self, features, mode='train'): 162 | # feature:tfrecord数据的example, 每个features的不同字段包括该字段一个batch数据 163 | feed_dict = { 164 | self.s1: features['s1'], 165 | self.s2: features['s2'], 166 | self.target: features['target'], 167 | } 168 | if mode == 'train': feed_dict['num'] = len(features['s1']) 169 | feed_dict[self.dropout_rate] = conf.dropout_rate if mode == 'train' else 0. 170 | return feed_dict 171 | 172 | def create_feed_dict_from_raw(self, batch_s1, batch_s2, batch_y, token2id_dct, mode='infer'): 173 | word2id = token2id_dct['word2id'] 174 | 175 | feed_s1 = [self.sent2ids(s1, word2id) for s1 in batch_s1] 176 | feed_s2 = [self.sent2ids(s2, word2id) for s2 in batch_s2] 177 | 178 | feed_dict = { 179 | self.s1: utils.pad_sequences(feed_s1, padding='post'), 180 | self.s2: utils.pad_sequences(feed_s2, padding='post'), 181 | } 182 | feed_dict[self.dropout_rate] = conf.dropout_rate if mode == 'train' else 0. 183 | 184 | if mode == 'infer': 185 | return feed_dict 186 | 187 | if mode in ['train', 'dev']: 188 | assert batch_y, 'batch_y should not be None when mode is train or dev' 189 | feed_dict[self.target] = batch_y 190 | return feed_dict 191 | 192 | raise ValueError(f'mode type {mode} not support') 193 | 194 | @classmethod 195 | def generate_data(cls, file, token2id_dct): 196 | word2id = token2id_dct['word2id'] 197 | data = { 198 | 's1': [], 199 | 's2': [], 200 | 'target': [] 201 | } 202 | with open(file, 'r', encoding='U8') as f: 203 | for i, line in enumerate(f): 204 | item = line.strip().split('\t') 205 | if len(item) != 3: 206 | print('error', repr(line)) 207 | continue 208 | s1 = item[0] 209 | s2 = item[1] 210 | y = item[2] 211 | s1_ids = cls.sent2ids(s1, word2id, max_word_len=50) 212 | s2_ids = cls.sent2ids(s2, word2id, max_word_len=50) 213 | y_id = int(y) 214 | if i < 5: # check 215 | print(f'check {i}:') 216 | print(f'{s1} -> {s1_ids}') 217 | print(f'{s2} -> {s2_ids}') 218 | print(f'{y} -> {y_id}') 219 | data['s1'].append(s1_ids) 220 | data['s2'].append(s2_ids) 221 | data['target'].append(y_id) 222 | data['num_data'] = len(data['s1']) 223 | return data 224 | 225 | @classmethod 226 | def generate_tfrecord(cls, file, token2id_dct, tfrecord_file): 227 | from qiznlp.common.tfrecord_utils import items2tfrecord 228 | word2id = token2id_dct['word2id'] 229 | 230 | def items_gen(): 231 | with open(file, 'r', encoding='U8') as f: 232 | for i, line in enumerate(f): 233 | item = line.strip().split('\t') 234 | if len(item) != 3: 235 | print('error', repr(line)) 236 | continue 237 | try: 238 | s1 = item[0] 239 | s2 = item[1] 240 | y = item[2] 241 | s1_ids = cls.sent2ids(s1, word2id, max_word_len=50) 242 | s2_ids = cls.sent2ids(s2, word2id, max_word_len=50) 243 | y_id = int(y) 244 | if i < 5: # check 245 | print(f'check {i}:') 246 | print(f'{s1} -> {s1_ids}') 247 | print(f'{s2} -> {s2_ids}') 248 | print(f'{y} -> {y_id}') 249 | d = { 250 | 's1': s1_ids, 251 | 's2': s2_ids, 252 | 'target': y_id, 253 | } 254 | yield d 255 | except Exception as e: 256 | print('Exception occur in items_gen()!\n', e) 257 | continue 258 | 259 | count = items2tfrecord(items_gen(), tfrecord_file) 260 | return count 261 | 262 | @classmethod 263 | def load_tfrecord(cls, tfrecord_file, batch_size=128, index=None, shard=None): 264 | from qiznlp.common.tfrecord_utils import tfrecord2dataset 265 | if not os.path.exists(tfrecord_file): 266 | return None, None 267 | feat_dct = { 268 | # 's1': tf.FixedLenFeature([50], tf.int64), 269 | 's1': tf.VarLenFeature(tf.int64), 270 | 's2': tf.VarLenFeature(tf.int64), 271 | 'target': tf.FixedLenFeature([], tf.int64), 272 | } 273 | dataset, count = tfrecord2dataset(tfrecord_file, feat_dct, batch_size=batch_size, auto_pad=True, index=index, shard=shard) 274 | return dataset, count 275 | 276 | def get_signature_export_model(self): 277 | inputs_dct = { 278 | 's1': self.s1, 279 | 's2': self.s2, 280 | 'dropout_rate': self.dropout_rate, 281 | } 282 | outputs_dct = { 283 | 'y_prob': self.y_prob, 284 | } 285 | return inputs_dct, outputs_dct 286 | 287 | @classmethod 288 | def get_signature_load_pbmodel(cls): 289 | inputs_lst = ['s1', 's2', 'dropout_rate'] 290 | outputs_lst = ['y_prob'] 291 | return inputs_lst, outputs_lst 292 | 293 | @classmethod 294 | def from_pbmodel(cls, pbmodel_dir, sess): 295 | meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], pbmodel_dir) # 从pb模型载入graph和variable,绑定到sess 296 | signature = meta_graph_def.signature_def # 取出signature_def 297 | default_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # or 'chitchat_predict' 签名 298 | inputs_lst, output_lst = cls.get_signature_load_pbmodel() # 类方法 299 | pb_dict = {} 300 | for k in inputs_lst: 301 | pb_dict[k] = sess.graph.get_tensor_by_name(signature[default_key].inputs[k].name) # 从signature中获取输入输出的tensor_name,并从graph中取出 302 | for k in output_lst: 303 | pb_dict[k] = sess.graph.get_tensor_by_name(signature[default_key].outputs[k].name) 304 | model = cls(build_graph=False) # 里面不再构造图 305 | for k, v in pb_dict.items(): 306 | setattr(model, k, v) # 绑定必要的输入输出到实例 307 | return model 308 | 309 | @classmethod 310 | def from_ckpt_meta(cls, ckpt_name, sess, graph): 311 | with graph.as_default(): 312 | saver = tf.train.import_meta_graph(ckpt_name + '.meta', clear_devices=True) 313 | sess.run(tf.global_variables_initializer()) 314 | 315 | model = cls(build_graph=False) # 里面不再构造图 316 | # 绑定必要的输入输出到实例 317 | model.s1 = graph.get_tensor_by_name('s1:0') 318 | model.s2 = graph.get_tensor_by_name('s2:0') 319 | model.dropout_rate = graph.get_tensor_by_name('dropout_rate:0') 320 | # self.target = self.graph.get_tensor_by_name('target:0') 321 | model.y_prob = graph.get_tensor_by_name('y_prob:0') 322 | 323 | saver.restore(sess, ckpt_name) 324 | print(f':: restore success! {ckpt_name}') 325 | return model, saver 326 | -------------------------------------------------------------------------------- /qiznlp/run/run_base.py: -------------------------------------------------------------------------------- 1 | import os, pickle, glob, re 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | import qiznlp.common.utils as utils 6 | utils.suppress_tf_warning(tf) 7 | import qiznlp.common.train_helper as train_helper 8 | 9 | 10 | class Run_Model_Base(): 11 | def __init__(self): 12 | self.model_name = None 13 | self.sess = None 14 | self.graph = None 15 | self.config = None 16 | self.saver = None 17 | self.token2id_dct = None 18 | self.tokenize = None 19 | self.cut = None 20 | self.use_hvd = None 21 | self.model = None 22 | 23 | def save(self, ckpt_dir, model_name=None, epo=None, global_step=None, info_str=''): 24 | if hasattr(self, 'hvd_rank') and self.hvd_rank != 0: # 分布式训练时只需1个进程master来保存 25 | return 26 | if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) 27 | if model_name is None: model_name = self.model_name 28 | if global_step is None: global_step = self.model.global_step 29 | epo = '' if epo is None else f'{epo}-' 30 | 31 | save_path = f'{ckpt_dir}/{model_name}-{epo}{info_str}.ckpt' 32 | 33 | exist_ckpt_path = glob.glob(f'{ckpt_dir}/{model_name}-{epo}*') 34 | if exist_ckpt_path: 35 | ckpt_path = exist_ckpt_path[0].rsplit('.', 1)[0] 36 | [utils.delete_file(file) for file in [ 37 | ckpt_path + '.index', 38 | ckpt_path + '.meta', 39 | ckpt_path + '.data-00000-of-00001' 40 | ]] 41 | 42 | self.saver.save(self.sess, save_path, global_step=global_step) 43 | # e.g. 44 | # trans-2-1.23-1.18.ckpt-228.index 45 | # trans-2-1.23-1.18.ckpt-228.meta 46 | print(f'>>>>>> save ckpt ok! {save_path}') 47 | 48 | def restore(self, ckpt_dir, model_name=None, epo=None, step=None): 49 | if model_name is None: model_name = self.model_name 50 | if epo is not None: 51 | restore_path = glob.glob(f'{ckpt_dir}/{model_name}-{epo}-*.ckpt*')[0].rsplit('.', 1)[0] 52 | elif step is not None: 53 | restore_path = glob.glob(f'{ckpt_dir}/{model_name}*.ckpt-{step}.*')[0].rsplit('.', 1)[0] 54 | else: 55 | restore_path = tf.train.latest_checkpoint(ckpt_dir) 56 | self.saver.restore(self.sess, restore_path) 57 | print(f'<<<<<< restoring ckpt from {restore_path}') 58 | 59 | def delete_ckpt(self, ckpt_dir, model_name=None, epo=None): 60 | if model_name is None: model_name = self.model_name 61 | if epo is None: # 删除所有的epo 62 | epo = '' 63 | exist_ckpt_path = glob.glob(f'{ckpt_dir}/{model_name}-{epo}*') 64 | if exist_ckpt_path: # 如果存在就删除 65 | ckpt_path = exist_ckpt_path[0].rsplit('.', 1)[0] 66 | res = [utils.delete_file(file, verbose=False) for file in [ 67 | ckpt_path + '.index', 68 | ckpt_path + '.meta', 69 | ckpt_path + '.data-00000-of-00001']] 70 | if all(res): 71 | print(f'------ delete ckpt ok! {ckpt_path}') 72 | 73 | def prepare_data(self, data_type, raw_data_file, 74 | preprocess_raw_data, batch_size, 75 | save_data_prefix, 76 | **kwargs): 77 | index = self.hvd_rank if hasattr(self, 'hvd_rank') else None 78 | shard = self.hvd_size if hasattr(self, 'hvd_size') else None 79 | 80 | if data_type == 'tfrecord': 81 | train_tfrecord_file, dev_tfrecord_file, test_tfrecord_file = train_helper.prepare_tfrecord( 82 | raw_data_file, self.model, self.token2id_dct, self.tokenize, 83 | preprocess_raw_data_fn=preprocess_raw_data, 84 | save_data_prefix=save_data_prefix, 85 | **kwargs, 86 | ) 87 | 88 | with self.graph.as_default(): 89 | # 如果载入不成功将返回None, None 90 | train_dataset, train_data_size = self.model.load_tfrecord(train_tfrecord_file, batch_size=batch_size, index=index, shard=shard) 91 | dev_dataset, dev_data_size = self.model.load_tfrecord(dev_tfrecord_file, batch_size=batch_size) 92 | test_dataset, test_data_size = self.model.load_tfrecord(test_tfrecord_file, batch_size=batch_size) 93 | 94 | # 获得迭代的tfrecord example Tensor 95 | train_features = train_dataset.make_one_shot_iterator().get_next() if train_dataset else None 96 | dev_features = dev_dataset.make_one_shot_iterator().get_next() if dev_dataset else None 97 | test_features = test_dataset.make_one_shot_iterator().get_next() if test_dataset else None 98 | 99 | def gen_feed_dict(i, epo, mode='train'): 100 | nonlocal train_features, dev_features, test_features 101 | if mode == 'train': 102 | assert train_features 103 | features = self.sess.run(train_features) 104 | if i == 0 and epo == 1: 105 | print('inspect tfrecord features: (show first two element)') 106 | for k, v in features.items(): 107 | print(f'{k}: {v.shape}{v.tolist()[:2]}') 108 | return self.model.create_feed_dict_from_features(features, 'train') 109 | elif mode == 'dev': 110 | assert dev_features 111 | features = self.sess.run(dev_features) 112 | return self.model.create_feed_dict_from_features(features, 'dev') 113 | elif mode == 'test': 114 | assert test_features 115 | features = self.sess.run(test_features) 116 | return self.model.create_feed_dict_from_features(features, 'test') 117 | else: 118 | raise Exception('unsupport mode type') 119 | 120 | 121 | elif data_type == 'pkldata': 122 | train_pkl_file, dev_pkl_file, test_pkl_file = train_helper.prepare_pkldata( 123 | raw_data_file, self.model, self.token2id_dct, self.tokenize, 124 | preprocess_raw_data_fn=preprocess_raw_data, 125 | save_data_prefix=save_data_prefix, 126 | **kwargs, 127 | ) 128 | train_data, dev_data, test_data, train_data_size, dev_data_size, test_data_size = (None,) * 6 129 | trn_total_ids, dev_total_ids, test_total_ids = (None,) * 3 130 | if os.path.exists(train_pkl_file): 131 | train_data = pickle.load(open(train_pkl_file, 'rb')) 132 | train_data_size = train_data['num_data'] 133 | print(f'loading exist train pkl file ok! {train_pkl_file}') 134 | if os.path.exists(dev_pkl_file): 135 | dev_data = pickle.load(open(dev_pkl_file, 'rb')) 136 | dev_data_size = dev_data['num_data'] 137 | print(f'loading exist dev pkl file ok! {dev_pkl_file}') 138 | dev_total_ids = list(range(dev_data_size)) # dev 按顺序就行 139 | if os.path.exists(test_pkl_file): 140 | test_data = pickle.load(open(test_pkl_file, 'rb')) 141 | test_data_size = test_data['num_data'] 142 | print(f'loading exist test pkl file ok! {test_pkl_file}') 143 | test_total_ids = list(range(test_data_size)) # test 按顺序就行 144 | 145 | curr_epo = -1 146 | 147 | def gen_feed_dict(i, epo, mode='train'): 148 | nonlocal curr_epo, trn_total_ids, dev_total_ids, test_total_ids, train_data, dev_data, test_data 149 | if mode == 'train': 150 | assert train_data 151 | if curr_epo != epo: # 换了一个epo了 152 | np.random.seed(epo) 153 | trn_total_ids = np.random.permutation(train_data_size) # 根据不同epo从新打乱数据 154 | curr_epo = epo 155 | if i == 0 and epo == 1: 156 | print('inspect pkl data:') 157 | for k in train_data: 158 | if k != 'num_data': 159 | v = [train_data[k][i] for i in trn_total_ids[:2]] 160 | print(f'{k}: {v}') 161 | ids = trn_total_ids[i * batch_size:(i + 1) * batch_size] 162 | return self.model.create_feed_dict_from_data(train_data, ids, 'train') 163 | elif mode == 'dev': 164 | assert dev_data 165 | dev_ids = dev_total_ids[i * batch_size:(i + 1) * batch_size] 166 | return self.model.create_feed_dict_from_data(dev_data, dev_ids, 'dev') 167 | elif mode == 'test': 168 | assert test_data 169 | test_ids = test_total_ids[i * batch_size:(i + 1) * batch_size] 170 | return self.model.create_feed_dict_from_data(test_data, test_ids, 'test') 171 | else: 172 | raise Exception('unsupport mode type') 173 | else: 174 | raise Exception('unsupport data type') 175 | 176 | train_epo_steps, dev_epo_steps, test_epo_steps = (None,) * 3 177 | if train_data_size: 178 | train_epo_steps = (train_data_size - 1) // batch_size + 1 179 | if dev_data_size: 180 | dev_epo_steps = (dev_data_size - 1) // batch_size + 1 181 | if test_data_size: 182 | test_epo_steps = (test_data_size - 1) // batch_size + 1 183 | 184 | print('\nTraining Data INFO') 185 | print('batch_size:', batch_size) 186 | print('train_data_size:', train_data_size) 187 | print('train_epo_steps:', train_epo_steps) 188 | print('dev_data_size:', dev_data_size) 189 | print('dev_epo_steps:', dev_epo_steps) 190 | print('test_data_size:', test_data_size) 191 | print('test_epo_steps:', test_epo_steps) 192 | print('') 193 | 194 | return train_epo_steps, dev_epo_steps, test_epo_steps, gen_feed_dict 195 | 196 | def stop_training(self, early_stop_patience, train_info, indicator='dev_acc', greater_is_better=True): 197 | # e.g. patience=3 第2、3、4epo都低于第1epo, 则停止 198 | if not train_info: 199 | return False 200 | assert indicator in list(train_info.values())[0], f'indicator {indicator} not in train_info' 201 | patience = early_stop_patience 202 | epo_list = sorted(train_info.keys()) 203 | if len(epo_list) <= patience: 204 | return False 205 | pivot = epo_list[-(patience + 1)] 206 | flags = [] 207 | for e in epo_list[-patience:]: 208 | if greater_is_better: # 指标越大越好,如准确率 209 | flags.append(train_info[pivot][indicator] >= train_info[e][indicator]) # 第1轮比第2、3、4指标都好则停止 210 | else: # 指标越小越好,如loss 211 | flags.append(train_info[pivot][indicator] <= train_info[e][indicator]) # 第1轮比第2、3、4指标都好则停止 212 | return all(flags) 213 | 214 | def should_save(self, curr_epo, train_info, indicator='dev_acc', greater_is_better=True): 215 | if not train_info: 216 | return True 217 | if len(train_info) == 1: 218 | return True 219 | assert indicator in list(train_info.values())[0], f'indicator {indicator} not in train_info' 220 | indicator_lst = [train_info[e][indicator] for e in train_info if e != curr_epo] 221 | best_indicator = max(indicator_lst) if greater_is_better else min(indicator_lst) 222 | if greater_is_better: 223 | if train_info[curr_epo][indicator] > best_indicator: 224 | return True 225 | else: 226 | return False 227 | else: 228 | if train_info[curr_epo][indicator] < best_indicator: 229 | return True 230 | else: 231 | return False 232 | 233 | def get_best_epo(self, train_info, indicator='dev_acc', greater_is_better=True): 234 | if not train_info: 235 | return 0 236 | if len(train_info) == 1: 237 | return 1 238 | assert indicator in list(train_info.values())[0], f'indicator {indicator} not in train_info' 239 | indicator_lst = [[epo, v[indicator]] for epo, v in train_info.items()] 240 | indicator_lst.sort(key=lambda e: e[0], reverse=True) # 优先后面的epo 241 | indicator_lst.sort(key=lambda e: e[1], reverse=True if greater_is_better else False) 242 | return indicator_lst[0][0] 243 | 244 | def export_model(self, pbmodel_dir): 245 | with self.graph.as_default(): # 坑 246 | builder = tf.saved_model.builder.SavedModelBuilder(pbmodel_dir) 247 | 248 | inputs, outputs = self.model.get_signature_export_model() 249 | signature_def_inputs = {k: tf.saved_model.utils.build_tensor_info(v) for k, v in inputs.items()} 250 | signature_def_outputs = {k: tf.saved_model.utils.build_tensor_info(v) for k, v in outputs.items()} 251 | 252 | default_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # or 'chitchat_predict' 签名 253 | signature_def_map = {default_key: 254 | tf.saved_model.signature_def_utils.build_signature_def( 255 | inputs=signature_def_inputs, 256 | outputs=signature_def_outputs, 257 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME 258 | ) 259 | } 260 | builder.add_meta_graph_and_variables( 261 | self.sess, 262 | [tf.saved_model.tag_constants.SERVING], 263 | signature_def_map=signature_def_map 264 | ) 265 | builder.save(as_text=False) 266 | print(f'export pb model ok! {pbmodel_dir}') 267 | return pbmodel_dir 268 | 269 | 270 | def check_and_update_param_of_model_pyfile(param_dict, model_inst): 271 | # 字典大小自动对齐 272 | # param_dict = { 273 | # 'vocab_size': (param_value, dict_value), 274 | # 'label_size': (param_value, dict_value), 275 | # } 276 | param_check = {k: p_v == d_v for k, (p_v, d_v) in param_dict.items()} 277 | if not all(list(param_check.values())): 278 | # 获取要修改的model.py文件路径 e.g. **/**/cls_model.py 279 | module_package_str = type(model_inst).__module__ 280 | pyfile = __import__(module_package_str, fromlist=module_package_str.split('.')).__file__ 281 | print('some param should be update:') 282 | for param_name, check_success in param_check.items(): 283 | if not check_success: 284 | param_value, dict_value = param_dict[param_name] 285 | print(f'{param_name} => param: {param_value} != dict: {dict_value}') 286 | change_param_of_pyfile(pyfile, dict_value, param=param_name) 287 | print(f'update {param_name} success') 288 | print('script will exit! please run the script again! e.g. python run_***.py') 289 | exit(0) 290 | 291 | 292 | def change_param_of_pyfile(py_filename, value, param='vocab_size'): 293 | with open(py_filename, 'r', encoding='U8') as f: 294 | pycode_str = f.read() 295 | # print(repr(pycode_str)) 296 | changed_pycode_str = change_param(pycode_str, value, param) 297 | with open(py_filename, 'w', encoding='U8') as f: 298 | f.write(changed_pycode_str) 299 | 300 | 301 | def change_param(pycode_str, value, param): 302 | def _change_param(m): 303 | # print(repr(m.group(1))) 304 | # print(repr(m.group(2))) 305 | return f'{m.group(1)}{value}{m.group(2)}' 306 | 307 | # sample of pycode_str: "conf = utils.dict2obj({\\n 'vocab_size': 123,\\n 'label_size': 321,\\n" 308 | changed_pycode_str = re.sub(r'([\'\"]' + param + r'[\'\"]:\s*)\d+(,\s*)', _change_param, pycode_str) 309 | return changed_pycode_str -------------------------------------------------------------------------------- /qiznlp/run/run_s2l.py: -------------------------------------------------------------------------------- 1 | import os, sys, re 2 | import time 3 | import jieba 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import qiznlp.common.utils as utils 8 | from qiznlp.run.run_base import Run_Model_Base, check_and_update_param_of_model_pyfile 9 | 10 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 11 | sys.path.append(curr_dir + '/..') # 添加上级目录即默认qiznlp根目录 12 | from model.s2l_model import Model as S2L_Model 13 | 14 | try: 15 | import horovod.tensorflow as hvd 16 | # 示例:horovodrun -np 2 -H localhost:2 python run_s2l.py 17 | except: 18 | HVD_ENABLE = False 19 | else: 20 | HVD_ENABLE = True 21 | 22 | conf = utils.dict2obj({ 23 | 'early_stop_patience': None, 24 | 'just_save_best': True, 25 | 'n_epochs': 10, 26 | 'data_type': 'tfrecord', 27 | # 'data_type': 'pkldata', 28 | }) 29 | 30 | 31 | class Run_Model_S2L(Run_Model_Base): 32 | def __init__(self, model_name, tokenize=None, pbmodel_dir=None, use_hvd=False): 33 | # 维护sess graph config saver 34 | self.model_name = model_name 35 | if tokenize is None: 36 | self.jieba = jieba.Tokenizer() 37 | # self.jieba.load_userdict(f'{curr_dir}/segword.dct') 38 | self.tokenize = lambda t: self.jieba.lcut(re.sub(r'\s+', ',', t)) 39 | else: 40 | self.tokenize = tokenize 41 | self.cut = lambda t: ' '.join(self.tokenize(t)) 42 | self.token2id_dct = { 43 | # 'char2id': utils.Any2Id.from_file(f'{curr_dir}/../data/s2l_char2id.dct', use_line_no=True), # 自有数据 44 | # 'bmeo2id': utils.Any2Id.from_file(f'{curr_dir}/../data/s2l_bmeo2id.dct', use_line_no=True), # 自有数据 45 | 'char2id': utils.Any2Id.from_file(f'{curr_dir}/../data/rner_s2l_char2id.dct', use_line_no=True), # ResumeNER 46 | 'bmeo2id': utils.Any2Id.from_file(f'{curr_dir}/../data/rner_s2l_bmeo2id.dct', use_line_no=True), # ResumeNER 47 | } 48 | self.config = tf.ConfigProto(allow_soft_placement=True, 49 | gpu_options=tf.GPUOptions(allow_growth=True), 50 | ) 51 | self.use_hvd = use_hvd if HVD_ENABLE else False 52 | if self.use_hvd: 53 | hvd.init() 54 | self.hvd_rank = hvd.rank() 55 | self.hvd_size = hvd.size() 56 | self.config.gpu_options.visible_device_list = str(hvd.local_rank()) 57 | self.graph = tf.Graph() 58 | self.sess = tf.Session(config=self.config, graph=self.graph) 59 | 60 | if pbmodel_dir is not None: # 只能做predict 61 | self.model = S2L_Model.from_pbmodel(pbmodel_dir, self.sess) 62 | else: 63 | with self.graph.as_default(): 64 | self.model = S2L_Model(model_name=self.model_name, run_model=self) 65 | if self.use_hvd: 66 | self.model.optimizer._lr = self.model.optimizer._lr * self.hvd_size # 分布式训练大batch增大学习率 67 | self.model.hvd_optimizer = hvd.DistributedOptimizer(self.model.optimizer) 68 | self.model.train_op = self.model.hvd_optimizer.minimize(self.model.loss, global_step=self.model.global_step) 69 | self.sess.run(tf.global_variables_initializer()) 70 | if self.use_hvd: 71 | self.sess.run(hvd.broadcast_global_variables(0)) 72 | 73 | with self.graph.as_default(): 74 | self.saver = tf.train.Saver(max_to_keep=100) # must in the graph context 75 | 76 | def train_step(self, feed_dict): 77 | _, step, loss, accuracy = self.sess.run([self.model.train_op, 78 | self.model.global_step, 79 | self.model.loss, 80 | self.model.accuracy, 81 | ], 82 | feed_dict=feed_dict) 83 | return step, loss, accuracy 84 | 85 | def eval_step(self, feed_dict): 86 | loss, accuracy = self.sess.run([self.model.loss, 87 | self.model.accuracy, 88 | ], 89 | feed_dict=feed_dict) 90 | return loss, accuracy 91 | 92 | def train(self, ckpt_dir, raw_data_file, preprocess_raw_data, batch_size=100, save_data_prefix=None): 93 | save_data_prefix = os.path.basename(ckpt_dir) if save_data_prefix is None else save_data_prefix 94 | train_epo_steps, dev_epo_steps, test_epo_steps, gen_feed_dict = self.prepare_data(conf.data_type, 95 | raw_data_file, 96 | preprocess_raw_data, 97 | batch_size, 98 | save_data_prefix=save_data_prefix, 99 | update_txt=False, 100 | ) 101 | self.is_master = True 102 | if hasattr(self, 'hvd_rank') and self.hvd_rank != 0: # 分布式训练且非master 103 | dev_epo_steps, test_epo_steps = None, None # 不进行验证和测试 104 | self.is_master = False 105 | 106 | # 字典大小自动对齐 107 | check_and_update_param_of_model_pyfile({ 108 | 'vocab_size': (self.model.conf.vocab_size, len(self.token2id_dct['char2id'])), 109 | 'label_size': (self.model.conf.label_size, len(self.token2id_dct['bmeo2id'])), 110 | }, self.model) 111 | 112 | train_info = {} 113 | for epo in range(1, 1 + conf.n_epochs): 114 | train_info[epo] = {} 115 | 116 | # train 117 | time0 = time.time() 118 | epo_num_example = 0 119 | trn_epo_loss = [] 120 | trn_epo_acc = [] 121 | for i in range(train_epo_steps): 122 | feed_dict = gen_feed_dict(i, epo, 'train') 123 | epo_num_example += feed_dict.pop('num') 124 | 125 | step_start_time = time.time() 126 | step, loss, acc = self.train_step(feed_dict) 127 | trn_epo_loss.append(loss) 128 | trn_epo_acc.append(acc) 129 | 130 | if self.is_master: 131 | print(f'\repo:{epo} step:{i + 1}/{train_epo_steps} num:{epo_num_example} ' 132 | f'cur_loss:{loss:.3f} epo_loss:{np.mean(trn_epo_loss):.3f} ' 133 | f'epo_acc:{np.mean(trn_epo_acc):.3f} ' 134 | f'sec/step:{time.time() - step_start_time:.2f}', 135 | end=f'{os.linesep if i == train_epo_steps - 1 else ""}', 136 | ) 137 | 138 | trn_loss = np.mean(trn_epo_loss) 139 | trn_acc = np.mean(trn_epo_acc) 140 | if self.is_master: 141 | print(f'epo:{epo} trn loss {trn_loss:.3f} ' 142 | f'trn acc {trn_acc:.3f} ' 143 | f'elapsed {(time.time() - time0) / 60:.2f} min') 144 | train_info[epo]['trn_loss'] = trn_loss 145 | train_info[epo]['trn_acc'] = trn_acc 146 | 147 | if not self.is_master: 148 | continue 149 | 150 | # dev or test 151 | for mode in ['dev', 'test']: 152 | epo_steps = {'dev': dev_epo_steps, 'test': test_epo_steps}[mode] 153 | if epo_steps is None: 154 | continue 155 | time0 = time.time() 156 | epo_loss = [] 157 | epo_acc = [] 158 | for i in range(epo_steps): 159 | feed_dict = gen_feed_dict(i, epo, mode) 160 | loss, acc = self.eval_step(feed_dict) 161 | 162 | epo_loss.append(loss) 163 | epo_acc.append(acc) 164 | 165 | loss = np.mean(epo_loss) 166 | acc = np.mean(epo_acc) 167 | print(f'epo:{epo} {mode} loss {loss:.3f} ' 168 | f'{mode} acc {acc:.3f} ' 169 | f'elapsed {(time.time() - time0) / 60:.2f} min') 170 | train_info[epo][f'{mode}_loss'] = loss 171 | train_info[epo][f'{mode}_acc'] = acc 172 | 173 | info_str = f'{trn_loss:.2f}-{train_info[epo]["dev_loss"]:.2f}-{train_info[epo]["test_loss"]:.2f}' 174 | info_str += f'-{trn_acc:.3f}-{train_info[epo]["dev_acc"]:.3f}-{train_info[epo]["test_acc"]:.3f}' 175 | 176 | if conf.just_save_best: 177 | if self.should_save(epo, train_info, 'dev_acc', greater_is_better=True): 178 | self.delete_ckpt(ckpt_dir=ckpt_dir) # 删掉已存在的 179 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 180 | else: 181 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 182 | 183 | utils.obj2json(train_info, f'{ckpt_dir}/metrics.json') 184 | print('=' * 40, end='\n') 185 | if conf.early_stop_patience: 186 | if self.stop_training(conf.early_stop_patience, train_info, 'dev_acc'): 187 | print('early stop training!') 188 | print('train_info', train_info) 189 | break 190 | 191 | def predict(self, s1_lst, need_cut=True, batch_size=100): 192 | if need_cut: 193 | s1_lst = [self.cut(s1) for s1 in s1_lst] 194 | if not hasattr(self, 'bmeo2id'): self.bmeo2id = self.token2id_dct['bmeo2id'] 195 | if not hasattr(self, 'id2bmeo'): self.id2bmeo = self.token2id_dct['bmeo2id'].get_reverse() 196 | pred_lst = [] 197 | for i in range(0, len(s1_lst), batch_size): 198 | batch_s1 = s1_lst[i:i + batch_size] 199 | feed_dict = self.model.create_feed_dict_from_raw(batch_s1, [], self.token2id_dct, mode='infer') 200 | ner_pred, ner_prob = self.sess.run([self.model.ner_pred, self.model.ner_prob], feed_dict) # [batch] 201 | # ner_pred [batch,len] 202 | # ner_prob [batch] 203 | ner_pred = ner_pred.tolist() 204 | for i, pred in enumerate(ner_pred): 205 | while pred[-1] == self.bmeo2id['']: 206 | pred.pop(-1) 207 | ner_pred[i] = ' '.join([self.id2bmeo.get(bmeoid, '') for bmeoid in pred]) 208 | pred_lst.extend(ner_pred) 209 | return pred_lst 210 | 211 | 212 | def preprocess_raw_data(file, tokenize, token2id_dct, **kwargs): 213 | """ 214 | # 处理自有数据函数模板 215 | # file文件数据格式: 句子(以空格分好)\t标签(以空格分好) 216 | # [filter] 过滤 217 | # [segment] 分词 ner一般仅分字,用空格隔开,不需分词步骤 218 | # [build vocab] 构造词典 219 | # [split] train-dev-test 220 | """ 221 | items = utils.file2items(file) 222 | # 过滤 223 | # filter here 224 | 225 | print('过滤后数据量', len(items)) 226 | 227 | # 划分 228 | train_items, dev_items, test_items = utils.split_file(items, ratio='18:1:1', shuffle=True, seed=1234) 229 | 230 | # 构造词典(option) 231 | need_to_rebuild = [] 232 | for token2id_name in token2id_dct: 233 | if not token2id_dct[token2id_name]: 234 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 235 | need_to_rebuild.append(token2id_name) 236 | 237 | if need_to_rebuild: 238 | print(f'生成缺失词表文件...{need_to_rebuild}') 239 | for items in [train_items, dev_items]: # 字典只统计train和dev 240 | for item in items: 241 | if 'char2id' in need_to_rebuild: 242 | token2id_dct['char2id'].to_count(item[0].split(' ')) 243 | if 'bmeo2id' in need_to_rebuild: 244 | token2id_dct['bmeo2id'].to_count(item[1].split(' ')) 245 | if 'char2id' in need_to_rebuild: 246 | token2id_dct['char2id'].rebuild_by_counter(restrict=['', ''], min_freq=1, max_vocab_size=5000) 247 | token2id_dct['char2id'].save(f'{curr_dir}/../data/s2l_char2id.dct') 248 | if 'bmeo2id' in need_to_rebuild: 249 | token2id_dct['bmeo2id'].rebuild_by_counter(restrict=['', '']) 250 | token2id_dct['bmeo2id'].save(f'{curr_dir}/../data/s2l_bmeo2id.dct') 251 | else: 252 | print('使用已有词表文件...') 253 | 254 | return train_items, dev_items, test_items 255 | 256 | 257 | def preprocess_common_dataset_ResumeNER(file, tokenize, token2id_dct, **kwargs): 258 | train_file = f'{curr_dir}/../data/train.char.bmes.txt' 259 | dev_file = f'{curr_dir}/../data/dev.char.bmes.txt' 260 | test_file = f'{curr_dir}/../data/test.char.bmes.txt' 261 | 262 | # 转为行 用空格分隔 263 | def change2line(file): 264 | exm_lst = [] 265 | items = utils.file2items(file, deli=' ') 266 | curr_sent = [] 267 | curr_bmeo = [] 268 | 269 | for item in items: 270 | if len(item) == 1: # 分隔标志 [''] 271 | if curr_sent and curr_bmeo: 272 | exm_lst.append([' '.join(curr_sent), ' '.join(curr_bmeo)]) 273 | curr_sent, curr_bmeo = [], [] 274 | continue 275 | curr_sent.append(item[0]) 276 | curr_bmeo.append(item[1]) 277 | if curr_sent and curr_bmeo: 278 | exm_lst.append([' '.join(curr_sent), ' '.join(curr_bmeo)]) 279 | return exm_lst 280 | 281 | train_items = change2line(train_file) 282 | dev_items = change2line(dev_file) 283 | test_items = change2line(test_file) 284 | 285 | # 构造词典(option) 286 | need_to_rebuild = [] 287 | for token2id_name in token2id_dct: 288 | if not token2id_dct[token2id_name]: 289 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 290 | need_to_rebuild.append(token2id_name) 291 | 292 | if need_to_rebuild: 293 | print(f'生成缺失词表文件...{need_to_rebuild}') 294 | for items in [train_items, dev_items]: # 字典只统计train和dev 295 | for item in items: 296 | if 'char2id' in need_to_rebuild: 297 | token2id_dct['char2id'].to_count(item[0].split(' ')) 298 | if 'bmeo2id' in need_to_rebuild: 299 | token2id_dct['bmeo2id'].to_count(item[1].split(' ')) 300 | if 'char2id' in need_to_rebuild: 301 | token2id_dct['char2id'].rebuild_by_counter(restrict=['', ''], min_freq=1, max_vocab_size=5000) 302 | token2id_dct['char2id'].save(f'{curr_dir}/../data/rner_s2l_char2id.dct') 303 | if 'bmeo2id' in need_to_rebuild: 304 | token2id_dct['bmeo2id'].rebuild_by_counter(restrict=['', '']) 305 | token2id_dct['bmeo2id'].save(f'{curr_dir}/../data/rner_s2l_bmeo2id.dct') 306 | else: 307 | print('使用已有词表文件...') 308 | 309 | return train_items, dev_items, test_items 310 | 311 | 312 | if __name__ == '__main__': 313 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用CPU设为'-1' 314 | 315 | rm_s2l = Run_Model_S2L('birnn') # use BiLSTM 316 | # rm_s2l = Run_Model_S2L('idcnn') # use IDCNN 317 | # rm_s2l = Run_Model_S2L('bert_crf') # use bert+crf 318 | 319 | # 训练自有数据 320 | # rm_s2l.train('s2l_ckpt_1', '../data/s2l_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512) # train 321 | 322 | # 训练ResumeNER语料 323 | rm_s2l.train('s2l_ckpt_RNER1', '', preprocess_raw_data=preprocess_common_dataset_ResumeNER, batch_size=512) # train 324 | 325 | # demo命名实体识别ResumeNER 326 | rm_s2l.restore('s2l_ckpt_RNER1') # for infer 327 | import readline 328 | while True: 329 | try: 330 | inp = input('enter:') 331 | sent1 = ' '.join(inp) # NER分字 332 | time0 = time.time() 333 | ret = rm_s2l.predict([sent1], need_cut=False) 334 | print(ret[0]) 335 | print('elapsed:', time.time() - time0) 336 | except KeyboardInterrupt: 337 | exit(0) 338 | -------------------------------------------------------------------------------- /qiznlp/run/run_multi_s2s.py: -------------------------------------------------------------------------------- 1 | import os, sys, re 2 | import time 3 | import jieba 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import qiznlp.common.utils as utils 8 | from qiznlp.run.run_base import Run_Model_Base, check_and_update_param_of_model_pyfile 9 | 10 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 11 | sys.path.append(curr_dir + '/..') # 添加上级目录即默认qiznlp根目录 12 | from model.multi_s2s_model import Model as MS2S_Model 13 | 14 | try: 15 | import horovod.tensorflow as hvd 16 | # 示例:horovodrun -np 2 -H localhost:2 python run_s2s.py 17 | # 可根据local_rank设置shard的数据,保证各个gpu采样的数据不重叠。 18 | except: 19 | HVD_ENABLE = False 20 | else: 21 | HVD_ENABLE = True 22 | 23 | conf = utils.dict2obj({ 24 | 'early_stop_patience': None, 25 | 'just_save_best': True, 26 | 'n_epochs': 10, 27 | 'data_type': 'tfrecord', 28 | # 'data_type': 'pkldata', 29 | }) 30 | 31 | 32 | class Run_Model_MS2S(Run_Model_Base): 33 | def __init__(self, model_name, tokenize=None, pbmodel_dir=None, use_hvd=False): 34 | # 维护sess graph config saver 35 | self.model_name = model_name 36 | if tokenize is None: 37 | self.jieba = jieba.Tokenizer() 38 | # self.jieba.load_userdict(f'{curr_dir}/segword.dct') 39 | self.tokenize = lambda t: self.jieba.lcut(re.sub(r'\s+', ',', t)) 40 | else: 41 | self.tokenize = tokenize 42 | self.cut = lambda t: ' '.join(self.tokenize(t)) 43 | self.token2id_dct = { 44 | # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/ms2s_char2id.dct', use_line_no=True), # 自有数据 45 | 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/XHJDB_ms2s_char2id.dct', use_line_no=True), # 小黄鸡+豆瓣 46 | } 47 | self.config = tf.ConfigProto(allow_soft_placement=True, 48 | gpu_options=tf.GPUOptions(allow_growth=True), 49 | ) 50 | self.use_hvd = use_hvd if HVD_ENABLE else False 51 | if self.use_hvd: 52 | hvd.init() 53 | self.hvd_rank = hvd.rank() 54 | self.hvd_size = hvd.size() 55 | self.config.gpu_options.visible_device_list = str(hvd.local_rank()) 56 | self.graph = tf.Graph() 57 | self.sess = tf.Session(config=self.config, graph=self.graph) 58 | 59 | if pbmodel_dir is not None: # 只能做predict 60 | self.model = MS2S_Model.from_pbmodel(pbmodel_dir, self.sess) 61 | else: 62 | with self.graph.as_default(): 63 | self.model = MS2S_Model(model_name=self.model_name, run_model=self) 64 | if self.use_hvd: 65 | self.model.optimizer._lr = self.model.optimizer._lr * self.hvd_size # 分布式训练大batch增大学习率 66 | self.model.hvd_optimizer = hvd.DistributedOptimizer(self.model.optimizer) 67 | self.model.train_op = self.model.hvd_optimizer.minimize(self.model.loss, global_step=self.model.global_step) 68 | self.sess.run(tf.global_variables_initializer()) 69 | if self.use_hvd: 70 | self.sess.run(hvd.broadcast_global_variables(0)) 71 | 72 | with self.graph.as_default(): 73 | self.saver = tf.train.Saver(max_to_keep=100) # must in the graph context 74 | 75 | def train_step(self, feed_dict): 76 | _, step, loss = self.sess.run([self.model.train_op, 77 | self.model.global_step, 78 | self.model.loss, 79 | ], 80 | feed_dict=feed_dict) 81 | return step, loss 82 | 83 | def eval_step(self, feed_dict): 84 | loss, = self.sess.run([self.model.loss, 85 | ], 86 | feed_dict=feed_dict) 87 | return loss 88 | 89 | def train(self, ckpt_dir, raw_data_file, preprocess_raw_data, batch_size = 100, save_data_prefix = None): 90 | save_data_prefix = os.path.basename(ckpt_dir) if save_data_prefix is None else save_data_prefix 91 | train_epo_steps, dev_epo_steps, test_epo_steps, gen_feed_dict = self.prepare_data(conf.data_type, 92 | raw_data_file, 93 | preprocess_raw_data, 94 | batch_size, 95 | save_data_prefix=save_data_prefix, 96 | update_txt=False, 97 | ) 98 | self.is_master = True 99 | if hasattr(self, 'hvd_rank') and self.hvd_rank != 0: # 分布式训练且非master 100 | dev_epo_steps, test_epo_steps = None, None # 不进行验证和测试 101 | self.is_master = False 102 | 103 | # 字典大小自动对齐 104 | check_and_update_param_of_model_pyfile({ 105 | 'vocab_size': (self.model.conf.vocab_size, len(self.token2id_dct['word2id'])), 106 | }, self.model) 107 | 108 | train_info = {} 109 | for epo in range(1, 1 + conf.n_epochs): 110 | train_info[epo] = {} 111 | 112 | # train 113 | time0 = time.time() 114 | epo_num_example = 0 115 | trn_epo_loss = [] 116 | for i in range(train_epo_steps): 117 | feed_dict = gen_feed_dict(i, epo, 'train') 118 | epo_num_example += feed_dict.pop('num') 119 | 120 | step_start_time = time.time() 121 | step, loss = self.train_step(feed_dict) 122 | trn_epo_loss.append(loss) 123 | 124 | if self.is_master: 125 | print(f'\repo:{epo} step:{i + 1}/{train_epo_steps} num:{epo_num_example} ' 126 | f'cur_loss:{loss:.3f} epo_loss:{np.mean(trn_epo_loss):.3f} ' 127 | f'sec/step:{time.time() - step_start_time:.2f}', 128 | end=f'{os.linesep if i == train_epo_steps - 1 else ""}', 129 | ) 130 | 131 | trn_loss = np.mean(trn_epo_loss) 132 | if self.is_master: 133 | print(f'epo:{epo} trn loss {trn_loss:.3f} ' 134 | f'elapsed {(time.time() - time0) / 60:.2f} min') 135 | train_info[epo]['trn_loss'] = trn_loss 136 | 137 | if not self.is_master: 138 | continue 139 | 140 | # dev or test 141 | for mode in ['dev', 'test']: 142 | epo_steps = {'dev': dev_epo_steps, 'test': test_epo_steps}[mode] 143 | if epo_steps is None: 144 | continue 145 | time0 = time.time() 146 | epo_loss = [] 147 | for i in range(epo_steps): 148 | feed_dict = gen_feed_dict(i, epo, mode) 149 | loss = self.eval_step(feed_dict) 150 | 151 | epo_loss.append(loss) 152 | 153 | loss = np.mean(epo_loss) 154 | print(f'epo:{epo} {mode} loss {loss:.3f} ' 155 | f'elapsed {(time.time() - time0) / 60:.2f} min') 156 | train_info[epo][f'{mode}_loss'] = loss 157 | 158 | info_str = f'{trn_loss:.2f}' 159 | info_str += f'-{train_info[epo]["dev_loss"]:.2f}' 160 | 161 | if conf.just_save_best: 162 | if self.should_save(epo, train_info, 'dev_loss', greater_is_better=False): 163 | self.delete_ckpt(ckpt_dir=ckpt_dir) # 删掉已存在的 164 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 165 | else: 166 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 167 | 168 | utils.obj2json(train_info, f'{ckpt_dir}/metrics.json') 169 | print('=' * 40, end='\n') 170 | if conf.early_stop_patience: 171 | if self.stop_training(conf.early_stop_patience, train_info, 'dev_loss', greater_is_better=False): 172 | print('early stop training!') 173 | print('train_info', train_info) 174 | break 175 | 176 | def predict(self, multi_s1_lst, need_cut=True, batch_size=100): 177 | if need_cut: 178 | multi_s1_lst = ['$$$'.join([self.cut(s1) for s1 in multi_s1.split('$$$')]) for multi_s1 in multi_s1_lst] 179 | if not hasattr(self, 'word2id'): self.word2id = self.token2id_dct['word2id'] 180 | if not hasattr(self, 'id2word'): self.id2word = self.token2id_dct['word2id'].get_reverse() 181 | pred_s2_lst = [] 182 | for i in range(0, len(multi_s1_lst), batch_size): 183 | batch_multi_s1 = multi_s1_lst[i:i + batch_size] 184 | feed_dict = self.model.create_feed_dict_from_raw(batch_multi_s1, [], self.token2id_dct, mode='infer') 185 | s2_ids, s2_score = self.sess.run([self.model.decoded_ids, self.model.scores], feed_dict) 186 | # s2_ids: [batch, beam, len] 187 | # s2_score: [batch, beam] 188 | s2_ids = s2_ids.tolist() 189 | for batch_idx in range(len(s2_ids)): 190 | beam_sents = s2_ids[batch_idx] 191 | for beam_idx, sent in enumerate(beam_sents): 192 | while sent[-1] == self.word2id['']: 193 | sent.pop(-1) 194 | beam_sents[beam_idx] = ''.join([self.id2word.get(wid, '') for wid in sent]) 195 | pred_s2_lst.append(beam_sents) 196 | return pred_s2_lst 197 | 198 | 199 | def preprocess_raw_data(file, tokenize, token2id_dct, **kwargs): 200 | """ 201 | # 处理自有数据函数模板 202 | # file文件数据格式: 多轮对话句子1\t多轮对话句子2\t...\t多轮对话句子n 203 | # [filter] 过滤 204 | # [segment] 分词 205 | # [build vocab] 构造词典 206 | # [split] train-dev-test 207 | """ 208 | seg_file = file.rsplit('.', 1)[0] + '_seg.txt' 209 | if not os.path.exists(seg_file): 210 | sess_lst = utils.file2items(file) 211 | # 过滤 212 | # filter here 213 | 214 | print('过滤后数据量', len(sess_lst)) 215 | 216 | # 分词 217 | for i, sess in enumerate(sess_lst): 218 | sess_lst[i] = [' '.join(s) for s in sess] # 按字分 219 | # sess_lst[i] = [' '.join(tokenize(s)) for s in sess] # 按词分 220 | utils.list2file(seg_file, sess_lst) 221 | print('保存分词后数据成功', '数据量', len(sess_lst), seg_file) 222 | else: 223 | # 读取分词好的数据 224 | sess_lst = utils.file2items(seg_file) 225 | 226 | # 转为多轮格式 multi-turn之间用$$$分隔 227 | items = [] 228 | for sess in sess_lst: 229 | for i in range(1, len(sess)): 230 | multi_src = '$$$'.join(sess[:i]) 231 | tgt = sess[i] 232 | items.append([multi_src, tgt]) 233 | # items: [['w w w$$$w w', 'w w w'],...] 234 | 235 | # 划分 不分测试集 236 | train_items, dev_items = utils.split_file(items, ratio='19:1', shuffle=True, seed=1234) 237 | 238 | # 构造词典(option) 239 | need_to_rebuild = [] 240 | for token2id_name in token2id_dct: 241 | if not token2id_dct[token2id_name]: 242 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 243 | need_to_rebuild.append(token2id_name) 244 | 245 | if need_to_rebuild: 246 | print(f'生成缺失词表文件...{need_to_rebuild}') 247 | for items in [train_items, dev_items]: # 字典只统计train和dev 248 | for item in items: 249 | if 'word2id' in need_to_rebuild: 250 | for sent in item[0].split('$$$'): 251 | token2id_dct['word2id'].to_count(sent.split(' ')) 252 | token2id_dct['word2id'].to_count(item[1].split(' ')) 253 | if 'word2id' in need_to_rebuild: 254 | token2id_dct['word2id'].rebuild_by_counter(restrict=['', '', ''], min_freq=1, max_vocab_size=4000) 255 | token2id_dct['word2id'].save(f'{curr_dir}/../data/ms2s_char2id.dct') 256 | else: 257 | print('使用已有词表文件...') 258 | 259 | return train_items, dev_items, None 260 | 261 | 262 | def preprocess_common_dataset_XiaoHJ_and_Douban(file, tokenize, token2id_dct, **kwargs): 263 | # 小黄鸡单轮+豆瓣多轮语料(都按子分) 264 | XHJ_file = f'{curr_dir}/../data/XHJ_5w.txt' 265 | DB_file = f'{curr_dir}/../data/Douban_Sess662.txt' 266 | 267 | def XiaoHJchange2items(file): 268 | # 转为[src, tgt]格式 按字分 269 | lines = utils.file2list(file) 270 | items = [line.split(' ', 1) for line in lines] 271 | exm_lst = [] 272 | sent_lst = [] 273 | for item in items: 274 | if len(item) == 1 and item[0] == 'E': # 分隔标志 275 | if sent_lst: 276 | src_tgt_lst = zip(sent_lst, sent_lst[1:]) 277 | exm_lst.extend([[' '.join(src), ' '.join(tgt)] for src, tgt in src_tgt_lst]) 278 | sent_lst = [] 279 | continue 280 | if item[0] == 'M' and item[1]: # 有些数据只有M 281 | sent_lst.append(item[1]) 282 | if sent_lst: 283 | src_tgt_lst = zip(sent_lst, sent_lst[1:]) 284 | exm_lst.extend([[' '.join(src), ' '.join(tgt)] for src, tgt in src_tgt_lst]) 285 | return exm_lst 286 | 287 | def Doubanchange2items(file): 288 | # 转为[multi_src, tgt]格式 按字分 289 | exm_lst = [] 290 | sess_lst = utils.file2items(file) 291 | for sess in sess_lst: 292 | sess = [' '.join(s) for s in sess] # 按字分 293 | for i in range(1, len(sess)): 294 | multi_src = '$$$'.join(sess[:i]) 295 | tgt = sess[i] 296 | exm_lst.append([multi_src, tgt]) 297 | return exm_lst 298 | 299 | # XiaoHJ和Douban数据不分词,直接按字分,但为了方便词典仍旧叫word2id 300 | items = XiaoHJchange2items(XHJ_file) # [['w w w$$$w w', 'w w w'],...] 301 | items += Doubanchange2items(DB_file) 302 | 303 | # 划分 不分测试集 304 | train_items, dev_items = utils.split_file(items, ratio='19:1', shuffle=True, seed=1234) 305 | 306 | # 构造词典(option) 307 | need_to_rebuild = [] 308 | for token2id_name in token2id_dct: 309 | if not token2id_dct[token2id_name]: 310 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 311 | need_to_rebuild.append(token2id_name) 312 | 313 | if need_to_rebuild: 314 | print(f'生成缺失词表文件...{need_to_rebuild}') 315 | for items in [train_items, dev_items]: # 字典只统计train和dev 316 | for item in items: 317 | if 'word2id' in need_to_rebuild: 318 | for sent in item[0].split('$$$'): 319 | token2id_dct['word2id'].to_count(sent.split(' ')) 320 | token2id_dct['word2id'].to_count(item[1].split(' ')) 321 | if 'word2id' in need_to_rebuild: 322 | token2id_dct['word2id'].rebuild_by_counter(restrict=['', '', ''], min_freq=1, max_vocab_size=4000) 323 | token2id_dct['word2id'].save(f'{curr_dir}/../data/XHJDB_ms2s_char2id.dct') 324 | else: 325 | print('使用已有词表文件...') 326 | 327 | return train_items, dev_items, None 328 | 329 | 330 | if __name__ == '__main__': 331 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用CPU设为'-1' 332 | 333 | # rm_ms2s = Run_Model_MS2S('HRED') # use HRED 334 | # rm_ms2s = Run_Model_MS2S('HRAN') # use HRAN 335 | rm_ms2s = Run_Model_MS2S('RECOSA') # use RECOSA 336 | 337 | # 训练自有数据 338 | # rm_ms2s.train('multi_s2s_ckpt_1', '../data/multi_s2s_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512) # train 339 | 340 | # 训练小黄鸡单轮+豆瓣多轮语料 341 | rm_ms2s.train('multi_s2s_ckpt_XHJDB1', '', preprocess_raw_data=preprocess_common_dataset_XiaoHJ_and_Douban, batch_size=512) # train 342 | 343 | # demo小黄鸡+豆瓣多轮聊天机器人 344 | rm_ms2s.restore('multi_s2s_ckpt_XHJDB1') # for infer 345 | import readline 346 | while True: 347 | try: 348 | inp = input('enter:($$$分隔多轮句子)') 349 | sent1 = '$$$'.join([' '.join(s) for s in inp.split('$$$')]) # 分字 350 | time0 = time.time() 351 | ret = rm_ms2s.predict([sent1], need_cut=False) 352 | print(ret[0]) 353 | print('elapsed:', time.time() - time0) 354 | except KeyboardInterrupt: 355 | exit(0) 356 | 357 | -------------------------------------------------------------------------------- /qiznlp/run/run_cls.py: -------------------------------------------------------------------------------- 1 | import os, sys, re 2 | import time 3 | import jieba 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import qiznlp.common.utils as utils 8 | from qiznlp.run.run_base import Run_Model_Base, check_and_update_param_of_model_pyfile 9 | 10 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 11 | sys.path.append(curr_dir + '/..') # 添加上级目录即默认qiznlp根目录 12 | from model.cls_model import Model as CLS_Model 13 | 14 | try: 15 | import horovod.tensorflow as hvd 16 | # 示例:horovodrun -np 2 -H localhost:2 python run_cls.py 17 | # 可根据local_rank设置shard的数据,保证各个gpu采样的数据不重叠。 18 | except: 19 | HVD_ENABLE = False 20 | else: 21 | HVD_ENABLE = True 22 | 23 | conf = utils.dict2obj({ 24 | 'early_stop_patience': None, 25 | 'just_save_best': True, 26 | 'n_epochs': 20, 27 | 'data_type': 'tfrecord', 28 | # 'data_type': 'pkldata', 29 | }) 30 | 31 | 32 | def print_(*args, should_print=True, **kwargs): 33 | if should_print: 34 | print(*args, **kwargs) 35 | 36 | 37 | class Run_Model_Cls(Run_Model_Base): 38 | def __init__(self, model_name, tokenize=None, pbmodel_dir=None, use_hvd=False): 39 | # 维护sess graph config saver 40 | self.model_name = model_name 41 | if tokenize is None: 42 | self.jieba = jieba.Tokenizer() 43 | # self.jieba.load_userdict(f'{curr_dir}/segword.dct') 44 | self.tokenize = lambda t: self.jieba.lcut(re.sub(r'\s+', ',', t)) 45 | else: 46 | self.tokenize = tokenize 47 | self.cut = lambda t: ' '.join(self.tokenize(t)) 48 | self.token2id_dct = { 49 | # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/cls_word2id.dct', use_line_no=True), # 自有数据 50 | # 'label2id': utils.Any2Id.from_file(f'{curr_dir}/../data/cls_label2id.dct', use_line_no=True), # 自有数据 51 | 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiao_cls_word2id.dct', use_line_no=True), # toutiao新闻 52 | 'label2id': utils.Any2Id.from_file(f'{curr_dir}/../data/toutiao_cls_label2id.dct', use_line_no=True), # toutiao新闻 53 | } 54 | self.config = tf.ConfigProto(allow_soft_placement=True, 55 | gpu_options=tf.GPUOptions(allow_growth=True), 56 | ) 57 | self.use_hvd = use_hvd if HVD_ENABLE else False 58 | if self.use_hvd: 59 | hvd.init() 60 | self.hvd_rank = hvd.rank() 61 | self.hvd_size = hvd.size() 62 | self.config.gpu_options.visible_device_list = str(hvd.local_rank()) 63 | self.graph = tf.Graph() 64 | self.sess = tf.Session(config=self.config, graph=self.graph) 65 | 66 | if pbmodel_dir is not None: # 只能做predict 67 | self.model = CLS_Model.from_pbmodel(pbmodel_dir, self.sess) 68 | else: 69 | with self.graph.as_default(): 70 | self.model = CLS_Model(model_name=self.model_name, run_model=self) 71 | if self.use_hvd: 72 | self.model.optimizer._lr = self.model.optimizer._lr * self.hvd_size # 分布式训练大batch增大学习率 73 | self.model.hvd_optimizer = hvd.DistributedOptimizer(self.model.optimizer) 74 | self.model.train_op = self.model.hvd_optimizer.minimize(self.model.loss, global_step=self.model.global_step) 75 | self.sess.run(tf.global_variables_initializer()) 76 | if self.use_hvd: 77 | self.sess.run(hvd.broadcast_global_variables(0)) 78 | 79 | with self.graph.as_default(): 80 | self.saver = tf.train.Saver(max_to_keep=100) # must in the graph context 81 | 82 | def train_step(self, feed_dict): 83 | _, step, loss, accuracy = self.sess.run([self.model.train_op, 84 | self.model.global_step, 85 | self.model.loss, 86 | self.model.accuracy, 87 | ], 88 | feed_dict=feed_dict) 89 | return step, loss, accuracy 90 | 91 | def eval_step(self, feed_dict): 92 | loss, accuracy, y_prob = self.sess.run([self.model.loss, 93 | self.model.accuracy, 94 | self.model.y_prob, 95 | ], 96 | feed_dict=feed_dict) 97 | return loss, accuracy, y_prob 98 | 99 | def train(self, ckpt_dir, raw_data_file, preprocess_raw_data, batch_size=100, save_data_prefix=None): 100 | save_data_prefix = os.path.basename(ckpt_dir) if save_data_prefix is None else save_data_prefix 101 | train_epo_steps, dev_epo_steps, test_epo_steps, gen_feed_dict = self.prepare_data(conf.data_type, 102 | raw_data_file, 103 | preprocess_raw_data, 104 | batch_size, 105 | save_data_prefix=save_data_prefix, 106 | update_txt=True, 107 | ) 108 | self.is_master = True 109 | if self.use_hvd and self.hvd_rank != 0: # 分布式训练且非master 110 | dev_epo_steps, test_epo_steps = None, None # 不进行验证和测试 111 | self.is_master = False 112 | 113 | # 字典大小自动对齐 114 | check_and_update_param_of_model_pyfile({ 115 | 'vocab_size': (self.model.conf.vocab_size, len(self.token2id_dct['word2id'])), 116 | 'label_size': (self.model.conf.label_size, len(self.token2id_dct['label2id'])) 117 | }, self.model) 118 | 119 | train_info = {} 120 | for epo in range(1, 1 + conf.n_epochs): 121 | train_info[epo] = {} 122 | 123 | # train 124 | time0 = time.time() 125 | epo_num_example = 0 126 | trn_epo_loss = [] 127 | trn_epo_acc = [] 128 | for i in range(train_epo_steps): 129 | feed_dict = gen_feed_dict(i, epo, 'train') 130 | epo_num_example += feed_dict.pop('num') 131 | 132 | step_start_time = time.time() 133 | step, loss, acc = self.train_step(feed_dict) 134 | trn_epo_loss.append(loss) 135 | trn_epo_acc.append(acc) 136 | 137 | if self.is_master: 138 | print(f'\repo:{epo} step:{i + 1}/{train_epo_steps} num:{epo_num_example} ' 139 | f'cur_loss:{loss:.3f} epo_loss:{np.mean(trn_epo_loss):.3f} ' 140 | f'epo_acc:{np.mean(trn_epo_acc):.3f} ' 141 | f'sec/step:{time.time() - step_start_time:.2f}', 142 | end=f'{os.linesep if i == train_epo_steps - 1 else ""}', 143 | ) 144 | 145 | trn_loss = np.mean(trn_epo_loss) 146 | trn_acc = np.mean(trn_epo_acc) 147 | if self.is_master: 148 | print(f'epo:{epo} trn loss {trn_loss:.3f} ' 149 | f'trn acc {trn_acc:.3f} ' 150 | f'elapsed {(time.time() - time0) / 60:.2f} min') 151 | train_info[epo]['trn_loss'] = trn_loss 152 | train_info[epo]['trn_acc'] = trn_acc 153 | 154 | if not self.is_master: 155 | continue 156 | 157 | # dev or test 158 | for mode in ['dev', 'test']: 159 | epo_steps = {'dev': dev_epo_steps, 'test': test_epo_steps}[mode] 160 | if epo_steps is None: 161 | continue 162 | time0 = time.time() 163 | epo_loss = [] 164 | epo_acc = [] 165 | for i in range(epo_steps): 166 | feed_dict = gen_feed_dict(i, epo, mode) 167 | loss, acc, y_prob = self.eval_step(feed_dict) 168 | 169 | epo_loss.append(loss) 170 | epo_acc.append(acc) 171 | 172 | loss = np.mean(epo_loss) 173 | acc = np.mean(epo_acc) 174 | print(f'epo:{epo} {mode} loss {loss:.3f} ' 175 | f'{mode} acc {acc:.3f} ' 176 | f'elapsed {(time.time() - time0) / 60:.2f} min') 177 | train_info[epo][f'{mode}_loss'] = loss 178 | train_info[epo][f'{mode}_acc'] = acc 179 | 180 | info_str = f'{trn_loss:.2f}-{trn_acc:.3f}' 181 | # info_str += f'-{train_info[epo]["dev_loss"]:.2f}-{train_info[epo]["dev_acc"]:.3f}' 182 | # info_str += f'-{train_info[epo]["test_loss"]:.2f}-{train_info[epo]["test_acc"]:.3f}' 183 | 184 | if conf.just_save_best: 185 | if self.should_save(epo, train_info, 'dev_acc', greater_is_better=True): 186 | self.delete_ckpt(ckpt_dir=ckpt_dir) # 删掉已存在的 187 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 188 | else: 189 | self.save(ckpt_dir=ckpt_dir, epo=epo, info_str=info_str) 190 | 191 | utils.obj2json(train_info, f'{ckpt_dir}/metrics.json') 192 | print('=' * 40, end='\n\n') 193 | if conf.early_stop_patience: 194 | if self.stop_training(conf.early_stop_patience, train_info, 'dev_acc'): 195 | print('early stop training!') 196 | print('train_info', train_info) 197 | break 198 | 199 | def predict(self, s1_lst, need_cut=True, batch_size=100): 200 | if need_cut: 201 | s1_lst = [self.cut(s1) for s1 in s1_lst] 202 | if not hasattr(self, 'label2id'): self.label2id = self.token2id_dct['label2id'] 203 | if not hasattr(self, 'id2label'): self.id2label = self.token2id_dct['label2id'].get_reverse() 204 | 205 | pred_lst = [] 206 | for i in range(0, len(s1_lst), batch_size): 207 | batch_s1 = s1_lst[i:i + batch_size] 208 | feed_dict = self.model.create_feed_dict_from_raw(batch_s1, [], self.token2id_dct, mode='infer') 209 | probs = self.sess.run(self.model.y_prob, feed_dict) # [batch, num_cls] 210 | pred = np.argmax(probs, -1) # [batch] 211 | pred_label = [self.id2label[p] for p in pred] 212 | pred_prob = [p[pred] for p in probs] 213 | pred_lst.extend(zip(pred_label, pred_prob)) 214 | return pred_lst 215 | 216 | 217 | def preprocess_raw_data(file, tokenize, token2id_dct, **kwargs): 218 | """ 219 | # 处理自有数据函数模板 220 | # file文件数据格式: 句子\t类别 221 | # [filter] 过滤 222 | # [segment] 分词 223 | # [build vocab] 构造词典 224 | # [split] train-dev-test 225 | """ 226 | seg_file = file.rsplit('.', 1)[0] + '_seg.txt' 227 | if not os.path.exists(seg_file): 228 | items = utils.file2items(file) 229 | # 过滤 230 | # filter here 231 | 232 | print('过滤后数据量', len(items)) 233 | 234 | # 分词 235 | for i, item in enumerate(items): 236 | item[0] = ' '.join(tokenize(item[0])) 237 | utils.list2file(seg_file, items) 238 | print('保存分词后数据成功', '数据量', len(items), seg_file) 239 | else: 240 | # 读取分词好的数据 241 | items = utils.file2items(seg_file) 242 | 243 | # 划分 244 | train_items, dev_items, test_items = utils.split_file(items, ratio='18:1:1', shuffle=True, seed=1234) 245 | 246 | # 构造词典(option) 247 | need_to_rebuild = [] 248 | for token2id_name in token2id_dct: 249 | if not token2id_dct[token2id_name]: 250 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 251 | need_to_rebuild.append(token2id_name) 252 | 253 | if need_to_rebuild: 254 | print(f'生成缺失词表文件...{need_to_rebuild}') 255 | for items in [train_items, dev_items]: # 字典只统计train和dev 256 | for item in items: 257 | if 'word2id' in need_to_rebuild: 258 | token2id_dct['word2id'].to_count(item[0].split(' ')) 259 | if 'label2id' in need_to_rebuild: 260 | token2id_dct['label2id'].to_count([item[1]]) 261 | if 'word2id' in need_to_rebuild: 262 | token2id_dct['word2id'].rebuild_by_counter(restrict=['', '', ''], min_freq=5, max_vocab_size=30000) 263 | token2id_dct['word2id'].save(f'{curr_dir}/../data/cls_word2id.dct') 264 | if 'label2id' in need_to_rebuild: 265 | token2id_dct['label2id'].rebuild_by_counter(restrict=['']) 266 | token2id_dct['label2id'].save(f'{curr_dir}/../data/cls_label2id.dct') 267 | else: 268 | print('使用已有词表文件...') 269 | 270 | return train_items, dev_items, test_items 271 | 272 | 273 | def preprocess_common_dataset_Toutiao(file, tokenize, token2id_dct, **kwargs): 274 | train_file = f'{curr_dir}/../data/train.toutiao.cls.txt' 275 | dev_file = f'{curr_dir}/../data/valid.toutiao.cls.txt' 276 | test_file = f'{curr_dir}/../data/test.toutiao.cls.txt' 277 | items_lst = [] 278 | for file in [train_file, dev_file, test_file]: 279 | seg_file = file.rsplit('.', 1)[0] + '_seg.txt' # 原始文本分词并保存为_seg.txt后缀文件 280 | if not os.path.exists(seg_file): 281 | items = utils.file2items(file, deli='\t') 282 | # 过滤 283 | # filter here 284 | 285 | print('过滤后数据量', len(items)) 286 | 287 | # 分词 288 | for i, item in enumerate(items): 289 | item[0] = ' '.join(tokenize(item[0])) 290 | utils.list2file(seg_file, items) 291 | print('保存分词后数据成功', '数据量', len(items), seg_file) 292 | items_lst.append(items) 293 | else: 294 | # 读取分词好的数据 295 | items_lst.append(utils.file2items(seg_file)) 296 | 297 | train_items, dev_items, test_items = items_lst 298 | 299 | # 构造词典(option) 300 | need_to_rebuild = [] 301 | for token2id_name in token2id_dct: 302 | if not token2id_dct[token2id_name]: 303 | print(f'字典{token2id_name} 载入不成功, 将生成并保存') 304 | need_to_rebuild.append(token2id_name) 305 | 306 | if need_to_rebuild: 307 | print(f'生成缺失词表文件...{need_to_rebuild}') 308 | for items in [train_items, dev_items]: # 字典只统计train和dev 309 | for item in items: 310 | if 'word2id' in need_to_rebuild: 311 | token2id_dct['word2id'].to_count(item[0].split(' ')) 312 | if 'label2id' in need_to_rebuild: 313 | token2id_dct['label2id'].to_count([item[1]]) 314 | if 'word2id' in need_to_rebuild: 315 | token2id_dct['word2id'].rebuild_by_counter(restrict=['', ''], min_freq=1, max_vocab_size=20000) 316 | token2id_dct['word2id'].save(f'{curr_dir}/../data/toutiao_cls_word2id.dct') 317 | if 'label2id' in need_to_rebuild: 318 | token2id_dct['label2id'].rebuild_by_counter(restrict=['']) 319 | token2id_dct['label2id'].save(f'{curr_dir}/../data/toutiao_cls_label2id.dct') 320 | else: 321 | print('使用已有词表文件...') 322 | 323 | return train_items, dev_items, test_items 324 | 325 | 326 | if __name__ == '__main__': 327 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用CPU设为'-1' 328 | 329 | # rm_cls = Run_Model_Cls('trans_meanpool') # use transformer_encoder with mean_pooling 330 | rm_cls = Run_Model_Cls('trans_mhattnpool') # use transformer_encoder with multi_head_pooling 331 | # rm_cls = Run_Model_Cls('bert_cls') # use transformer_encoder with multi_head_pooling 332 | 333 | # 训练自有数据 334 | # rm_cls.train('cls_ckpt_1', '../data/cls_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512) # train 335 | 336 | # 训练toutiao新闻语料 337 | rm_cls.train('cls_ckpt_toutiao1', '', preprocess_raw_data=preprocess_common_dataset_Toutiao, batch_size=128) # train 338 | 339 | # exit(0) 340 | # demo头条新闻分类 341 | rm_cls.restore('cls_ckpt_toutiao1') # for infer 342 | import readline 343 | while True: 344 | try: 345 | sent = input('enter:') 346 | need_cut = False if ' ' in sent else True 347 | time0 = time.time() 348 | ret = rm_cls.predict([sent], need_cut=need_cut) 349 | print(ret[0]) 350 | print('elapsed:', time.time() - time0) 351 | except KeyboardInterrupt: 352 | exit(0) 353 | 354 | -------------------------------------------------------------------------------- /qiznlp/common/modules/DAM/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from . import operations as op 3 | def similarity(x, y, x_lengths, y_lengths): 4 | '''calculate similarity with two 3d tensor. 5 | 6 | Args: 7 | x: a tensor with shape [batch, time_x, dimension] 8 | y: a tensor with shape [batch, time_y, dimension] 9 | 10 | Returns: 11 | a tensor with shape [batch, time_x, time_y] 12 | 13 | Raises: 14 | ValueError: if 15 | the dimenisons of x and y are not equal. 16 | ''' 17 | with tf.variable_scope('x_attend_y'): 18 | try: 19 | x_a_y = block( 20 | x, y, y, 21 | Q_lengths=x_lengths, K_lengths=y_lengths) 22 | except ValueError: 23 | tf.get_variable_scope().reuse_variables() 24 | x_a_y = block( 25 | x, y, y, 26 | Q_lengths=x_lengths, K_lengths=y_lengths) 27 | 28 | with tf.variable_scope('y_attend_x'): 29 | try: 30 | y_a_x = block( 31 | y, x, x, 32 | Q_lengths=y_lengths, K_lengths=x_lengths) 33 | except ValueError: 34 | tf.get_variable_scope().reuse_variables() 35 | y_a_x = block( 36 | y, x, x, 37 | Q_lengths=y_lengths, K_lengths=x_lengths) 38 | 39 | return tf.matmul(x + x_a_y, y + y_a_x, transpose_b=True) 40 | 41 | 42 | def dynamic_L(x): 43 | '''Attention machanism to combine the infomation, 44 | from https://arxiv.org/pdf/1612.01627.pdf. 45 | 46 | Args: 47 | x: a tensor with shape [batch, time, dimension] 48 | 49 | Returns: 50 | a tensor with shape [batch, dimension] 51 | 52 | Raises: 53 | ''' 54 | key_0 = tf.get_variable( 55 | name='key', 56 | shape=[x.shape[-1]], 57 | dtype=tf.float32, 58 | initializer=tf.random_uniform_initializer( 59 | -tf.sqrt(6./tf.cast(x.shape[-1], tf.float32)), 60 | tf.sqrt(6./tf.cast(x.shape[-1], tf.float32)))) 61 | 62 | key = op.dense(x, add_bias=False) #[batch, time, dimension] 63 | weight = tf.reduce_sum(tf.multiply(key, key_0), axis=-1) #[batch, time] 64 | weight = tf.expand_dims(tf.nn.softmax(weight), -1) #[batch, time, 1] 65 | 66 | L = tf.reduce_sum(tf.multiply(x, weight), axis=1) #[batch, dimension] 67 | return L 68 | 69 | def loss(x, y, num_classes=2, is_clip=True, clip_value=10): 70 | '''From info x calculate logits as return loss. 71 | 72 | Args: 73 | x: a tensor with shape [batch, dimension] 74 | num_classes: a number 75 | 76 | Returns: 77 | loss: a tensor with shape [1], which is the average loss of one batch 78 | logits: a tensor with shape [batch, 1] 79 | 80 | Raises: 81 | AssertionError: if 82 | num_classes is not a int greater equal than 2. 83 | TODO: 84 | num_classes > 2 may be not adapted. 85 | ''' 86 | assert isinstance(num_classes, int) 87 | assert num_classes >= 2 88 | 89 | W = tf.get_variable( 90 | name='weights', 91 | shape=[x.shape[-1], num_classes-1], 92 | initializer=tf.orthogonal_initializer()) 93 | bias = tf.get_variable( 94 | name='bias', 95 | shape=[num_classes-1], 96 | initializer=tf.zeros_initializer()) 97 | 98 | logits = tf.reshape(tf.matmul(x, W) + bias, [-1]) 99 | loss = tf.nn.sigmoid_cross_entropy_with_logits( 100 | labels=tf.cast(y, tf.float32), 101 | logits=logits) 102 | loss = tf.reduce_mean(tf.clip_by_value(loss, -clip_value, clip_value)) 103 | 104 | return loss, logits 105 | 106 | def attention( 107 | Q, K, V, 108 | Q_lengths, K_lengths, 109 | attention_type='dot', 110 | is_mask=True, mask_value=-2**32+1, 111 | drop_prob=None): 112 | '''Add attention layer. 113 | Args: 114 | Q: a tensor with shape [batch, Q_time, Q_dimension] 115 | K: a tensor with shape [batch, time, K_dimension] 116 | V: a tensor with shape [batch, time, V_dimension] 117 | 118 | Q_length: a tensor with shape [batch] 119 | K_length: a tensor with shape [batch] 120 | 121 | Returns: 122 | a tensor with shape [batch, Q_time, V_dimension] 123 | 124 | Raises: 125 | AssertionError: if 126 | Q_dimension not equal to K_dimension when attention type is dot. 127 | ''' 128 | assert attention_type in ('dot', 'bilinear') 129 | if attention_type == 'dot': 130 | assert Q.shape[-1] == K.shape[-1] 131 | 132 | Q_time = tf.shape(Q)[1] 133 | K_time = tf.shape(K)[1] 134 | 135 | if attention_type == 'dot': 136 | logits = op.dot_sim(Q, K) #[batch, Q_time, time] 137 | if attention_type == 'bilinear': 138 | logits = op.bilinear_sim(Q, K) 139 | 140 | if is_mask: 141 | mask = op.mask(Q_lengths, K_lengths, Q_time, K_time) #[batch, Q_time, K_time] 142 | logits = mask * logits + (1 - mask) * mask_value 143 | 144 | attention = tf.nn.softmax(logits) 145 | 146 | if drop_prob is not None: 147 | # print('use attention drop') 148 | attention = tf.layers.dropout(attention, rate=drop_prob) 149 | 150 | return op.weighted_sum(attention, V) 151 | 152 | def FFN(x, out_dimension_0=None, out_dimension_1=None): 153 | '''Add two dense connected layer, max(0, x*W0+b0)*W1+b1. 154 | 155 | Args: 156 | x: a tensor with shape [batch, time, dimension] 157 | out_dimension: a number which is the output dimension 158 | 159 | Returns: 160 | a tensor with shape [batch, time, out_dimension] 161 | 162 | Raises: 163 | ''' 164 | with tf.variable_scope('FFN_1'): 165 | y = op.dense(x, out_dimension_0) 166 | y = tf.nn.relu(y) 167 | with tf.variable_scope('FFN_2'): 168 | z = op.dense(y, out_dimension_1) #, add_bias=False) #!!!! 169 | return z 170 | 171 | def block( 172 | Q, K, V, 173 | Q_lengths, K_lengths, 174 | attention_type='dot', 175 | is_layer_norm=True, 176 | is_mask=True, mask_value=-2**32+1, 177 | drop_prob=None): 178 | '''Add a block unit from https://arxiv.org/pdf/1706.03762.pdf. 179 | Args: 180 | Q: a tensor with shape [batch, Q_time, Q_dimension] 181 | K: a tensor with shape [batch, time, K_dimension] 182 | V: a tensor with shape [batch, time, V_dimension] 183 | 184 | Q_length: a tensor with shape [batch] 185 | K_length: a tensor with shape [batch] 186 | 187 | Returns: 188 | a tensor with shape [batch, time, dimension] 189 | 190 | Raises: 191 | ''' 192 | att = attention(Q, K, V, 193 | Q_lengths, K_lengths, 194 | attention_type=attention_type, 195 | is_mask=is_mask, mask_value=mask_value, 196 | drop_prob=drop_prob) 197 | # attn [batch, q_len, hid] 198 | # K_lengths [batch] 199 | tmp = tf.expand_dims(tf.expand_dims(tf.cast(tf.not_equal(K_lengths, 0), tf.float32), -1), -1) # [batch,1,1] 200 | att = att * tmp 201 | 202 | if is_layer_norm: 203 | with tf.variable_scope('attention_layer_norm'): 204 | y = op.layer_norm_debug(Q + att) 205 | else: 206 | y = Q + att 207 | 208 | z = FFN(y) 209 | if is_layer_norm: 210 | with tf.variable_scope('FFN_layer_norm'): 211 | w = op.layer_norm_debug(y + z) 212 | else: 213 | w = y + z 214 | return w 215 | 216 | def CNN(x, out_channels, filter_size, pooling_size, add_relu=True): 217 | '''Add a convlution layer with relu and max pooling layer. 218 | 219 | Args: 220 | x: a tensor with shape [batch, in_height, in_width, in_channels] 221 | out_channels: a number 222 | filter_size: a number 223 | pooling_size: a number 224 | 225 | Returns: 226 | a flattened tensor with shape [batch, num_features] 227 | 228 | Raises: 229 | ''' 230 | #calculate the last dimension of return 231 | num_features = ((tf.shape(x)[1]-filter_size+1)/pooling_size * 232 | (tf.shape(x)[2]-filter_size+1)/pooling_size) * out_channels 233 | 234 | in_channels = x.shape[-1] 235 | weights = tf.get_variable( 236 | name='filter', 237 | shape=[filter_size, filter_size, in_channels, out_channels], 238 | dtype=tf.float32, 239 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 240 | bias = tf.get_variable( 241 | name='bias', 242 | shape=[out_channels], 243 | dtype=tf.float32, 244 | initializer=tf.zeros_initializer()) 245 | 246 | conv = tf.nn.conv2d(x, weights, strides=[1, 1, 1, 1], padding="VALID") 247 | conv = conv + bias 248 | 249 | if add_relu: 250 | conv = tf.nn.relu(conv) 251 | 252 | pooling = tf.nn.max_pool( 253 | conv, 254 | ksize=[1, pooling_size, pooling_size, 1], 255 | strides=[1, pooling_size, pooling_size, 1], 256 | padding="VALID") 257 | 258 | return tf.contrib.layers.flatten(pooling) 259 | 260 | def CNN_3d(x, out_channels_0, out_channels_1, add_relu=True): 261 | '''Add a 3d convlution layer with relu and max pooling layer. 262 | 263 | Args: 264 | x: a tensor with shape [batch, in_depth, in_height, in_width, in_channels] 265 | out_channels: a number 266 | filter_size: a number 267 | pooling_size: a number 268 | 269 | Returns: 270 | a flattened tensor with shape [batch, num_features] 271 | 272 | Raises: 273 | ''' 274 | in_channels = x.shape[-1] 275 | weights_0 = tf.get_variable( 276 | name='filter_0', 277 | shape=[3, 3, 3, in_channels, out_channels_0], 278 | dtype=tf.float32, 279 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 280 | bias_0 = tf.get_variable( 281 | name='bias_0', 282 | shape=[out_channels_0], 283 | dtype=tf.float32, 284 | initializer=tf.zeros_initializer()) 285 | 286 | conv_0 = tf.nn.conv3d(x, weights_0, strides=[1, 1, 1, 1, 1], padding="SAME") 287 | print('conv_0 shape: %s' %conv_0.shape) 288 | conv_0 = conv_0 + bias_0 289 | 290 | if add_relu: 291 | conv_0 = tf.nn.elu(conv_0) 292 | 293 | pooling_0 = tf.nn.max_pool3d( 294 | conv_0, 295 | ksize=[1, 3, 3, 3, 1], 296 | strides=[1, 3, 3, 3, 1], 297 | padding="SAME") 298 | print('pooling_0 shape: %s' %pooling_0.shape) 299 | 300 | #layer_1 301 | weights_1 = tf.get_variable( 302 | name='filter_1', 303 | shape=[3, 3, 3, out_channels_0, out_channels_1], 304 | dtype=tf.float32, 305 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 306 | bias_1 = tf.get_variable( 307 | name='bias_1', 308 | shape=[out_channels_1], 309 | dtype=tf.float32, 310 | initializer=tf.zeros_initializer()) 311 | 312 | conv_1 = tf.nn.conv3d(pooling_0, weights_1, strides=[1, 1, 1, 1, 1], padding="SAME") 313 | print('conv_1 shape: %s' %conv_1.shape) 314 | conv_1 = conv_1 + bias_1 315 | 316 | if add_relu: 317 | conv_1 = tf.nn.elu(conv_1) 318 | 319 | pooling_1 = tf.nn.max_pool3d( 320 | conv_1, 321 | ksize=[1, 3, 3, 3, 1], 322 | strides=[1, 3, 3, 3, 1], 323 | padding="SAME") 324 | print('pooling_1 shape: %s' %pooling_1.shape) 325 | 326 | return tf.contrib.layers.flatten(pooling_1) 327 | 328 | def CNN_3d_2d(x, out_channels_0, out_channels_1, add_relu=True): 329 | '''Add a 3d convlution layer with relu and max pooling layer. 330 | 331 | Args: 332 | x: a tensor with shape [batch, in_depth, in_height, in_width, in_channels] 333 | out_channels: a number 334 | filter_size: a number 335 | pooling_size: a number 336 | 337 | Returns: 338 | a flattened tensor with shape [batch, num_features] 339 | 340 | Raises: 341 | ''' 342 | in_channels = x.shape[-1] 343 | weights_0 = tf.get_variable( 344 | name='filter_0', 345 | shape=[1, 3, 3, in_channels, out_channels_0], 346 | dtype=tf.float32, 347 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 348 | bias_0 = tf.get_variable( 349 | name='bias_0', 350 | shape=[out_channels_0], 351 | dtype=tf.float32, 352 | initializer=tf.zeros_initializer()) 353 | 354 | conv_0 = tf.nn.conv3d(x, weights_0, strides=[1, 1, 1, 1, 1], padding="SAME") 355 | print('conv_0 shape: %s' %conv_0.shape) 356 | conv_0 = conv_0 + bias_0 357 | 358 | if add_relu: 359 | conv_0 = tf.nn.elu(conv_0) 360 | 361 | pooling_0 = tf.nn.max_pool3d( 362 | conv_0, 363 | ksize=[1, 1, 3, 3, 1], 364 | strides=[1, 1, 3, 3, 1], 365 | padding="SAME") 366 | print('pooling_0 shape: %s' %pooling_0.shape) 367 | 368 | #layer_1 369 | weights_1 = tf.get_variable( 370 | name='filter_1', 371 | shape=[1, 3, 3, out_channels_0, out_channels_1], 372 | dtype=tf.float32, 373 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 374 | bias_1 = tf.get_variable( 375 | name='bias_1', 376 | shape=[out_channels_1], 377 | dtype=tf.float32, 378 | initializer=tf.zeros_initializer()) 379 | 380 | conv_1 = tf.nn.conv3d(pooling_0, weights_1, strides=[1, 1, 1, 1, 1], padding="SAME") 381 | print('conv_1 shape: %s' %conv_1.shape) 382 | conv_1 = conv_1 + bias_1 383 | 384 | if add_relu: 385 | conv_1 = tf.nn.elu(conv_1) 386 | 387 | pooling_1 = tf.nn.max_pool3d( 388 | conv_1, 389 | ksize=[1, 1, 3, 3, 1], 390 | strides=[1, 1, 3, 3, 1], 391 | padding="SAME") 392 | print('pooling_1 shape: %s' %pooling_1.shape) 393 | 394 | return tf.contrib.layers.flatten(pooling_1) 395 | 396 | def CNN_3d_change(x, out_channels_0, out_channels_1, add_relu=True): 397 | '''Add a 3d convlution layer with relu and max pooling layer. 398 | 399 | Args: 400 | x: a tensor with shape [batch, in_depth, in_height, in_width, in_channels] 401 | out_channels: a number 402 | filter_size: a number 403 | pooling_size: a number 404 | 405 | Returns: 406 | a flattened tensor with shape [batch, num_features] 407 | 408 | Raises: 409 | ''' 410 | in_channels = x.shape[-1] 411 | weights_0 = tf.get_variable( 412 | name='filter_0', 413 | shape=[3, 3, 3, in_channels, out_channels_0], 414 | dtype=tf.float32, 415 | #initializer=tf.random_normal_initializer(0, 0.05)) 416 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 417 | bias_0 = tf.get_variable( 418 | name='bias_0', 419 | shape=[out_channels_0], 420 | dtype=tf.float32, 421 | initializer=tf.zeros_initializer()) 422 | #Todo 423 | g_0 = tf.get_variable(name='scale_0', 424 | shape = [out_channels_0], 425 | dtype=tf.float32, 426 | initializer=tf.ones_initializer()) 427 | weights_0 = tf.reshape(g_0, [1, 1, 1, out_channels_0]) * tf.nn.l2_normalize(weights_0, [0, 1, 2]) 428 | 429 | conv_0 = tf.nn.conv3d(x, weights_0, strides=[1, 1, 1, 1, 1], padding="VALID") 430 | print('conv_0 shape: %s' %conv_0.shape) 431 | conv_0 = conv_0 + bias_0 432 | ####### 433 | ''' 434 | with tf.variable_scope('layer_0'): 435 | conv_0 = op.layer_norm(conv_0, axis=[1, 2, 3, 4]) 436 | print('layer_norm in cnn') 437 | ''' 438 | if add_relu: 439 | conv_0 = tf.nn.elu(conv_0) 440 | 441 | pooling_0 = tf.nn.max_pool3d( 442 | conv_0, 443 | ksize=[1, 2, 3, 3, 1], 444 | strides=[1, 2, 3, 3, 1], 445 | padding="VALID") 446 | print('pooling_0 shape: %s' %pooling_0.shape) 447 | 448 | #layer_1 449 | weights_1 = tf.get_variable( 450 | name='filter_1', 451 | shape=[2, 2, 2, out_channels_0, out_channels_1], 452 | dtype=tf.float32, 453 | initializer=tf.random_uniform_initializer(-0.01, 0.01)) 454 | 455 | bias_1 = tf.get_variable( 456 | name='bias_1', 457 | shape=[out_channels_1], 458 | dtype=tf.float32, 459 | initializer=tf.zeros_initializer()) 460 | 461 | g_1 = tf.get_variable(name='scale_1', 462 | shape = [out_channels_1], 463 | dtype=tf.float32, 464 | initializer=tf.ones_initializer()) 465 | weights_1 = tf.reshape(g_1, [1, 1, 1, out_channels_1]) * tf.nn.l2_normalize(weights_1, [0, 1, 2]) 466 | 467 | conv_1 = tf.nn.conv3d(pooling_0, weights_1, strides=[1, 1, 1, 1, 1], padding="VALID") 468 | print('conv_1 shape: %s' %conv_1.shape) 469 | conv_1 = conv_1 + bias_1 470 | #with tf.variable_scope('layer_1'): 471 | # conv_1 = op.layer_norm(conv_1, axis=[1, 2, 3, 4]) 472 | 473 | if add_relu: 474 | conv_1 = tf.nn.elu(conv_1) 475 | 476 | pooling_1 = tf.nn.max_pool3d( 477 | conv_1, 478 | ksize=[1, 3, 3, 3, 1], 479 | strides=[1, 3, 3, 3, 1], 480 | padding="VALID") 481 | print('pooling_1 shape: %s' %pooling_1.shape) 482 | 483 | return tf.contrib.layers.flatten(pooling_1) 484 | 485 | def RNN_last_state(x, lengths, hidden_size): 486 | '''encode x with a gru cell and return the last state. 487 | 488 | Args: 489 | x: a tensor with shape [batch, time, dimension] 490 | length: a tensor with shape [batch] 491 | 492 | Return: 493 | a tensor with shape [batch, hidden_size] 494 | 495 | Raises: 496 | ''' 497 | cell = tf.nn.rnn_cell.GRUCell(hidden_size) 498 | outputs, last_states = tf.nn.dynamic_rnn(cell, x, lengths, dtype=tf.float32) 499 | return outputs, last_states 500 | 501 | 502 | --------------------------------------------------------------------------------