├── examples ├── data │ ├── templates │ │ ├── neo4j_config.txt │ │ ├── utter_search.txt │ │ └── text_templates.txt │ └── build_upload.py ├── train │ ├── run_train.sh │ ├── run_train.py │ └── train_config.json └── deploy │ ├── run_deploy.sh │ └── run_deploy.py ├── requirements.txt ├── keras_bert_kbqa ├── __init__.py ├── predict.py ├── utils │ ├── __init__.py │ ├── callbacks.py │ ├── decoder.py │ ├── metrics.py │ ├── graph_builder.py │ ├── processor.py │ ├── tokenizer.py │ ├── models.py │ └── bert.py ├── helper.py └── train.py ├── LICENSE ├── README_ZH.md └── README.md /examples/data/templates/neo4j_config.txt: -------------------------------------------------------------------------------- 1 | { 2 | "host": "http://192.168.110.8:12251", 3 | "auth": ("neo4j", "liushaoweihua.") 4 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask == 1.1.1 2 | keras == 2.3.1 3 | numpy == 1.18.1 4 | loguru == 0.4.1 5 | requests == 2.22.0 6 | termcolor == 1.1.0 7 | tensorflow == 1.15.2 8 | keras_contrib == 2.0.8 -------------------------------------------------------------------------------- /examples/train/run_train.sh: -------------------------------------------------------------------------------- 1 | CONFIG_FILE="train_config.json" 2 | SAVE_PATH="../models" 3 | 4 | python run_train.py \ 5 | -config ${CONFIG_FILE} \ 6 | -save_path "../models" \ 7 | -device_map "2" 8 | -------------------------------------------------------------------------------- /keras_bert_kbqa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: __init__.py 8 | @Time: 2020/3/9 10:13 AM 9 | """ 10 | 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function -------------------------------------------------------------------------------- /examples/deploy/run_deploy.sh: -------------------------------------------------------------------------------- 1 | MODEL_DIR="../models" 2 | DATA_DIR="../data" 3 | 4 | python run_deploy.py \ 5 | -model_configs ${MODEL_DIR}/model_configs.json \ 6 | -log_path deploy_log/ \ 7 | -prior_checks ${DATA_DIR}/data/prior_check.txt \ 8 | -database ${DATA_DIR}/data/database.txt \ 9 | -utter_search ${DATA_DIR}/templates/utter_search.txt \ 10 | -device_map "cpu" 11 | -------------------------------------------------------------------------------- /examples/data/templates/utter_search.txt: -------------------------------------------------------------------------------- 1 | { 2 | "豆瓣评分": "','.join([str(item['rate']) for item in app.database if item['title']==\"{}\"])", 3 | "演员有谁": "','.join([item['actor'] for item in app.database if item['title']==\"{}\"][0])", 4 | "演员的作品有什么": "','.join([item['title'] for item in app.database if \"{}\" in item['actor']])", 5 | "电影类型是什么": "','.join([item['category'] for item in app.database if item['title']==\"{}\"][0])", 6 | "电影语言是什么": "','.join([item['language'] for item in app.database if item['title']==\"{}\"][0])", 7 | "电影上映时间是什么": "','.join([str(item['showtime']) for item in app.database if item['title']==\"{}\"])" 8 | } 9 | -------------------------------------------------------------------------------- /keras_bert_kbqa/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: predict.py 8 | @Time: 2020/3/16 3:39 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | from keras.models import load_model 17 | from keras.utils import CustomObjectScope 18 | from .utils import custom_objects 19 | 20 | 21 | def predict(model_path): 22 | """模型预测流程 23 | """ 24 | # 环境设置 25 | with CustomObjectScope(custom_objects): 26 | model = load_model(model_path) 27 | 28 | return model 29 | -------------------------------------------------------------------------------- /examples/train/run_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: models.py 8 | @Time: 2020/3/13 03:58 PM 9 | """ 10 | 11 | 12 | import sys 13 | sys.path.append("../..") 14 | 15 | from keras_bert_kbqa.train import train 16 | from keras_bert_kbqa.helper import train_args_parser 17 | 18 | 19 | def run_train(): 20 | args = train_args_parser() 21 | if True: 22 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) 23 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) 24 | train(args=args) 25 | 26 | 27 | if __name__ == "__main__": 28 | run_train() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2 | 3 |   Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 |    5 |   The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 7 | -------------------------------------------------------------------------------- /examples/train/train_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "train_data": "../data/data/train_data.txt", 4 | "dev_data": "../data/data/dev_data.txt", 5 | "tag_padding": "X", 6 | "max_len": 15 7 | }, 8 | "bert": { 9 | "bert_config": "/home1/liushaoweihua/pretrained_lm/albert_small_chinese/albert_config.json", 10 | "bert_checkpoint": "/home1/liushaoweihua/pretrained_lm/albert_small_chinese/albert_model.ckpt", 11 | "bert_vocab": "/home1/liushaoweihua/pretrained_lm/albert_small_chinese/vocab.txt", 12 | "albert": "True" 13 | }, 14 | "model": { 15 | "lr": 1e-4, 16 | "batch_size": 256, 17 | "max_epochs": 256, 18 | "early_stop_patience": 10, 19 | "reduce_lr_patience": 3, 20 | "reduce_lr_factor": 0.5, 21 | "all_train_threshold": 0.99, 22 | "clf_configs": { 23 | "clf_type": "dense", 24 | "dense_units": 128 25 | }, 26 | "ner_configs": { 27 | "ner_type": "idcnn", 28 | "filters": 128, 29 | "kernel_size": 3, 30 | "blocks": 4 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: __init__.py 8 | @Time: 2020/3/9 10:13 AM 9 | """ 10 | 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | 17 | from keras.utils import get_custom_objects 18 | from .bert import MultiHeadAttention, LayerNormalization, PositionEmbedding, FeedForward, EmbeddingDense 19 | from .models import ExpandDims, MultiLossLayer 20 | from .metrics import CrfAcc, CrfLoss 21 | from .models import CRF, gelu_erf, gelu_tanh 22 | 23 | 24 | custom_objects = { 25 | "MultiHeadAttention": MultiHeadAttention, 26 | "LayerNormalization": LayerNormalization, 27 | "PositionEmbedding": PositionEmbedding, 28 | "FeedForward": FeedForward, 29 | "EmbeddingDense": EmbeddingDense, 30 | "ExpandDims": ExpandDims, 31 | "MultiLossLayer": MultiLossLayer, 32 | "CrfAcc": CrfAcc, 33 | "CrfLoss": CrfLoss, 34 | "CRF": CRF, 35 | "gelu_erf": gelu_erf, 36 | "gelu_tanh": gelu_tanh, 37 | "gelu": gelu_tanh, 38 | } 39 | 40 | 41 | get_custom_objects().update(custom_objects) -------------------------------------------------------------------------------- /examples/data/templates/text_templates.txt: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "豆瓣评分": [ 4 | "{title}的评分高吗", 5 | "{title}的评分有多少", 6 | "{title}好看吗", 7 | "你觉得{title}怎么样", 8 | "你知道{title}的豆瓣评分怎么样" 9 | ], 10 | "演员有谁": [ 11 | "{title}的主演都有谁啊", 12 | "你知道都有哪些演员演了{title}", 13 | "{title}的演员都有谁", 14 | "{title}是谁的作品啊", 15 | "{title}是谁拍的,这么烂", 16 | "都有谁参演了{title}" 17 | ], 18 | "演员的作品有什么": [ 19 | "{actor}拍了几部什么片啊", 20 | "{actor}的经典作品有些啥", 21 | "哪些电影是{actor}拍的", 22 | "{actor}的作品有什么", 23 | "{actor}演过哪些电影啊", 24 | "有哪些片是{actor}参演的呀", 25 | "{actor}都演过什么片" 26 | ], 27 | "电影类型是什么": [ 28 | "{title}是什么类型的电影", 29 | "{title}是喜剧片吗", 30 | "{title}是一部什么样的电影", 31 | "{title}属于哪种类型" 32 | ], 33 | "电影语言是什么": [ 34 | "{title}的语言是什么", 35 | "{title}是中文电影吗", 36 | "{title}有英文配音版本吗" 37 | ], 38 | "电影上映时间是什么": [ 39 | "{title}是最近刚上的电影吗", 40 | "{title}上映多久了", 41 | "{title}啥时候上映的呀" 42 | ] 43 | }, 44 | "dev": { 45 | "豆瓣评分": [ 46 | "{title}是个好电影吗" 47 | ], 48 | "演员有谁": [ 49 | "都几个谁演了{title}啊" 50 | ], 51 | "演员的作品有什么": [ 52 | "{actor}演过的电影都有啥" 53 | ], 54 | "电影类型是什么": [ 55 | "{title}的电影类型是什么" 56 | ], 57 | "电影语言是什么": [ 58 | "{title}是闽南语电影吗" 59 | ], 60 | "电影上映时间是什么": [ 61 | "{title}啥时候上电影院的" 62 | ] 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: callbacks.py 8 | @Time: 2020/3/9 10:16 AM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import codecs 17 | import numpy as np 18 | from keras.utils import to_categorical 19 | from keras.callbacks import Callback, ReduceLROnPlateau, EarlyStopping 20 | from .decoder import Viterbi 21 | 22 | 23 | class TaskSwitch(Callback): 24 | """模型开关,达到阈值时,同时训练所有层参数 25 | """ 26 | def __init__(self, monitor, threshold): 27 | super(TaskSwitch, self).__init__() 28 | self.monitor = monitor 29 | self.threshold = threshold 30 | if "acc" in self.monitor: 31 | self.monitor_op = np.greater 32 | elif "loss" in self.monitor: 33 | self.monitor_op = np.less 34 | else: 35 | raise ValueError("monitor is not either 'acc' or 'loss'") 36 | 37 | def on_epoch_end(self, epoch, logs=None): 38 | if self.monitor_op(logs.get(self.monitor), self.threshold): 39 | self.model.stop_training = True 40 | 41 | 42 | def KbqaCallbacks(best_fit_params): 43 | """Kbqa模型训练的指标回调函数 44 | """ 45 | callbacks = [] 46 | early_stopping = EarlyStopping( 47 | monitor="val_loss", 48 | patience=best_fit_params.get("early_stop_patience"), 49 | verbose=1) 50 | reduce_lr_on_plateau = ReduceLROnPlateau( 51 | monitor="val_loss", 52 | factor=best_fit_params.get("reduce_lr_factor"), 53 | patience=best_fit_params.get("reduce_lr_patience"), 54 | verbose=1) 55 | callbacks.extend([early_stopping, reduce_lr_on_plateau]) 56 | return callbacks 57 | -------------------------------------------------------------------------------- /keras_bert_kbqa/helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: help.py 8 | @Time: 2020/3/11 10:50 AM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import os 17 | import sys 18 | import argparse 19 | 20 | 21 | if os.name == "nt": 22 | bert_dir = "" 23 | root_dir = "" 24 | else: 25 | bert_dir = "/home/liushaoweihua/pretrained_lm/bert_chinese/" 26 | root_dir = "/home/projects/kbqa/tools/Keras-Bert-Kbqa/" 27 | 28 | 29 | def train_args_parser(): 30 | 31 | parser = argparse.ArgumentParser() 32 | 33 | config_group = parser.add_argument_group( 34 | "Config File Paths", "Config all train information") 35 | config_group.add_argument("-config", 36 | type=str, 37 | required=True, 38 | help="(REQUIRED) train_config.json") 39 | 40 | save_group = parser.add_argument_group( 41 | "Model Output Paths", "Config the output paths for model") 42 | save_group.add_argument("-save_path", 43 | type=str, 44 | default=os.path.join(root_dir, "models"), 45 | help="Model output paths") 46 | 47 | action_group = parser.add_argument_group( 48 | "Action Configs", "Config the actions during running") 49 | action_group.add_argument("-device_map", 50 | type=str, 51 | default="cpu", 52 | help="Use CPU/GPU to train. If use CPU, then 'cpu'. " 53 | "If use GPU, then assign the devices, such as '0'. Default is 'cpu'") 54 | 55 | return parser.parse_args() 56 | 57 | 58 | if __name__ == "__main__": 59 | parser_type = sys.argv[1].lower() 60 | if parser_type == "train": 61 | parser = train_args_parser() 62 | else: 63 | raise ValueError("Parser type should be 'train'") 64 | parser.parse_args() 65 | -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: decoder.py 8 | @Time: 2020/3/9 1:46 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | class Viterbi: 20 | 21 | def __init__(self, model, numb_tags): 22 | self.model = model 23 | self.numb_tags = numb_tags 24 | self._get_crf_trans() 25 | 26 | def _get_crf_trans(self): 27 | """CRF转移矩阵 28 | """ 29 | self.crf_trans = {} 30 | crf_weights = self.model.layers[-1].get_weights()[0] 31 | for i in range(self.numb_tags): 32 | for j in range(self.numb_tags): 33 | self.crf_trans[str(i) + "-" + str(j)] = crf_weights[i, j] 34 | 35 | def _viterbi(self, nodes): 36 | """生成路径表 37 | """ 38 | paths = nodes[0] 39 | for l in range(1, len(nodes)): 40 | paths_old, paths = paths, {} 41 | for n, ns in nodes[l].items(): 42 | max_path, max_score = "", -1e10 43 | for p, ps in paths_old.items(): 44 | score = ns + ps + self.crf_trans[p.split("-")[-1] + "-" + str(n)] 45 | if score > max_score: 46 | max_path, max_score = p + "-" + n, score 47 | paths[max_path] = max_score 48 | 49 | return self._max_in_dict(paths) 50 | 51 | def _max_in_dict(self, paths): 52 | """获取路径表中的最大值 53 | """ 54 | paths_inv = {v: k for k, v in paths.items()} 55 | 56 | return paths_inv[max(paths_inv)] 57 | 58 | def decode(self, data): 59 | """解码过程 60 | """ 61 | preds = np.array(self.model.predict(data)) 62 | decodes = [] 63 | for pred in preds: 64 | nodes = [dict([[str(idx), item] for idx, item in enumerate(term)]) for term in pred] 65 | decodes.append([int(item) for item in self._viterbi(nodes).split("-")]) 66 | 67 | return np.array(decodes) -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: metrics.py 8 | @Time: 2020/3/9 1:48 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import keras.backend as K 17 | 18 | 19 | class CrfAcc: 20 | """训练过程中显示的CRF精度 21 | """ 22 | def __init__(self, tag_to_id, mask_tag=None): 23 | self.tag_to_id = tag_to_id 24 | self.mask_tag_id = tag_to_id.get(mask_tag) 25 | self.numb_tags = len(tag_to_id) 26 | 27 | def crf_accuracy(self, y_true, y_pred): 28 | """计算viterbi-crf精度 29 | """ 30 | crf, idx = y_pred._keras_history[:2] 31 | X = crf._inbound_nodes[idx].input_tensors[0] 32 | y_pred = crf.viterbi_decoding(X, None) 33 | return self._get_accuracy(y_true, y_pred, crf.sparse_target) 34 | 35 | def _get_accuracy(self, y_true, y_pred, sparse_target=False): 36 | y_pred = K.argmax(y_pred, -1) 37 | mask = K.cast(1. - K.one_hot( 38 | K.squeeze(K.cast(y_true, "int32"), axis=-1), 39 | num_classes=self.numb_tags)[:, :, self.mask_tag_id], K.floatx()) 40 | if sparse_target: 41 | y_true = K.cast(y_true[:, :, 0], K.dtype(y_pred)) 42 | else: 43 | y_true = K.argmax(y_true, -1) 44 | judge = K.cast(K.equal(y_true, y_pred), K.floatx()) 45 | if self.mask_tag_id is None: 46 | return K.mean(judge) 47 | else: 48 | return K.sum(judge * mask) / K.sum(mask) 49 | 50 | class CrfLoss: 51 | """训练过程中显示的CRF损失 52 | """ 53 | def __init__(self, tag_to_id, mask_tag=None): 54 | self.tag_to_id = tag_to_id 55 | self.mask_tag_id = tag_to_id.get(mask_tag) 56 | self.numb_tags = len(tag_to_id) 57 | 58 | def crf_loss(self, y_true, y_pred): 59 | """计算viterbi-crf损失 60 | """ 61 | crf, idx = y_pred._keras_history[:2] 62 | if crf.sparse_target: 63 | y_true = K.one_hot(K.cast(y_true[:, :, 0], "int32"), crf.units) 64 | X = crf._inbound_nodes[idx].input_tensors[0] 65 | mask = K.cast(1. - y_true[:, :, self.mask_tag_id], K.floatx()) if self.mask_tag_id else None 66 | nloglik = crf.get_negative_log_likelihood(y_true, X, mask) 67 | return nloglik -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/graph_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: graph_builder.py 8 | @Time: 2020/3/9 1:58 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import json 17 | import codecs 18 | from copy import deepcopy 19 | from py2neo import Graph, Node, Relationship, Subgraph, NodeMatcher, Schema 20 | 21 | 22 | def from_json(path, use_properties=None): 23 | with codecs.open(path, "r", encoding="utf-8") as f: 24 | data = json.load(f) 25 | if use_properties: 26 | data = [{prop:datum[prop] for prop in use_properties} for datum in data] 27 | return data 28 | 29 | 30 | class Creator: 31 | 32 | def __init__(self, graph): 33 | self.graph = Graph(graph.get("host"), auth=graph.get("auth")) 34 | 35 | def __call__(self, *args, **kwargs): 36 | raise NotImplementedError("__call__ function not defined.") 37 | 38 | def _create(self): 39 | raise NotImplementedError("_create function not defined.") 40 | 41 | 42 | class NodeCreator(Creator): 43 | 44 | def __init__(self, labels): 45 | super(NodeCreator, self).__init__() 46 | self.labels = labels 47 | self.node_template = Node(labels) 48 | self.schema = Schema(self.graph) 49 | 50 | def __call__(self, data, indexes=None, *args, **kwargs): 51 | assert isinstance(data, list), "except data to be list, but got %s" % type(data) 52 | nodes = [] 53 | for datum in data: 54 | new_node = deepcopy(self.node_template) 55 | for attr in datum: 56 | new_node[attr] = datum[attr] 57 | nodes.append(new_node) 58 | nodes = Subgraph(nodes) 59 | self._create(nodes, indexes) 60 | 61 | def _create(self, nodes, indexes): 62 | self.graph.create(nodes) 63 | if indexes: 64 | self.schema.create_index(self.labels, *indexes) 65 | 66 | 67 | class RelationCreator(Creator): 68 | """Hint:Py2neo的matcher太慢了,原生的cypher语句也很慢,最好用neo4j的import工具 69 | Reference:https://www.zhihu.com/question/45401120?sort=created 70 | """ 71 | def __init__(self, left_node_label, relations, right_node_label): 72 | super(RelationCreator, self).__init__() 73 | self.matcher = NodeMatcher(self.graph) 74 | if isinstance(relations, str): 75 | relations = [relations] 76 | assert 1 <= len(relations) <= 2, "except len(relations) to be either 1 or 2, but got %s" % len(relations) 77 | self.relations = relations 78 | self.relation_type = "unidirectional" if len(relations) == 1 else "bidirectional" 79 | self.left_node_label = left_node_label 80 | self.right_node_label = right_node_label 81 | 82 | def _node_searcher(self, *label, **properties): 83 | return self.matcher.match(*label, **properties).__iter__() 84 | 85 | def __call__(self, data, *args, **kwargs): 86 | assert isinstance(data, list), "except data to be list, but got %s" % type(data) 87 | relations = [] 88 | for datum in data: 89 | left_node = self._node_searcher(self.left_node_label, **datum["left_node"]) 90 | right_node = self._node_searcher(self.right_node_label, **datum["right_node"]) 91 | relations.extend([Relationship(l, self.relations[0], r) for l in left_node for r in right_node]) 92 | if self.relation_type == "bidirectional": 93 | relations.extend([Relationship(r, self.relations[1], l) for l in left_node for r in right_node]) 94 | relations = Subgraph(relations) 95 | self._create(relations) 96 | 97 | def _create(self, relations): 98 | self.graph.create(relations) -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/processor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: processor.py 8 | @Time: 2020/3/3 10:37 AM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | import json 16 | import codecs 17 | import numpy as np 18 | from keras.utils import to_categorical 19 | from .tokenizer import Tokenizer 20 | 21 | 22 | class Processor: 23 | 24 | def __init__(self, data_path, dict_path, tag_padding=None): 25 | self.data = self._load_data(data_path) 26 | if tag_padding is not None: 27 | self.tag_padding = tag_padding 28 | else: 29 | self.tag_padding = "X" 30 | self._load_tags() 31 | self._load_labels() 32 | self.tokenizer = Tokenizer(dict_path) 33 | 34 | def _load_data(self, path): 35 | """加载数据集 36 | """ 37 | with codecs.open(path, "r", encoding="utf-8") as f: 38 | data = json.load(f) 39 | 40 | return data 41 | 42 | def _load_tags(self): 43 | """tag转换为对应id 44 | """ 45 | tags = set() 46 | for item in self.data: 47 | for tag in item[2].split(" "): 48 | if len(tag) == 1: 49 | tags.add(tag) 50 | else: 51 | entity_type = tag.split("-")[-1] 52 | tags.add("B-%s" % entity_type) 53 | tags.add("I-%s" % entity_type) 54 | tags.add("S-%s" % entity_type) 55 | tags = list(tags) 56 | self.tag_to_id = {tags[i]: i for i in range(len(tags))} 57 | if self.tag_padding not in self.tag_to_id: 58 | self.tag_to_id[self.tag_padding] = len(self.tag_to_id) 59 | else: 60 | raise ValueError("tag_padding %s already exists" % self.tag_padding) 61 | self.id_to_tag = {v: k for k, v in self.tag_to_id.items()} 62 | self.numb_tags = len(self.tag_to_id) 63 | 64 | def _load_labels(self): 65 | """label转换为对应id 66 | """ 67 | labels = set() 68 | for item in self.data: 69 | labels.add(item[1]) 70 | labels = list(labels) 71 | self.label_to_id = {labels[i]: i for i in range(len(labels))} 72 | self.id_to_label = {v: k for k, v in self.label_to_id.items()} 73 | self.numb_labels = len(self.label_to_id) 74 | 75 | def process(self, path, max_len): 76 | """适配于Bert/Albert的训练数据生成 77 | """ 78 | data = self._load_data(path) 79 | np.random.shuffle(data) 80 | origin_texts, origin_labels, origin_tags = np.stack(data, axis=-1) 81 | tokens, segs = [], [] 82 | for text in origin_texts: 83 | token, seg = self.tokenizer.encode(text, first_length=max_len) 84 | tokens.append(token) 85 | segs.append(seg) 86 | labels = self._norm_labels(origin_labels) 87 | tags = self._pad_and_truncate(origin_tags, max_len - 2) 88 | 89 | return tokens, segs, labels, tags 90 | 91 | def _norm_labels(self, origin_labels): 92 | """将label转化为指定数据形式 93 | """ 94 | return to_categorical([self.label_to_id[i] for i in origin_labels], self.numb_labels) 95 | 96 | def _pad_and_truncate(self, origin_tags, max_len): 97 | """填充或截断至指定长度 98 | """ 99 | tags = [] 100 | for tag in origin_tags: 101 | tag_len = len(tag.split(" ")) 102 | if tag_len >= max_len: 103 | tags.append([self.tag_padding] + 104 | tag.split(" ")[:max_len] + 105 | [self.tag_padding]) 106 | else: 107 | tags.append([self.tag_padding] + 108 | tag.split(" ") + 109 | [self.tag_padding] * (max_len - tag_len + 1)) 110 | tags = np.expand_dims( 111 | [[self.tag_to_id[item] for item in term[1:]] for term in tags], 112 | axis=-1) 113 | 114 | return tags -------------------------------------------------------------------------------- /keras_bert_kbqa/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: train.py 8 | @Time: 2020/3/9 5:58 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import os 17 | import json 18 | import keras 19 | import codecs 20 | import pickle 21 | import numpy as np 22 | from .utils.processor import Processor 23 | from .utils.models import KbqaModel 24 | from .utils.callbacks import KbqaCallbacks, TaskSwitch 25 | from .utils.metrics import CrfAcc, CrfLoss 26 | 27 | 28 | def train(args): 29 | """模型训练流程 30 | """ 31 | # 环境设置 32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_map if args.device_map != "cpu" else "" 33 | if not os.path.exists(args.save_path): 34 | os.makedirs(args.save_path) 35 | with codecs.open(args.config, "r", encoding="utf-8") as f: 36 | args.config = json.load(f) 37 | args.data_params = args.config.get("data") 38 | args.bert_params = args.config.get("bert") 39 | args.model_params = args.config.get("model") 40 | # 数据准备 41 | processor = Processor(args.data_params.get("train_data"), args.bert_params.get("bert_vocab"), args.data_params.get("tag_padding")) 42 | train_tokens, train_segs, train_labels, train_tags = processor.process(args.data_params.get("train_data"), args.data_params.get("max_len")) 43 | train_x = [np.array(train_tokens), np.array(train_segs), np.array(train_labels), np.array(train_tags)] 44 | train_y = None 45 | if args.data_params.get("dev_data") is not None: 46 | dev_tokens, dev_segs, dev_labels, dev_tags = processor.process(args.data_params.get("dev_data"), args.data_params.get("max_len")) 47 | devs = [[np.array(dev_tokens), np.array(dev_segs), np.array(dev_labels), np.array(dev_tags)], None] 48 | else: 49 | devs = None 50 | # 模型准备 51 | model = KbqaModel( 52 | bert_config=args.bert_params.get("bert_config"), 53 | bert_checkpoint=args.bert_params.get("bert_checkpoint"), 54 | albert=args.bert_params.get("albert"), 55 | clf_configs=args.model_params.get("clf_configs"), 56 | ner_configs=args.model_params.get("ner_configs"), 57 | max_len=args.data_params.get("max_len"), 58 | numb_labels=processor.numb_labels, 59 | numb_tags=processor.numb_tags, 60 | tag_to_id=processor.tag_to_id, 61 | tag_padding=args.data_params.get("tag_padding")) 62 | model.build() 63 | crf_accuracy = CrfAcc(processor.tag_to_id, args.data_params.get("tag_padding")).crf_accuracy 64 | crf_loss = CrfLoss(processor.tag_to_id, args.data_params.get("tag_padding")).crf_loss 65 | # 模型基础信息 66 | bert_type = "ALBERT" if args.bert_params.get("albert") is "True" else "BERT" 67 | clf_type = args.model_params.get("clf_configs").get("clf_type").upper() 68 | ner_type = args.model_params.get("ner_configs").get("ner_type").upper() + "-CRF" 69 | model_save_path = os.path.abspath( 70 | os.path.join(args.save_path, "%s-%s-%s.h5" % (bert_type, clf_type, ner_type))) 71 | # 训练较难任务 72 | model.hard_train_model.compile( 73 | optimizer=keras.optimizers.Adam(lr=args.model_params.get("lr"), beta_1=0.9, beta_2=0.999, epsilon=1e-8), 74 | loss=crf_loss, 75 | metrics=[crf_accuracy]) 76 | hard_train_model_callbacks = KbqaCallbacks({ 77 | "early_stop_patience": args.model_params.get("early_stop_patience"), 78 | "reduce_lr_patience": args.model_params.get("reduce_lr_patience"), 79 | "reduce_lr_factor": args.model_params.get("reduce_lr_factor")}) 80 | hard_train_model_callbacks.append(TaskSwitch("val_crf_accuracy", args.model_params.get("all_train_threshold"))) 81 | model.hard_train_model.fit( 82 | x=train_x[:2], 83 | y=train_x[3], 84 | batch_size=args.model_params.get("batch_size"), 85 | epochs=args.model_params.get("max_epochs"), 86 | validation_data=[[devs[0][0], devs[0][1]], devs[0][3]], 87 | callbacks=hard_train_model_callbacks) 88 | # 训练所有任务 89 | model.train_all_tasks() 90 | model.full_train_model.compile( 91 | optimizer=model.hard_train_model.optimizer) 92 | model.full_train_model.fit( 93 | x=train_x, 94 | y=train_y, 95 | batch_size=args.model_params.get("batch_size"), 96 | epochs=args.model_params.get("max_epochs"), 97 | validation_data=devs, 98 | callbacks=KbqaCallbacks({ 99 | "early_stop_patience": args.model_params.get("early_stop_patience"), 100 | "reduce_lr_patience": args.model_params.get("reduce_lr_patience"), 101 | "reduce_lr_factor": args.model_params.get("reduce_lr_factor")})) 102 | # 保存信息 103 | with codecs.open(os.path.join(args.save_path, "label_to_id.pkl"), "wb") as f: 104 | pickle.dump(processor.label_to_id, f) 105 | with codecs.open(os.path.join(args.save_path, "id_to_label.pkl"), "wb") as f: 106 | pickle.dump(processor.id_to_label, f) 107 | with codecs.open(os.path.join(args.save_path, "tag_to_id.pkl"), "wb") as f: 108 | pickle.dump(processor.tag_to_id, f) 109 | with codecs.open(os.path.join(args.save_path, "id_to_tag.pkl"), "wb") as f: 110 | pickle.dump(processor.id_to_tag, f) 111 | model_configs = { 112 | "tag_padding": args.data_params.get("tag_padding"), 113 | "max_len": args.data_params.get("max_len"), 114 | "bert_vocab": os.path.abspath(args.bert_params.get("bert_vocab")), 115 | "model_path": model_save_path, 116 | "id_to_label": os.path.abspath(os.path.join(args.save_path, "id_to_label.pkl")), 117 | "id_to_tag": os.path.abspath(os.path.join(args.save_path, "id_to_tag.pkl"))} 118 | with codecs.open(os.path.join(args.save_path, "model_configs.json"), "w") as f: 119 | json.dump(model_configs, f, ensure_ascii=False, indent=4) 120 | model.pred_model.save(model_save_path) 121 | -------------------------------------------------------------------------------- /examples/data/build_upload.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: build_upload.py 8 | @Time: 2020/3/13 5:50 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import sys 17 | sys.path.append("../..") 18 | 19 | import os 20 | import json 21 | import codecs 22 | import numpy as np 23 | # from keras_bert_kbqa.utils.graph_builder import NodeCreator, RelationCreator 24 | 25 | 26 | def generate(sample_size, templates, fills, data_type): 27 | """填充式数据生成 28 | """ 29 | results = [] 30 | data = templates.get(data_type) 31 | for label in data: 32 | for text_template in data[label]: 33 | suffix = text_template.split("{")[1].split("}")[0] 34 | tag_template = "O "*(len(text_template.split(suffix)[0])-1) + "{" + suffix + "}" + " O"*(len(text_template.split(suffix)[1])-1) 35 | choose_fills = np.random.choice(fills[suffix], sample_size) 36 | tags = ["B-"+suffix+(" I-"+suffix)*(len(item)-1) if len(item) > 1 else "S-"+suffix for item in choose_fills] 37 | for choose_fill, tag in zip(choose_fills, tags): 38 | results.append([text_template.replace("{"+suffix+"}", choose_fill), label, tag_template.replace("{"+suffix+"}", tag)]) 39 | return results 40 | 41 | 42 | def build(raw_data_path, use_attrs, template_path, save_path, train_sample_size_per_class, dev_sample_size_per_class): 43 | """建立训练集与测试集 44 | """ 45 | with codecs.open(raw_data_path, "r", encoding="utf-8") as f: 46 | origin_data = json.load(f) 47 | data = [] 48 | for datum in origin_data: 49 | new_datum = {} 50 | for attr in use_attrs: 51 | new_datum[attr] = datum.get(attr) or [] 52 | data.append(new_datum) 53 | fills = {attr: [] for attr in use_attrs} 54 | for datum in data: 55 | for attr in datum: 56 | if isinstance(datum[attr], list): 57 | fills[attr].extend(datum[attr]) 58 | else: 59 | fills[attr].append(str(datum[attr])) 60 | with codecs.open(template_path, "r", encoding="utf-8") as f: 61 | templates = json.load(f) 62 | train_data = generate(train_sample_size_per_class, templates, fills, "train") 63 | dev_data = generate(dev_sample_size_per_class, templates, fills, "dev") 64 | save_path = os.path.abspath(save_path) 65 | if not os.path.exists(save_path): 66 | os.makedirs(save_path) 67 | with codecs.open(os.path.join(save_path, "train_data.txt"), "w", encoding="utf-8") as f: 68 | json.dump(train_data, f, ensure_ascii=False, indent=4) 69 | with codecs.open(os.path.join(save_path, "dev_data.txt"), "w", encoding="utf-8") as f: 70 | json.dump(dev_data, f, ensure_ascii=False, indent=4) 71 | with codecs.open(os.path.join(save_path, "database.txt"), "w", encoding="utf-8") as f: 72 | json.dump(data, f, ensure_ascii=False, indent=4) 73 | with codecs.open(os.path.join(save_path, "prior_check.txt"), "w", encoding="utf-8") as f: 74 | json.dump(fills, f, ensure_ascii=False, indent=4) 75 | 76 | 77 | # class Uploader: 78 | # """图数据库内容上传器(有需要的可以用neo4j跑一下试试,预测部分没写) 79 | # """ 80 | # def __init__(self, graph_config_path, raw_data_path): 81 | # self.graph = self._read(graph_config_path) 82 | # self.raw_data = self._read(raw_data_path) 83 | # self.nodes = {} 84 | # self.relations = {} 85 | # 86 | # def _read(self, path): 87 | # with codecs.open(path, "r", encoding="utf-8") as f: 88 | # data = json.load(f) 89 | # return data 90 | # 91 | # def define_node(self, node_name, key_property, properties): 92 | # assert isinstance(properties, list), "param `properties` should be type list" 93 | # nodes = [{item: term[item] for item in properties} for term in self.raw_data] 94 | # self.nodes[node_name] = {"key_property": key_property, "properties": properties} 95 | # self._build_node(node_name, nodes, [key_property]) 96 | # return nodes 97 | # 98 | # def define_relation(self, left_node_name, relation_name, right_node_name, build=True): 99 | # assert left_node_name in self.nodes, "`left_node_name` not found" 100 | # assert right_node_name in self.nodes, "`right_node_name` not found" 101 | # left_key_property = self.nodes[left_node_name]["key_property"] 102 | # right_key_property = self.nodes[right_node_name]["key_property"] 103 | # relations = [ 104 | # { 105 | # "left_node": { 106 | # left_key_property: item[left_key_property] 107 | # }, 108 | # "right_node": { 109 | # right_key_property: item[right_key_property] 110 | # } 111 | # } for item in self.raw_data] 112 | # self.relations[relation_name] = { 113 | # "left_node_name": left_node_name, 114 | # "left_key_property": left_key_property, 115 | # "right_node_name": right_node_name, 116 | # "right_key_property": right_key_property 117 | # } 118 | # if build: 119 | # self._build_relation(left_node_name, relation_name, right_node_name, relations) 120 | # return relations 121 | # 122 | # def define_twin_relations(self, left_node_name, relation_names, right_node_name): 123 | # assert isinstance(relation_names, list), "`relation_names` should be type list" 124 | # assert len(relation_names) == 2, "length of `relation_names` should be 2" 125 | # relations = self.define_relation(left_node_name, relation_names[0], right_node_name, build=False) 126 | # _ = self.define_relation(right_node_name, relation_names[1], left_node_name, build=False) 127 | # self._build_relation(left_node_name, relation_names, right_node_name, relations) 128 | # 129 | # def _build_node(self, node_name, nodes, key_property): 130 | # NodeCreator(self.graph, node_name)(nodes, indexes=key_property) 131 | # 132 | # def _build_relation(self, left_node_name, relation_names, right_node_name, relations): 133 | # RelationCreator(self.graph, left_node_name, relation_names, right_node_name)(relations) 134 | 135 | 136 | if __name__ == "__main__": 137 | # build train/dev data 138 | raw_data_path = "./origin_data/douban_movies.txt" 139 | use_attrs = ["title", "rate", "actor", "director", "category", "language", "showtime"] 140 | template_path = "./templates/text_templates.txt" 141 | save_path = "data" 142 | train_sample_size_per_class = 50 143 | dev_sample_size_per_class = 5 144 | build(raw_data_path, use_attrs, template_path, save_path, train_sample_size_per_class, dev_sample_size_per_class) 145 | # # upload data to neo4j 146 | # graph_config_path = "../neo4j_config.txt" 147 | # uploader = Uploader(graph_config_path, raw_data_path) 148 | # uploader.define_node( 149 | # node_name="豆瓣_影片", 150 | # key_property="title", 151 | # properties=use_attrs) 152 | # uploader.define_node( 153 | # node_name="豆瓣_导演", 154 | # key_property="director", 155 | # properties=["director"]) 156 | # uploader.define_node( 157 | # node_name="豆瓣_演员", 158 | # key_property="actor", 159 | # properties=["actor"]) 160 | # uploader.define_twin_relations( 161 | # left_node_name="豆瓣_影片", 162 | # relation_names=["导演", "拍摄"], 163 | # right_node_name="豆瓣_导演") 164 | # uploader.define_twin_relations( 165 | # left_node_name="豆瓣_影片", 166 | # relation_names=["演员", "参演"], 167 | # right_node_name="豆瓣_演员") 168 | -------------------------------------------------------------------------------- /examples/deploy/run_deploy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: run_deploy.py 8 | @Time: 2020/3/16 3:33 PM 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import sys 17 | sys.path.append("../..") 18 | 19 | import os 20 | import json 21 | import keras 22 | import codecs 23 | import pickle 24 | import numpy as np 25 | import tensorflow as tf 26 | from loguru import logger 27 | from termcolor import colored 28 | from flask import Flask, Response, request 29 | from keras_bert_kbqa.helper import deploy_args_parser 30 | from keras_bert_kbqa.predict import predict 31 | from keras_bert_kbqa.utils.tokenizer import Tokenizer 32 | 33 | 34 | app = Flask(__name__) 35 | app.model_configs = {} 36 | 37 | 38 | def log_init(log_path): 39 | log_file_path = os.path.join(log_path, "info.log") 40 | err_file_path = os.path.join(log_path, "error.log") 41 | if not os.path.exists(log_path): 42 | os.makedirs(log_path) 43 | logger.add(sys.stderr, format="{time} {level} {message}", 44 | filter="my_module", level="INFO") 45 | logger.add(log_file_path, rotation="12:00", retention="14 days", 46 | encoding="utf-8") 47 | logger.add(err_file_path, rotation="100 MB", retention="14 days", 48 | encoding="utf-8", level="ERROR") 49 | logger.debug("logger initialized") 50 | return logger 51 | 52 | 53 | def run_deploy(): 54 | # 基础配置 55 | args = deploy_args_parser() 56 | if True: 57 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) 58 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_map if args.device_map != "cpu" else "" 60 | sess = tf.Session() 61 | graph = tf.get_default_graph() 62 | keras.backend.set_session(sess) 63 | # 属性设置 64 | with codecs.open(args.model_configs, "r", encoding="utf-8") as f: 65 | model_configs = json.load(f) 66 | tokenizer = Tokenizer(model_configs.get("bert_vocab")) 67 | max_len = model_configs.get("max_len") 68 | tag_padding = model_configs.get("tag_padding") 69 | with codecs.open(model_configs.get("id_to_label"), "rb") as f: 70 | id_to_label = pickle.load(f) 71 | with codecs.open(model_configs.get("id_to_tag"), "rb") as f: 72 | id_to_tag = pickle.load(f) 73 | model = predict(model_configs.get("model_path")) 74 | # 新增属性 75 | app.model_configs["logger"] = log_init(args.log_path) 76 | app.model_configs["tokenizer"] = tokenizer 77 | app.model_configs["max_len"] = max_len 78 | app.model_configs["tag_padding"] = tag_padding 79 | app.model_configs["id_to_label"] = id_to_label 80 | app.model_configs["id_to_tag"] = id_to_tag 81 | app.model_configs["model"] = model 82 | with codecs.open(args.prior_checks, "r", encoding="utf-8") as f: 83 | app.prior_checks = json.load(f) 84 | with codecs.open(args.utter_search, "r", encoding="utf-8") as f: 85 | app.utter_search = json.load(f) 86 | with codecs.open(args.database, "r", encoding="utf-8") as f: 87 | app.database = json.load(f) 88 | return graph, sess 89 | 90 | 91 | graph, sess = run_deploy() 92 | 93 | 94 | def parse(text): 95 | # 编码 96 | text = text[:app.model_configs["max_len"]] 97 | token, seg = app.model_configs["tokenizer"].encode(text, first_length=app.model_configs["max_len"]) 98 | # 解码 99 | # 显存不足时注释掉以下代码 100 | # config = tf.ConfigProto() 101 | # config.gpu_options.allow_growth = True 102 | # keras.backend.tensorflow_backend.set_session(tf.Session(config=config)) 103 | global graph, sess 104 | with graph.as_default(): 105 | keras.backend.set_session(sess) 106 | clf_pred, ner_pred = app.model_configs["model"].predict([[token], [seg]]) 107 | clf_res = app.model_configs["id_to_label"][np.argmax(clf_pred)] 108 | ner_pred = [app.model_configs["id_to_tag"][item] for item in np.argmax(ner_pred, axis=-1)[0]] 109 | ner_res = get_entity_with_check(text, ner_pred) 110 | subject, response = utter(clf_res, ner_res) 111 | return json.dumps({ 112 | "text": text[:app.model_configs["max_len"]], 113 | "predicate": clf_res, 114 | "subject": subject, 115 | "response": response}, 116 | ensure_ascii=False, indent=4) 117 | 118 | 119 | def get_entity_with_check(text, tokens): 120 | """通过模糊匹配原始数据进行二次确认 121 | 避免小数据集下的CRF转移矩阵训练不准确带来的误识别 122 | """ 123 | entities = get_entity(text, tokens) 124 | checked_entities = [] 125 | for entity in entities: 126 | check_list = app.prior_checks[entity[0]] 127 | if entity[1] not in check_list: 128 | # less_list:表示算法获取的实体值字符少于原始值,进行最大匹配,如【周星】->【周星驰】 129 | # more_list:表示算法获取的实体值字符多于原始值,进行最小匹配,如【周星驰的】->【周星驰】 130 | less_list, more_list = [], [] 131 | for item in check_list: 132 | if entity[1] in item: 133 | less_list.append(item) 134 | if item in entity[1]: 135 | more_list.append(item) 136 | # 粗暴做法:直接获取两个list中的最长串,如果有多个就返回首个 137 | max_match_len = max([len(i) for i in list(set(less_list + more_list))]) if (less_list + more_list) != [] else None 138 | entity[1] = [i for i in list(set(less_list + more_list)) if len(i) == max_match_len][0] \ 139 | if max_match_len is not None else entity[1] 140 | checked_entities.append(entity) 141 | return checked_entities 142 | 143 | 144 | def get_entity(text, tokens): 145 | """获取ner结果 146 | """ 147 | # 如果text长度小于规定的max_len长度,则只保留text长度的tokens 148 | text_len = len(text) 149 | tokens = tokens[:text_len] 150 | 151 | entities = [] 152 | entity = "" 153 | for idx, char, token in zip(range(text_len), text, tokens): 154 | if token.startswith("O") or token.startswith(app.model_configs["tag_padding"]): 155 | token_prefix = token 156 | token_suffix = None 157 | else: 158 | token_prefix, token_suffix = token.split("-") 159 | if token_prefix == "S": 160 | entities.append([token_suffix, char]) 161 | entity = "" 162 | elif token_prefix == "B": 163 | if entity != "": 164 | entities.append([tokens[idx-1].split("-")[-1], entity]) 165 | entity = "" 166 | else: 167 | entity += char 168 | elif token_prefix == "I": 169 | if entity != "": 170 | entity += char 171 | else: 172 | entity = "" 173 | else: 174 | if entity != "": 175 | entities.append([tokens[idx-1].split("-")[-1], entity]) 176 | entity = "" 177 | else: 178 | continue 179 | 180 | return entities 181 | 182 | 183 | def utter(clf_res, ner_res): 184 | subject, response = [], [] 185 | for ner_item in ner_res: 186 | if ner_item[0] in app.utter_search[clf_res]: 187 | print(ner_item) 188 | print(app.database[:3]) 189 | print(app.utter_search[clf_res]) 190 | print(app.utter_search[clf_res].format(ner_item[1])) 191 | subject.append({ner_item[0]: ner_item[1]}) 192 | response = eval(app.utter_search[clf_res].format(ner_item[1])) 193 | else: 194 | response = None 195 | return subject, response 196 | 197 | 198 | def first_predict(): 199 | """第一次使用模型时需要加载,否则会降低预测速度 200 | """ 201 | parse("") 202 | 203 | 204 | first_predict() 205 | 206 | 207 | @app.route("/query", methods=["POST"]) 208 | def decode(): 209 | app.model_configs["logger"].info(colored("[RECEIVE]: ", "red") + colored(request.json["text"], "cyan")) 210 | res = parse(request.json["text"]) 211 | app.model_configs["logger"].info(colored("[SEND]: ", "green") + colored(res, "cyan")) 212 | return res 213 | 214 | 215 | if __name__ == "__main__": 216 | app.run(host="0.0.0.0", port=2020, debug=True, use_reloader=False) 217 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 | [English Version](https://github.com/liushaoweihua/keras-bert-kbqa/blob/master/README.md) | [中文版说明](https://github.com/liushaoweihua/keras-bert-kbqa/blob/master/README_ZH.md) 2 | 3 | # Keras-Bert-Kbqa 4 | 5 | **预训练语言模型BERT系列**在**知识图谱问答领域**的纯神经网络模型实现尝试,支持**BERT/RoBERTa/ALBERT**。 6 | 7 | ## KBQA的核心任务 8 | 9 | * **知识体系构建(KB)** 10 | * 基于业务特点,梳理知识体系; 11 | * 非结构化输入文本抽取三元组`(主实体Subject,关系Predicate,客实体Object)`,并以特定方式进行存储(通常为图数据库)。 12 | * 如:"周星驰的电影功夫上映于2004年",包含两对三元组`(周星驰,拍摄的电影,功夫)`,`(功夫,上映时间,2004年)`; 13 | 14 | * **标准问答查询(QA)** 15 | * 关系实体抽取 16 | * 查询语句抽取二元组`(主实体Subject,关系Predicate)`; 17 | * 如:"功夫上映于哪一年",包含一对二元组`(功夫,上映时间)`; 18 | * 实体消歧 19 | * 解决同名实体产生歧义的问题; 20 | * 如:周星驰和星爷应对应同一实体; 21 | * 关系链接 22 | * 将抽取得到的实体与关系进行链接,保证链接后的实体关系在知识体系中是有效的; 23 | * 如:豆瓣影评任务下询问"周星驰的母亲叫什么名字",所得到的二元组`(周星驰,母亲)`是非法的,因为知识体系中未建立该关系; 24 | * 结果查询 25 | * 在知识体系中检索合法的关系实体对,获取结果输出。 26 | 27 | ## 涉及内容 28 | 29 | 本项目主要关注**标准问答查询(QA)**任务中的**关系实体抽取**部分。常规KBQA的Query包含以下类别: 30 | * 单跳推导 31 | * 如:"功夫上映于哪一年",二元组为`(功夫,上映时间)`; 32 | * 推导比较 33 | * 如:"功夫的上映时间和赌圣比哪个早",二元组为`((功夫,上映时间)~(赌圣,上映时间))`,需要分别查询结果进行比较; 34 | * 嵌套推导 35 | * 如:"周星驰的母亲的年龄",二元组为`((周星驰,母亲),年龄)`,需要进行嵌套查询; 36 | * 嵌套推导比较 37 | * 如:"周星驰的母亲的年龄和吴孟达的年龄谁大",二元组为`(((周星驰,母亲),年龄)~(吴孟达,年龄))`。 38 | 39 | 项目仅处理最常用的第一种情况,**同时采用全局语义进行分类获取关系,并逐字进行序列标注获取实体**。 40 | 41 | 对于后三种情况,需要**先逐字进行序列标注获取实体,并采用全局语义与实体的局部语义(先验信息)获取多个关系**,目前难点在于所获取的**多关系与多实体的准确链接**与**多任务处理的损失放大**上,纯模型较难处理。 42 | 43 | ## 处理方式 44 | 45 | ### 模型结构 46 | 47 | 单跳推导的两类方法: 48 | * 流水线式(Pipeline)方法:关系分类和实体抽取分成两个任务进行,分别计算loss,不互相影响 49 | * 模型训练简单,为常规NLP下游任务:分类和序列标注; 50 | * 预测速度慢,需要同时输入两个模型; 51 | * 因为模型预测误差,容易出现非法的二元组,如:`(周星驰,上映时间)`,需要执行关系链接操作; 52 | * 联合式(Joint)方法:关系分类和实体抽取采用同一个公共Embedding层进行编码,并采用multi-task的处理方式计算loss,互相影响 53 | * 模型训练困难,两类下游任务的loss及梯度下降速度均不在一个量级上; 54 | * 预测速度快,仅通过单模型即可获得两个下游任务的输出; 55 | * 公共Embedding层编码共用信息的情况下,极少出现非法二元组,避免后续再执行关系链接的操作。 56 | 57 | **项目采用联合式方法进行构建。** 58 | 59 | ### 训练方式 60 | * 先训练较难的序列标注任务,冻结分类任务的下游权重,直至验证集精度达到设定阈值; 61 | * 采用[Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115.pdf)提到的MultiLoss,魔改并计算分类任务和序列标注任务的multi-loss,继承前一阶段的optimizer,同时训练两类下游任务。 62 | 63 | ## 项目框架 64 | 65 | ```bash 66 | keras_bert_kbqa 67 | ├── helper.py # 训练参数帮助文件 68 | ├── __init__.py 69 | ├── predict.py # 加载已训练模型 70 | ├── train.py # 新模型训练与保存 71 | └── utils 72 | ├── bert.py # bert模型的keras实现 73 | ├── callbacks.py # EarlyStopping,ReduceLROnPlateau,TaskSwitch 74 | ├── decoder.py # 序列标注任务的Viterbi解码器 75 | ├── graph_builder.py # neo4j图数据库处理函数,在本项目中未进行使用 76 | ├── __init__.py 77 | ├── metrics.py # 序列标注任务的Crf_accuracy和Crf_loss,支持mask 78 | ├── models.py # 分类任务支持textcnn和dense,序列标注任务支持idcnn-crf和bilstm-crf 79 | ├── processor.py # 标准化训练/验证数据集 80 | └── tokenizer.py # bert模型的分词器 81 | ``` 82 | 83 | ## 依赖项 84 | 85 | * flask == 1.1.1 86 | * keras == 2.3.1 87 | * numpy == 1.18.1 88 | * loguru == 0.4.1 89 | * requests == 2.22.0 90 | * termcolor == 1.1.0 91 | * tensorflow == 1.15.2 92 | * keras_contrib == 2.0.8 93 | 94 | ## 案例:豆瓣影评 95 | 96 | * 由于项目主要关注**标准问答查询(QA)**任务中的**关系实体抽取**部分,故而对于其余部分用较为简陋的方式进行实现(未采用图数据库); 97 | * 该案例模仿了工程实践中的线上案例,**存在问法数据量较低的问题,通过模板+填充的方式生成数据与实际数据分布存在较大差异,但仍是目前的主流做法**; 98 | * 测试结果表明在数据量较低的情况下: 99 | * **泛化误差较大**,仅采用模型效果不佳,应与正则、规则结合使用; 100 | * **模型难以训练**。 101 | 102 | 103 | 104 | ### 案例框架 105 | 106 | ```bash 107 | examples 108 | ├── data 109 | │   ├── build_upload.py # 从原始数据中生成训练数据、验证数据等 110 | │   ├── data 111 | │ │ ├── database.txt # 从原始数据中生成的数据库,用于查询结果检索(未使用图数据库) 112 | │   │   ├── dev_data.txt # 验证数据 113 | │   │   ├── prior_check.txt # 双重验证兜底,对于算法识别实体错误的结果进行纠正 114 | │   │   └── train_data.txt # 训练数据 115 | │   ├── origin_data 116 | │   │   └── douban_movies.txt # 原始数据 117 | │   └── templates 118 | │   ├── neo4j_config.txt # 图数据库配置文件,在本项目中未进行使用 119 | │   ├── text_templates.txt # 训练/验证数据生成模板 120 | │   └── utter_search.txt # 问题查询数据库命令(未使用图数据库,因此写的比较丑陋) 121 | ├── deploy # 发布使用 122 | │   ├── run_deploy.py 123 | │   └── run_deploy.sh 124 | ├── models # 模型信息保存位置 125 | │   ├── ALBERT-IDCNN-CRF.h5 126 | │   ├── id_to_label.pkl 127 | │   ├── id_to_tag.pkl 128 | │   ├── label_to_id.pkl 129 | │   ├── model_configs.json 130 | │   └── tag_to_id.pkl 131 | └── train # 模型训练 132 | ├── run_train.py 133 | ├── run_train.sh 134 | └── train_config.json # 训练配置文件 135 | 136 | ``` 137 | 138 | ### 数据形式 139 | 140 | 训练、验证数据形式为`[文本信息,类别信息,序列标注信息]`,如下: 141 | 142 | ```json 143 | [ 144 | [ 145 | "骗中骗的评分高吗", 146 | "豆瓣评分", 147 | "B-title I-title I-title O O O O O" 148 | ], 149 | [ 150 | "安东尼娅家族啥时候上映的呀", 151 | "电影上映时间是什么", 152 | "B-title I-title I-title I-title I-title I-title O O O O O O O" 153 | ], 154 | ... 155 | ] 156 | ``` 157 | 158 | ### 训练参数配置的一些技巧 159 | 160 | 该部分内容位于`examples/train/train_config.json`中: 161 | 162 | * 句长参数`max_len`应适配于训练、测试文本的长度,过长的句长将占用较大的显存,且对于序列标注任务的收敛影响较大; 163 | * 在数据量较低的情况下,ALBERT模型比BERT模型更易训练,且效果与BERT模型相差不大; 164 | * `all_train_threshold`表示序列标注任务的验证精度达到该值时,同时训练分类任务和序列标注任务: 165 | * 该值过小将导致序列标注任务无法收敛,而分类任务易过拟合; 166 | * 该值过大将导致分类任务欠拟合; 167 | * 建议取值在0.9~0.98之间; 168 | * `clf_type`可取`textcnn`和`dense`: 169 | * 为`textcnn`时,其余参数为`dense_units`,`dropout_rate`,`filters`和`kernel_size`; 170 | * 为`dense`时,其余参数为`dense_units`; 171 | * `ner_type`可取`idcnn`和`bilstm`: 172 | * 为`idcnn`时,其余参数为`filters`,`kernel_size`和`blocks`; 173 | * 为`bilstm`时,其余参数为`units`,`num_hidden_layers`和`dropout_rate`。 174 | 175 | ### 执行流程 176 | 177 | ```bash 178 | python examples/data/build_upload.py # 生成examples/data/data中的所有文件 179 | bash examples/train/run_train.sh # 训练模型 180 | bash examples/deploy/run_deploy.sh # 使用模型 181 | ``` 182 | 183 | ### 模型使用 184 | 185 | 调用接口: 186 | 187 | ```python 188 | import requests 189 | 190 | r = requests.post( 191 | "http://your_ip:your_port/query", 192 | json={ 193 | "text": "大话西游之大圣娶亲是最近刚上的电影吗"}) 194 | 195 | print(r.text) 196 | ``` 197 | 198 | 接口返回结果: 199 | 200 | ```json 201 | { 202 | "text": "大话西游之大圣娶亲是最近刚上的", 203 | "predicate": "电影上映时间是什么", 204 | "subject": [ 205 | { 206 | "title": "大话西游之大圣娶亲" 207 | } 208 | ], 209 | "response": "2014" 210 | } 211 | ``` 212 | 213 | ## 未来工作 214 | 215 | * 优化训练难度,使模型更容易训练; 216 | * 尝试处理更为复杂的KBQA场景; 217 | * 项目细节完善; 218 | * 项目迁移至tensorflow 2.0; 219 | * 新增BERT系列改进模型,如Distill Bert,Tiny Bert。 220 | 221 | ## 一些常用的中文预训练模型 222 | 223 | > **BERT** 224 | * [Google_bert](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) 225 | * [HIT_bert_wwm_ext](https://storage.googleapis.com/chineseglue/pretrain_models/chinese_wwm_ext_L-12_H-768_A-12.zip) 226 | 227 | > **ALBERT** 228 | * [Google_albert_base](https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz) 229 | * [Google_albert_large](https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz) 230 | * [Google_albert_xlarge](https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz) 231 | * [Google_albert_xxlarge](https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz) 232 | * [Xuliang_albert_xlarge](https://storage.googleapis.com/albert_zh/albert_xlarge_zh_177k.zip) 233 | * [Xuliang_albert_large](https://storage.googleapis.com/albert_zh/albert_large_zh.zip) 234 | * [Xuliang_albert_base](https://storage.googleapis.com/albert_zh/albert_base_zh.zip) 235 | * [Xuliang_albert_base_ext](https://storage.googleapis.com/albert_zh/albert_base_zh_additional_36k_steps.zip) 236 | * [Xuliang_albert_small](https://storage.googleapis.com/albert_zh/albert_small_zh_google.zip) 237 | * [Xuliang_albert_tiny](https://storage.googleapis.com/albert_zh/albert_tiny_zh_google.zip) 238 | 239 | > **Roberta** 240 | * [roberta](https://storage.googleapis.com/chineseglue/pretrain_models/roeberta_zh_L-24_H-1024_A-16.zip) 241 | * [roberta_wwm_ext](https://storage.googleapis.com/chineseglue/pretrain_models/chinese_roberta_wwm_ext_L-12_H-768_A-12.zip) 242 | * [roberta_wwm_ext_large](https://storage.googleapis.com/chineseglue/pretrain_models/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.zip) 243 | 244 | ## 参考 245 | * 同源项目:[Keras-Bert-Ner](https://github.com/liushaoweihua/keras-bert-ner) 246 | * 项目的BERT代码参考:[bert4keras](https://github.com/bojone/bert4keras) 247 | * ALBERT中文预训练模型系列,更快的推理时间和较高的预测精度:[albert_zh](https://github.com/brightmart/albert_zh) 248 | * [BERT](https://github.com/google-research/bert), [ALBERT](https://github.com/google-research/albert), [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)。 249 | 250 | 感谢以上作者和项目的贡献! -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: tokenizer.py 8 | @Time: 2020/3/3 10:37 AM 9 | """ 10 | 11 | # Codes come from : 12 | # Author: Jianlin Su 13 | # Github: https://github.com/bojone/bert4keras 14 | # Site: kexue.fm 15 | # Version: 0.2.5 16 | 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import re 24 | import six 25 | import codecs 26 | import unicodedata 27 | import numpy as np 28 | 29 | 30 | def is_string(s): 31 | if not six.PY2: 32 | return isinstance(s, str) 33 | 34 | 35 | def load_vocab(dict_path): 36 | """从Bert词典文件中读取词典 37 | """ 38 | token_dict = {} 39 | with codecs.open(dict_path, encoding="utf-8") as reader: 40 | for line in reader: 41 | token = line.strip() 42 | token_dict[token] = len(token_dict) 43 | 44 | return token_dict 45 | 46 | 47 | class BasicTokenizer(object): 48 | """分词器基类 49 | """ 50 | def __init__(self): 51 | """初始化 52 | """ 53 | self._token_pad = "[PAD]" 54 | self._token_cls = "[CLS]" 55 | self._token_sep = "[SEP]" 56 | self._token_unk = "[UNK]" 57 | self._token_mask = "[MASK]" 58 | 59 | def tokenize(self, text, add_cls=True, add_sep=True): 60 | """分词函数 61 | """ 62 | tokens = self._tokenize(text) 63 | if add_cls: 64 | tokens.insert(0, self._token_cls) 65 | if add_sep: 66 | tokens.append(self._token_sep) 67 | return tokens 68 | 69 | def token_to_id(self, token): 70 | """token转换为对应的id 71 | """ 72 | raise NotImplementedError 73 | 74 | def tokens_to_ids(self, tokens): 75 | """token序列转换为对应的id序列 76 | """ 77 | return [self.token_to_id(token) for token in tokens] 78 | 79 | def truncate_sequence(self, 80 | max_length, 81 | first_sequence, 82 | second_sequence=None): 83 | """截断总长度 84 | """ 85 | if second_sequence is None: 86 | second_sequence = [] 87 | 88 | while True: 89 | total_length = len(first_sequence) + len(second_sequence) 90 | if total_length <= max_length: 91 | break 92 | elif len(first_sequence) > len(second_sequence): 93 | first_sequence.pop() 94 | else: 95 | second_sequence.pop() 96 | 97 | def encode(self, 98 | first_text, 99 | second_text=None, 100 | max_length=None, 101 | first_length=None, 102 | second_length=None): 103 | """输出文本对应token id和segment id 104 | 如果传入first_length,则强行padding第一个句子到指定长度; 105 | 同理,如果传入second_length,则强行padding第二个句子到指定长度。 106 | """ 107 | first_tokens = self.tokenize(first_text, add_cls=False, add_sep=False) 108 | if second_text is None: 109 | if max_length is not None: 110 | first_tokens = first_tokens[:max_length - 2] 111 | else: 112 | second_tokens = self.tokenize(second_text, 113 | add_cls=False, 114 | add_sep=False) 115 | if max_length is not None: 116 | self.truncate_sequence(max_length - 3, first_tokens, second_tokens) 117 | 118 | first_tokens = [self._token_cls] + first_tokens + [self._token_sep] 119 | first_token_ids = self.tokens_to_ids(first_tokens) 120 | if first_length is not None: 121 | first_token_ids = first_token_ids[:first_length] 122 | first_token_ids.extend([self._token_pad_id] * 123 | (first_length - len(first_token_ids))) 124 | first_segment_ids = [0] * len(first_token_ids) 125 | 126 | if second_text is not None: 127 | second_tokens = second_tokens + [self._token_sep] 128 | second_token_ids = self.tokens_to_ids(second_tokens) 129 | if second_length is not None: 130 | second_token_ids = second_token_ids[:second_length] 131 | second_token_ids.extend( 132 | [self._token_pad_id] * 133 | (second_length - len(second_token_ids))) 134 | second_segment_ids = [1] * len(second_token_ids) 135 | 136 | first_token_ids.extend(second_token_ids) 137 | first_segment_ids.extend(second_segment_ids) 138 | 139 | return first_token_ids, first_segment_ids 140 | 141 | def id_to_token(self, i): 142 | """id序列为对应的token 143 | """ 144 | raise NotImplementedError 145 | 146 | def ids_to_tokens(self, ids): 147 | """id序列转换为对应的token序列 148 | """ 149 | return [self.id_to_token(i) for i in ids] 150 | 151 | def decode(self, ids): 152 | """转为可读文本 153 | """ 154 | raise NotImplementedError 155 | 156 | def _tokenize(self, text): 157 | """基本分词函数 158 | """ 159 | raise NotImplementedError 160 | 161 | 162 | class Tokenizer(BasicTokenizer): 163 | """Bert原生分词器 164 | 纯Python实现,代码修改自keras_bert的tokenizer实现 165 | """ 166 | def __init__(self, token_dict, case_sensitive=True): 167 | """初始化 168 | """ 169 | super(Tokenizer, self).__init__() 170 | if is_string(token_dict): 171 | token_dict = load_vocab(token_dict) 172 | 173 | self._token_dict = token_dict 174 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 175 | self._case_sensitive = case_sensitive 176 | for token in ["pad", "cls", "sep", "unk", "mask"]: 177 | try: 178 | _token_id = token_dict[getattr(self, "_token_%s" % token)] 179 | setattr(self, "_token_%s_id" % token, _token_id) 180 | except: 181 | pass 182 | self._vocab_size = len(token_dict) 183 | 184 | def token_to_id(self, token): 185 | """token转换为对应的id 186 | """ 187 | return self._token_dict.get(token, self._token_unk_id) 188 | 189 | def id_to_token(self, i): 190 | """id转换为对应的token 191 | """ 192 | return self._token_dict_inv[i] 193 | 194 | def decode(self, ids): 195 | """转为可读文本 196 | """ 197 | tokens = self.ids_to_tokens(ids) 198 | tokens = [token for token in tokens if not self._is_special(token)] 199 | 200 | text, flag = "", False 201 | for i, token in enumerate(tokens): 202 | if token[:2] == "##": 203 | text += token[2:] 204 | elif len(token) == 1 and self._is_cjk_character(token): 205 | text += token 206 | elif len(token) == 1 and self._is_punctuation(token): 207 | text += token 208 | text += " " 209 | elif i > 0 and self._is_cjk_character(text[-1]): 210 | text += token 211 | else: 212 | text += " " 213 | text += token 214 | 215 | text = re.sub(" +", " ", text) 216 | text = re.sub("\" (re|m|s|t|ve|d|ll) ", "\'\\1 ", text) 217 | punctuation = self._cjk_punctuation() + "+-/={(<[" 218 | punctuation_regex = "|".join([re.escape(p) for p in punctuation]) 219 | punctuation_regex = "(%s) " % punctuation_regex 220 | text = re.sub(punctuation_regex, "\\1", text) 221 | 222 | return text.strip() 223 | 224 | def _tokenize(self, text): 225 | """基本分词函数 226 | """ 227 | if not self._case_sensitive: 228 | text = unicodedata.normalize("NFD", text) 229 | text = "".join( 230 | [ch for ch in text if unicodedata.category(ch) != "Mn"]) 231 | text = text.lower() 232 | 233 | spaced = "" 234 | for ch in text: 235 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 236 | spaced += " " + ch + " " 237 | elif self._is_space(ch): 238 | spaced += " " 239 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 240 | continue 241 | else: 242 | spaced += ch 243 | 244 | tokens = [] 245 | for word in spaced.strip().split(): 246 | tokens.extend(self._word_piece_tokenize(word)) 247 | 248 | return tokens 249 | 250 | def _word_piece_tokenize(self, word): 251 | """word内分成subword 252 | """ 253 | if word in self._token_dict: 254 | return [word] 255 | 256 | tokens = [] 257 | start, stop = 0, 0 258 | while start < len(word): 259 | stop = len(word) 260 | while stop > start: 261 | sub = word[start:stop] 262 | if start > 0: 263 | sub = "##" + sub 264 | if sub in self._token_dict: 265 | break 266 | stop -= 1 267 | if start == stop: 268 | stop += 1 269 | tokens.append(sub) 270 | start = stop 271 | 272 | return tokens 273 | 274 | @staticmethod 275 | def _is_space(ch): 276 | """空格类字符判断 277 | """ 278 | return ch == " " or ch == "\n" or ch == "\r" or ch == "\t" or \ 279 | unicodedata.category(ch) == "Zs" 280 | 281 | @staticmethod 282 | def _is_punctuation(ch): 283 | """标点符号类字符判断(全/半角均在此内) 284 | """ 285 | code = ord(ch) 286 | return 33 <= code <= 47 or \ 287 | 58 <= code <= 64 or \ 288 | 91 <= code <= 96 or \ 289 | 123 <= code <= 126 or \ 290 | unicodedata.category(ch).startswith("P") 291 | 292 | @staticmethod 293 | def _cjk_punctuation(): 294 | return u"\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\xb7\uff01\uff1f\uff61\u3002" 295 | 296 | @staticmethod 297 | def _is_cjk_character(ch): 298 | """CJK类字符判断(包括中文字符也在此列) 299 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 300 | """ 301 | code = ord(ch) 302 | return 0x4E00 <= code <= 0x9FFF or \ 303 | 0x3400 <= code <= 0x4DBF or \ 304 | 0x20000 <= code <= 0x2A6DF or \ 305 | 0x2A700 <= code <= 0x2B73F or \ 306 | 0x2B740 <= code <= 0x2B81F or \ 307 | 0x2B820 <= code <= 0x2CEAF or \ 308 | 0xF900 <= code <= 0xFAFF or \ 309 | 0x2F800 <= code <= 0x2FA1F 310 | 311 | @staticmethod 312 | def _is_control(ch): 313 | """控制类字符判断 314 | """ 315 | return unicodedata.category(ch) in ("Cc", "Cf") 316 | 317 | @staticmethod 318 | def _is_special(ch): 319 | """判断是不是有特殊含义的符号 320 | """ 321 | return bool(ch) and (ch[0] == "[") and (ch[-1] == "]") -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: models.py 8 | @Time: 2020/3/9 2:43 PM 9 | """ 10 | 11 | # Some codes come from : 12 | # Author: Jianlin Su 13 | # Github: https://github.com/bojone/bert4keras 14 | # Site: kexue.fm 15 | # Version: 0.2.5 16 | 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import keras 24 | import numpy as np 25 | import tensorflow as tf 26 | import keras.backend as K 27 | from keras import Model 28 | from keras.initializers import Constant 29 | from keras_contrib.layers import CRF 30 | from .bert import * 31 | from .metrics import CrfAcc, CrfLoss 32 | 33 | 34 | def set_gelu(version): 35 | """设置gelu版本 36 | """ 37 | version = version.lower() 38 | assert version in ["erf", "tanh"], "gelu version must be erf or tanh" 39 | if version == "erf": 40 | keras.utils.get_custom_objects()["gelu"] = gelu_erf 41 | else: 42 | keras.utils.get_custom_objects()["gelu"] = gelu_tanh 43 | 44 | 45 | def gelu_erf(x): 46 | """基于Erf直接计算的gelu函数 47 | """ 48 | return 0.5 * x * (1.0 + tf.math.erf(x / np.sqrt(2.0))) 49 | 50 | 51 | def gelu_tanh(x): 52 | """基于Tanh近似计算的gelu函数 53 | """ 54 | cdf = 0.5 * (1.0 + K.tanh( 55 | (np.sqrt(2 / np.pi) * (x + 0.044715 * K.pow(x, 3))))) 56 | return x * cdf 57 | 58 | 59 | set_gelu("tanh") 60 | 61 | 62 | class KbqaModel: 63 | """Bert Kbqa模型基础类 64 | """ 65 | def __init__(self, 66 | bert_config, 67 | bert_checkpoint, 68 | albert, 69 | clf_configs, 70 | ner_configs, 71 | max_len, 72 | numb_labels, 73 | numb_tags, 74 | tag_to_id, 75 | tag_padding): 76 | super(KbqaModel, self).__init__() 77 | self._build_bert_model(bert_config, bert_checkpoint, albert) 78 | self.clf_configs = clf_configs 79 | self.ner_configs = ner_configs 80 | self.max_len = max_len 81 | self.numb_labels = numb_labels 82 | self.numb_tags = numb_tags 83 | self.tag_to_id = tag_to_id 84 | self.tag_padding = tag_padding 85 | 86 | def _build_bert_model(self, bert_config, bert_checkpoint, albert): 87 | self.bert_model = build_bert_model( 88 | bert_config, 89 | bert_checkpoint, 90 | albert=albert) 91 | for layer in self.bert_model.layers: 92 | layer.trainable = True 93 | 94 | def train_all_tasks(self): 95 | """是否开启所有冻结的层进行训练 96 | """ 97 | for layer in self.full_train_model.layers: 98 | layer.trainable = True 99 | 100 | def build(self): 101 | """Kbqa模型 102 | """ 103 | # 1. Embeddings层建立 104 | x_in = Input(shape=(self.max_len,), name="Origin-Input-Token") 105 | s_in = Input(shape=(self.max_len,), name="Origin-Input-Segment") 106 | x = self.bert_model([x_in, s_in]) 107 | clf_x = Lambda(lambda X: X[:, 0], name="Clf-Embedding")(x) 108 | ner_x = Lambda(lambda X: X[:, 1:], name="Ner-Embedding")(x) 109 | 110 | # 2. 下游网络层建立 111 | # clf任务 112 | clf_o = self.clf(clf_x) 113 | # ner任务,添加了clf_o的先验信息 114 | ner_o = self.ner(ner_x, clf_o) 115 | 116 | # 3. 返回训练模型与预测模型 117 | clf_true = Input(shape=(self.numb_labels,), name="Clf-True") 118 | ner_true = Input(shape=(self.max_len - 1, 1), name="Ner-True") 119 | clf_loss = K.categorical_crossentropy(clf_true, clf_o) 120 | clf_acc = keras.metrics.categorical_accuracy(clf_true, clf_o) 121 | ner_loss = CrfLoss(self.tag_to_id, self.tag_padding).crf_loss(ner_true, ner_o) 122 | ner_acc = CrfAcc(self.tag_to_id, self.tag_padding).crf_accuracy(ner_true, ner_o) 123 | train_o = MultiLossLayer(self.tag_to_id, self.tag_padding)([clf_true, ner_true, clf_o, ner_o]) 124 | # 较难任务训练模型 125 | self.hard_train_model = Model([x_in, s_in], ner_o) 126 | # 全训练模型 127 | self.full_train_model = Model([x_in, s_in, clf_true, ner_true], train_o) 128 | self.full_train_model.add_metric(clf_loss, name="clf_loss") 129 | self.full_train_model.add_metric(clf_acc, name="clf_acc") 130 | self.full_train_model.add_metric(ner_loss, name="ner_loss") 131 | self.full_train_model.add_metric(ner_acc, name="ner_acc") 132 | self.full_train_model.train_all_tasks = self.train_all_tasks 133 | # 预测模型 134 | self.pred_model = Model([x_in, s_in], [clf_o, ner_o]) 135 | 136 | def clf(self, clf_x): 137 | """分类任务模型,由于上游模型已经采用bert,所以下游模型尽可能简单 138 | 简单给了TextCNN和Dense两类样例 139 | """ 140 | # 配置解析 141 | clf_type = self.clf_configs.get("clf_type").lower() 142 | assert clf_type in ["textcnn", "dense"], "clf_type should be 'textcnn' or 'dense'" 143 | dropout_rate = self.clf_configs.get("dropout_rate") 144 | dense_units = self.clf_configs.get("dense_units") 145 | 146 | # clf模型定义 147 | def textcnn(clf_x): 148 | clf_x = Lambda(lambda X: K.expand_dims(X, axis=-1))(clf_x) 149 | clf_pool_output = [] 150 | for kernel_size in self.clf_configs.get("kernels"): 151 | clf_conv = Conv1D(filters=self.clf_configs.get("filters"), kernel_size=kernel_size, strides=1, padding="same", 152 | activation="relu", name="Clf-Conv-%s" % kernel_size, trainable=False)(clf_x) 153 | clf_pool = MaxPooling1D(name="Clf-MaxPooling-%s" % kernel_size)(clf_conv) 154 | clf_pool_output.append(clf_pool) 155 | clf_o = concatenate(clf_pool_output) 156 | clf_o = Dropout(self.clf_configs.get("dropout_rate"), name="Clf-Dropout", trainable=False)(clf_o) 157 | clf_o = Flatten(name="Clf-Flatten")(clf_o) 158 | clf_o = Dense(self.clf_configs.get("dense_units"), activation="relu", name="Clf-Dense-In", trainable=False)(clf_o) 159 | clf_o = Dense(self.numb_labels, activation="softmax", name="Clf-Dense-Out", trainable=False)(clf_o) 160 | return clf_o 161 | 162 | def dense(clf_x): 163 | clf_o = Dense(self.clf_configs.get("dense_units"), activation="relu", name="Clf-Dense-In", trainable=False)(clf_x) 164 | clf_o = Dense(self.numb_labels, activation="softmax", name="Clf-Dense-Out", trainable=False)(clf_o) 165 | return clf_o 166 | 167 | # 模型构建 168 | if clf_type == "textcnn": 169 | clf_o = textcnn(clf_x) 170 | else: 171 | clf_o = dense(clf_x) 172 | return clf_o 173 | 174 | def ner(self, ner_x, clf_o): 175 | """序列标注任务模型,需要添加clf_o的先验信息 176 | 给了Idcnn和Bilstm两类样例 177 | """ 178 | # 配置解析 179 | ner_type = self.ner_configs.get("ner_type").lower() 180 | assert ner_type in ["idcnn", "bilstm"], "ner_type should be 'idcnn' or 'bilstm'" 181 | 182 | # ner模型定义 183 | def idcnn(ner_x): 184 | def dilation_conv1d(dilation_rate, name): 185 | return Conv1D(self.ner_configs.get("filters"), self.ner_configs.get("kernel_size"), padding="same", dilation_rate=dilation_rate, name=name) 186 | 187 | def idcnn_block(name): 188 | return [dilation_conv1d(1, name + "1"), dilation_conv1d(1, name + "2"), dilation_conv1d(2, name + "3")] 189 | 190 | ner_o = [] 191 | for layer_idx in range(self.ner_configs.get("blocks")): 192 | name = "Idcnn-Block-%s-Layer-" % layer_idx 193 | idcnns = idcnn_block(name) 194 | cnn = idcnns[0](ner_x) 195 | cnn = idcnns[1](cnn) 196 | cnn = idcnns[2](cnn) 197 | ner_o.append(cnn) 198 | ner_o = concatenate(ner_o, axis=-1) 199 | return ner_o 200 | 201 | def bilstm(ner_x): 202 | for layer_idx in range(self.ner_configs.get("num_hidden_layers")): 203 | name = "Bilstm-Layer-%s" % layer_idx 204 | ner_x = Bidirectional(LSTM(units=self.ner_configs.get("units"), return_sequences=True, recurrent_dropout=self.ner_configs.get("dropout_rate")), 205 | name=name)(ner_x) 206 | return ner_x 207 | 208 | # 模型构建 209 | clf_o = ExpandDims(self.max_len-1, name="Clf-Prior")(clf_o) 210 | ner_x = Concatenate(name="Ner-Clf-Joint")([ner_x, clf_o]) 211 | if ner_type == "idcnn": 212 | ner_o = idcnn(ner_x) 213 | else: 214 | ner_o = bilstm(ner_x) 215 | ner_o = CRF(self.numb_tags, sparse_target=True, name="Ner-CRF")(ner_o) 216 | return ner_o 217 | 218 | 219 | class ExpandDims(Layer): 220 | """需要写成自定义层的形式而不能用Lambda函数写,在保存时会出错 221 | """ 222 | def __init__(self, max_len, **kwargs): 223 | self.max_len = max_len 224 | super(ExpandDims, self).__init__(**kwargs) 225 | 226 | def call(self, inputs): 227 | return K.tile(K.expand_dims(inputs, 1), [1, self.max_len, 1]) 228 | 229 | def get_config(self): 230 | config = { 231 | "max_len": self.max_len} 232 | base_config = super(ExpandDims, self).get_config() 233 | return dict(list(base_config.items()) + list(config.items())) 234 | 235 | def compute_output_shape(self, input_shape): 236 | return (input_shape[0], self.max_len, input_shape[1]) 237 | 238 | 239 | class MultiLossLayer(Layer): 240 | """以下论文提出方法的魔改 241 | Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 242 | paper: http://openaccess.thecvf.com/content_cvpr_2018/html/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.html 243 | Keras code: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example.ipynb 244 | """ 245 | def __init__(self, tag_to_id, tag_padding): 246 | self.tag_to_id = tag_to_id 247 | self.tag_padding = tag_padding 248 | self.loss_funcs = [K.categorical_crossentropy, CrfLoss(tag_to_id, tag_padding).crf_loss] 249 | super(MultiLossLayer, self).__init__() 250 | 251 | def build(self, input_shape=None): 252 | self.log_vars = [] 253 | for i in range(2): 254 | self.log_vars += [ 255 | self.add_weight(name="log_var" + str(i), shape=(1,), initializer=Constant(0.), trainable=False)] 256 | super(MultiLossLayer, self).build(input_shape) 257 | 258 | def multi_loss(self, ys_true, ys_pred): 259 | loss = 0 260 | for y_true, y_pred, loss_func, log_var in zip(ys_true, ys_pred, self.loss_funcs, self.log_vars): 261 | precision = K.exp(-log_var[0]) 262 | loss += K.sum(precision * loss_func(y_true, y_pred) + log_var[0], -1) 263 | return K.mean(loss) 264 | 265 | def call(self, inputs): 266 | ys_true = inputs[:2] 267 | ys_pred = inputs[2:] 268 | loss = self.multi_loss(ys_true, ys_pred) 269 | self.add_loss(loss, inputs=inputs) 270 | return inputs 271 | 272 | def get_config(self): 273 | config = { 274 | "tag_to_id": self.tag_to_id, 275 | "tag_padding": self.tag_padding} 276 | base_config = super(MultiLossLayer, self).get_config() 277 | return dict(list(base_config.items()) + list(config.items())) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [English Version](https://github.com/liushaoweihua/keras-bert-kbqa/blob/master/README.md) | [中文版说明](https://github.com/liushaoweihua/keras-bert-kbqa/blob/master/README_ZH.md) 2 | 3 | # Keras-Bert-Kbqa 4 | 5 | Implementation of *Neural Network (NN)* model in the field of *Knowledge Based Question Answering (KBQA)* with Pre-trained Language Model: supports BERT/RoBERTa/ALBERT。 6 | 7 | ## Core Tasks of KBQA 8 | 9 | * **Build Knowledge System (KB)** 10 | * Organize knowledge system based on business characteristics. 11 | * Extract triples from unstructured input text `(Subject, Predicate, Object)` and store them in a specific way (usually a graph database). 12 | * For example: "Chow SingChi's movie *Kung Fu Hustle* was released in 2004", including two pairs of triples `(Chow SingChi, filmed, Kung Fu Hustle)`, `(Kung Fu Hustle, release time, 2004)`. 13 | 14 | * **Standard Question Answering (QA)** 15 | * Relational entity extraction 16 | * Extracting `(Subject, Predicate)` from query statements. 17 | * For example: "In which year is *Kung Fu Hustle* released" includes `(Kung Fu Hustle, release time)`. 18 | * Entity disambiguation 19 | * Solving the problem of ambiguity caused by entities with the same name. 20 | * For example: Chow SingChi and Xing Ye should correspond to the same entity. 21 | * Relationship linking 22 | * Linking the extracted entities and relationships to ensure that the linked entity relationship is valid in the knowledge system. 23 | * For example, in the *Douban Movie Review* scene, asking "What is the name of Chow SingChi's mother", the obtained `(Chow SingChi, mother)` is illegal because the relationship is not established in the knowledge system. 24 | * Response generating 25 | * Retrieve legal relational entity pairs in the knowledge system and generate output results. 26 | 27 | ## What's Involved 28 | 29 | This project mainly focuses on the **relational entity extraction** part of the **standard question and answer (QA)** task. Regular KBQA Query contains the following categories: 30 | * One-hop derivation 31 | * For example: "In which year is *Kung Fu Hustle* released," including `(Kung Fu Hustle, release time)`. 32 | * Comparison of derivation results 33 | * For example: "Is *Kung Fu Hustle* release earlier than *All for the Winner*", which includes `((Kung Fu Hustle, release time) ~ (All for the Winner, release time))`, you need to retrieve all results for comparison. 34 | * Nested derivation 35 | * For example: "What is the age of Chow SingChi's mother", includes `((Chow SingChi, mother), age)`, which requires nested query. 36 | * Comparison of nested derivation results 37 | * For example: "Is Chow SingChi's mother older than Ng Mang Tat", including `(((Chow SingChi, mother), age) ~ (Ng Mang Tat, age))`. 38 | 39 | The project only deals with the first case, which is most commonly used. **Relationships are obtained through classification with global semantics, at the same time, the entities are obtained through sequence labeling**. 40 | 41 | For the latter three cases, **first obtain the entities through sequence labeling, and then use global semantics and local semantics of entities (prior information) to obtain multiple relationships**. At present, the difficulties lie in **how to accurately link multi-relationships with multi-entities** and **how to handle the amplification of loss of multi-tasking**. It's hard to deal with using *neural network* models alone. 42 | 43 | ## Method 44 | 45 | ### Model Structure 46 | 47 | Methods of One-hop derivation 48 | * Pipeline method: The Relationship classification and entity extraction tasks are divided into two tasks to calculate the loss separately without affecting each other. 49 | * Easy to train as both of them are regular NLP tasks: classification and sequence labeling. 50 | * Slow in inference, the input needs to be fed into both models at inference phase. 51 | * Illegal results are prone to occur due to the model prediction error, which need to perform relationship linking task. 52 | * For example: `(Chow SingChi, release time)` 53 | * Joint method: The relationship classification and entity extraction tasks interact with each other, using the same embedding layer to obtain semantic encoding, and loss is calculated in multi-tasking way. 54 | * Hard to train, both loss and gradient descent speed of the two tasks are not on the same magnitude; 55 | * Fast in inference, the input is fed into a single model at inference phase. 56 | * Illegal results rarely occur due to the same semantic encoding for both tasks, which avoid the relationship linking task. 57 | 58 | **Joint method is used in this project**. 59 | 60 | ### Training Method 61 | 62 | * Train difficult sequence labeling task first and freeze the downstream weights of classification task until the validation set accuracy reaches a default threshold. 63 | * MultiLoss mentioned in [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115.pdf) is modified to calculate the multi-loss of classification task and sequence labeling task. The optimizer is inherit from the previous stage to train both tasks at the same time. 64 | 65 | ## Project Framework 66 | 67 | ```bash 68 | keras_bert_kbqa 69 | ├── helper.py # help file for training paramters 70 | ├── __init__.py 71 | ├── predict.py # load the trained model 72 | ├── train.py # train and save model 73 | └── utils 74 | ├── bert.py # keras implementation of bert model 75 | ├── callbacks.py # EarlyStopping,ReduceLROnPlateau,TaskSwitch 76 | ├── decoder.py # Viterbi decoder of sequence labeling task 77 | ├── graph_builder.py # neo4j graph database processing function, not used 78 | ├── __init__.py 79 | ├── metrics.py # Crf_accuracy and Crf_loss of sequence labeling task, support mask 80 | ├── models.py # support textcnn and dense for classification task, idcnn-crf and bilstm-crf for sequence labeling task 81 | ├── processor.py # Standardized training/validation dataset 82 | └── tokenizer.py # tokenizer of bert model 83 | ``` 84 | 85 | ## Dependencies 86 | 87 | * flask == 1.1.1 88 | * keras == 2.3.1 89 | * numpy == 1.18.1 90 | * loguru == 0.4.1 91 | * requests == 2.22.0 92 | * termcolor == 1.1.0 93 | * tensorflow == 1.15.2 94 | * keras_contrib == 2.0.8 95 | 96 | ## Case: Douban Movie Review 97 | 98 | * This project mainly focuses on the **relational entity extraction** part of the **standard question and answer (QA)** task. For the rest, it is implemented in a relatively crude way (not using a graph database). 99 | * This case mimics the online case in engineering practice, which exist problems of **low amount of query data** and **difference between the data generated by `template + filling` method and actural data**. 100 | * Test results show that in the case of low data volume: 101 |    * **Generalization error is large**, the use of neural network models alone does not work well. It should be used in combination with regular expressions and rules. 102 |    * **Models are difficult to train**. 103 | 104 | ### Case Framework 105 | 106 | ```bash 107 | examples 108 | ├── data 109 | │   ├── build_upload.py # generate training/validation data from raw data 110 | │   ├── data 111 | │ │ ├── database.txt # database generated from raw data for query result retrieval (not using graph database) 112 | │   │   ├── dev_data.txt # validation data 113 | │   │   ├── prior_check.txt # double check, correcting the errors of entities obtained by nn model 114 | │   │   └── train_data.txt # training data 115 | │   ├── origin_data 116 | │   │   └── douban_movies.txt # raw data 117 | │   └── templates 118 | │   ├── neo4j_config.txt # configs of graph database, not used 119 | │   ├── text_templates.txt # templates for generating training/validation data 120 | │   └── utter_search.txt # query result retrieval instructions(crude impletementation, not using graph database) 121 | ├── deploy # deploy a trained model for use 122 | │   ├── run_deploy.py 123 | │   └── run_deploy.sh 124 | ├── models # model save path 125 | │   ├── ALBERT-IDCNN-CRF.h5 126 | │   ├── id_to_label.pkl 127 | │   ├── id_to_tag.pkl 128 | │   ├── label_to_id.pkl 129 | │   ├── model_configs.json 130 | │   └── tag_to_id.pkl 131 | └── train # train a new model 132 | ├── run_train.py 133 | ├── run_train.sh 134 | └── train_config.json # train configs 135 | 136 | ``` 137 | 138 | ### Data Format 139 | 140 | The form of training/validation data is `[text information, category information, sequence labeling information]`, shown as follows: 141 | 142 | ```json 143 | [ 144 | [ 145 | "骗中骗的评分高吗", 146 | "豆瓣评分", 147 | "B-title I-title I-title O O O O O" 148 | ], 149 | [ 150 | "安东尼娅家族啥时候上映的呀", 151 | "电影上映时间是什么", 152 | "B-title I-title I-title I-title I-title I-title O O O O O O O" 153 | ], 154 | ... 155 | ] 156 | ``` 157 | 158 | ### Some Tricks for Setting Training Parameters 159 | 160 | This part is located in `examples/train/train_config.json`: 161 | 162 | * The sentence length parameter `max_len` should be adapted to the length of the training/validation text. Excessively long sentence length will occupy a large amount of video memory and have a large impact on the convergence of the sequence labeling task. 163 | * ALBERT model is easier to train than BERT model in low data volume scene, and the performance has no significant difference compared with BERT model. 164 | * `all_train_threshold` indicates that when the validation accuracy of the sequence labeling task reaches this value, both the classification task and the sequence labeling task are trained: 165 | * If it is too small, the sequence labeling task cannot converge, and the classification task is prone to over-fitting. 166 | * If it is too large, the classification task is prone to under-fitting. 167 | * The recommended value is between 0.9 and 0.98. 168 | * `clf_type` can be `textcnn` and `dense`: 169 | * When it is `textcnn`, the rest parameters are `dense_units`, `dropout_rate`, `filters` and `kernel_size`. 170 | * When it is `dense`, the rest parameter is `dense_units`. 171 | * `ner_type`can be `idcnn` and `bilstm`: 172 | * When it is `idcnn`, the rest parameters are `filters`, `kernel_size` and `blocks`. 173 | * When it is `bilstm`, the rest parameters are `units`, `num_hidden_layers` and `dropout_rate`. 174 | 175 | ### Implementation Process 176 | 177 | ```bash 178 | python examples/data/build_upload.py # generate all files in examples/data/data 179 | bash examples/train/run_train.sh # train a new model 180 | bash examples/deploy/run_deploy.sh # deploy a trained model for use 181 | ``` 182 | 183 | ### Usage 184 | 185 | Send a request to API: 186 | 187 | ```python 188 | import requests 189 | 190 | r = requests.post( 191 | "http://your_ip:your_port/query", 192 | json={ 193 | "text": "大话西游之大圣娶亲是最近刚上的电影吗"}) 194 | 195 | print(r.text) 196 | ``` 197 | 198 | Returns: 199 | 200 | ```json 201 | { 202 | "text": "大话西游之大圣娶亲是最近刚上的", 203 | "predicate": "电影上映时间是什么", 204 | "subject": [ 205 | { 206 | "title": "大话西游之大圣娶亲" 207 | } 208 | ], 209 | "response": "2014" 210 | } 211 | ``` 212 | 213 | ## Future Work 214 | 215 | * Optimize model structure to make it easier to train. 216 | * Try to handle more complex KBQA scenarios. 217 | * Improve some details. 218 | * Migrate to tensorflow 2.0. 219 | * Add other BERTs models, like Distill_Bert, Tiny_Bert. 220 | 221 | ## Some Chinese Pretrained Language Model 222 | 223 | > **BERT** 224 | * [Google_bert](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) 225 | * [HIT_bert_wwm_ext](https://storage.googleapis.com/chineseglue/pretrain_models/chinese_wwm_ext_L-12_H-768_A-12.zip) 226 | 227 | > **ALBERT** 228 | * [Google_albert_base](https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz) 229 | * [Google_albert_large](https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz) 230 | * [Google_albert_xlarge](https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz) 231 | * [Google_albert_xxlarge](https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz) 232 | * [Xuliang_albert_xlarge](https://storage.googleapis.com/albert_zh/albert_xlarge_zh_177k.zip) 233 | * [Xuliang_albert_large](https://storage.googleapis.com/albert_zh/albert_large_zh.zip) 234 | * [Xuliang_albert_base](https://storage.googleapis.com/albert_zh/albert_base_zh.zip) 235 | * [Xuliang_albert_base_ext](https://storage.googleapis.com/albert_zh/albert_base_zh_additional_36k_steps.zip) 236 | * [Xuliang_albert_small](https://storage.googleapis.com/albert_zh/albert_small_zh_google.zip) 237 | * [Xuliang_albert_tiny](https://storage.googleapis.com/albert_zh/albert_tiny_zh_google.zip) 238 | 239 | > **Roberta** 240 | * [roberta](https://storage.googleapis.com/chineseglue/pretrain_models/roeberta_zh_L-24_H-1024_A-16.zip) 241 | * [roberta_wwm_ext](https://storage.googleapis.com/chineseglue/pretrain_models/chinese_roberta_wwm_ext_L-12_H-768_A-12.zip) 242 | * [roberta_wwm_ext_large](https://storage.googleapis.com/chineseglue/pretrain_models/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.zip) 243 | 244 | ## Reference 245 | 246 | * [Keras-Bert-Ner](https://github.com/liushaoweihua/keras-bert-ner) 247 | * [bert4keras](https://github.com/bojone/bert4keras) 248 | * [albert_zh](https://github.com/brightmart/albert_zh) 249 | * [BERT](https://github.com/google-research/bert), [ALBERT](https://github.com/google-research/albert), [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)。 250 | 251 | Thanks for all these wonderful works! -------------------------------------------------------------------------------- /keras_bert_kbqa/utils/bert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Shaoweihua.Liu 5 | @Contact: liushaoweihua@126.com 6 | @Site: github.com/liushaoweihua 7 | @File: bert.py 8 | @Time: 2020/3/3 10:37 AM 9 | """ 10 | 11 | # Codes come from : 12 | # Author: Jianlin Su 13 | # Github: https://github.com/bojone/bert4keras 14 | # Site: kexue.fm 15 | # Version: 0.2.5 16 | 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import json 24 | import keras 25 | import numpy as np 26 | import tensorflow as tf 27 | import keras.backend as K 28 | from keras import initializers, activations 29 | from keras.layers import * 30 | from keras.models import Model 31 | from .tokenizer import is_string 32 | 33 | 34 | def sequence_masking(x, mask, mode=0, axis=None, heads=1): 35 | """为序列条件mask的函数 36 | mask: 形如(batch_size, sequence)的0-1矩阵; 37 | mode: 如果是0,则直接乘以mask; 38 | 如果是1,则在padding部分减去一个大正数。 39 | axis: 序列所在轴,默认为1; 40 | heads: 相当于batch这一维要被重复的次数。 41 | """ 42 | if mask is None or mode not in [0, 1]: 43 | return x 44 | else: 45 | if heads is not 1: 46 | mask = K.expand_dims(mask, 1) 47 | mask = K.tile(mask, (1, heads, 1)) 48 | mask = K.reshape(mask, (-1, K.shape(mask)[2])) 49 | if axis is None: 50 | axis = 1 51 | if axis == -1: 52 | axis = K.ndim(x) - 1 53 | assert axis > 0, "axis must be greater than 0" 54 | for _ in range(axis - 1): 55 | mask = K.expand_dims(mask, 1) 56 | for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1): 57 | mask = K.expand_dims(mask, K.ndim(mask)) 58 | if mode == 0: 59 | return x * mask 60 | else: 61 | return x - (1 - mask) * 1e12 62 | 63 | 64 | class MultiHeadAttention(Layer): 65 | """多头注意力机制 66 | """ 67 | 68 | def __init__(self, 69 | heads, 70 | head_size, 71 | key_size=None, 72 | kernel_initializer="glorot_uniform", 73 | **kwargs): 74 | super(MultiHeadAttention, self).__init__(**kwargs) 75 | self.heads = heads 76 | self.head_size = head_size 77 | self.out_dim = heads * head_size 78 | self.key_size = key_size if key_size else head_size 79 | self.kernel_initializer = initializers.get(kernel_initializer) 80 | 81 | def build(self, input_shape): 82 | super(MultiHeadAttention, self).build(input_shape) 83 | self.q_dense = Dense(units=self.key_size * self.heads, 84 | kernel_initializer=self.kernel_initializer) 85 | self.k_dense = Dense(units=self.key_size * self.heads, 86 | kernel_initializer=self.kernel_initializer) 87 | self.v_dense = Dense(units=self.out_dim, 88 | kernel_initializer=self.kernel_initializer) 89 | self.o_dense = Dense(units=self.out_dim, 90 | kernel_initializer=self.kernel_initializer) 91 | 92 | def call(self, inputs, q_mask=False, v_mask=False, a_mask=False): 93 | """实现多头注意力 94 | q_mask: 对输入的query序列的mask。 95 | 主要是将输出结果的padding部分置0。 96 | v_mask: 对输入的value序列的mask。 97 | 主要是防止attention读取到padding信息。 98 | a_mask: 对attention矩阵的mask。 99 | 不同的attention mask对应不同的应用。 100 | """ 101 | q, k, v = inputs[:3] 102 | # 处理mask 103 | idx = 3 104 | if q_mask: 105 | q_mask = inputs[idx] 106 | idx += 1 107 | else: 108 | q_mask = None 109 | if v_mask: 110 | v_mask = inputs[idx] 111 | idx += 1 112 | else: 113 | v_mask = None 114 | if a_mask: 115 | if len(inputs) > idx: 116 | a_mask = inputs[idx] 117 | else: 118 | a_mask = "history_only" 119 | else: 120 | a_mask = None 121 | # 线性变换 122 | qw = self.q_dense(q) 123 | kw = self.k_dense(k) 124 | vw = self.v_dense(v) 125 | # 形状变换 126 | qw = K.reshape(qw, (-1, K.shape(q)[1], self.heads, self.key_size)) 127 | kw = K.reshape(kw, (-1, K.shape(k)[1], self.heads, self.key_size)) 128 | vw = K.reshape(vw, (-1, K.shape(v)[1], self.heads, self.head_size)) 129 | # 维度置换 130 | qw = K.permute_dimensions(qw, (0, 2, 1, 3)) 131 | kw = K.permute_dimensions(kw, (0, 2, 1, 3)) 132 | vw = K.permute_dimensions(vw, (0, 2, 1, 3)) 133 | # 转为三阶张量 134 | qw = K.reshape(qw, (-1, K.shape(q)[1], self.key_size)) 135 | kw = K.reshape(kw, (-1, K.shape(k)[1], self.key_size)) 136 | vw = K.reshape(vw, (-1, K.shape(v)[1], self.head_size)) 137 | # Attention 138 | a = K.batch_dot(qw, kw, [2, 2]) / self.key_size ** 0.5 139 | a = sequence_masking(a, v_mask, 1, -1, self.heads) 140 | if a_mask is not None: 141 | if is_string(a_mask) and a_mask == "history_only": 142 | ones = K.ones_like(a[:1]) 143 | a_mask = (ones - tf.linalg.band_part(ones, -1, 0)) * 1e12 144 | a = a - a_mask 145 | else: 146 | a = a - (1 - a_mask) * 1e12 147 | a = K.softmax(a) 148 | # 完成输出 149 | o = K.batch_dot(a, vw, [2, 1]) 150 | o = K.reshape(o, (-1, self.heads, K.shape(q)[1], self.head_size)) 151 | o = K.permute_dimensions(o, (0, 2, 1, 3)) 152 | o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim)) 153 | o = self.o_dense(o) 154 | o = sequence_masking(o, q_mask, 0) 155 | return o 156 | 157 | def compute_output_shape(self, input_shape): 158 | return (input_shape[0][0], input_shape[0][1], self.out_dim) 159 | 160 | def get_config(self): 161 | config = { 162 | "heads": self.heads, 163 | "head_size": self.head_size, 164 | "key_size": self.key_size, 165 | "kernel_initializer": initializers.serialize(self.kernel_initializer), 166 | } 167 | base_config = super(MultiHeadAttention, self).get_config() 168 | return dict(list(base_config.items()) + list(config.items())) 169 | 170 | 171 | class LayerNormalization(Layer): 172 | """实现基本的Layer Norm,只保留核心运算部分 173 | """ 174 | 175 | def __init__(self, **kwargs): 176 | super(LayerNormalization, self).__init__(**kwargs) 177 | self.epsilon = K.epsilon() * K.epsilon() 178 | 179 | def build(self, input_shape): 180 | super(LayerNormalization, self).build(input_shape) 181 | shape = (input_shape[-1],) 182 | self.gamma = self.add_weight(shape=shape, 183 | initializer="ones", 184 | name="gamma") 185 | self.beta = self.add_weight(shape=shape, 186 | initializer="zeros", 187 | name="beta") 188 | 189 | def call(self, inputs): 190 | mean = K.mean(inputs, axis=-1, keepdims=True) 191 | variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True) 192 | std = K.sqrt(variance + self.epsilon) 193 | outputs = (inputs - mean) / std 194 | outputs *= self.gamma 195 | outputs += self.beta 196 | return outputs 197 | 198 | 199 | class PositionEmbedding(Layer): 200 | """定义位置Embedding,这里的Embedding是可训练的。 201 | """ 202 | 203 | def __init__(self, 204 | input_dim, 205 | output_dim, 206 | merge_mode="add", 207 | embeddings_initializer="zeros", 208 | **kwargs): 209 | super(PositionEmbedding, self).__init__(**kwargs) 210 | self.input_dim = input_dim 211 | self.output_dim = output_dim 212 | self.merge_mode = merge_mode 213 | self.embeddings_initializer = initializers.get(embeddings_initializer) 214 | 215 | def build(self, input_shape): 216 | super(PositionEmbedding, self).build(input_shape) 217 | self.embeddings = self.add_weight( 218 | name="embeddings", 219 | shape=(self.input_dim, self.output_dim), 220 | initializer=self.embeddings_initializer, 221 | ) 222 | 223 | def call(self, inputs): 224 | input_shape = K.shape(inputs) 225 | batch_size, seq_len = input_shape[0], input_shape[1] 226 | pos_embeddings = self.embeddings[:seq_len] 227 | pos_embeddings = K.expand_dims(pos_embeddings, 0) 228 | pos_embeddings = K.tile(pos_embeddings, [batch_size, 1, 1]) 229 | if self.merge_mode == "add": 230 | return inputs + pos_embeddings 231 | else: 232 | return K.concatenate([inputs, pos_embeddings]) 233 | 234 | def compute_output_shape(self, input_shape): 235 | if self.merge_mode == "add": 236 | return input_shape 237 | else: 238 | return input_shape[:2] + (input_shape[2] + self.v_dim,) 239 | 240 | def get_config(self): 241 | config = { 242 | "input_dim": self.input_dim, 243 | "output_dim": self.output_dim, 244 | "merge_mode": self.merge_mode, 245 | "embeddings_initializer": initializers.serialize(self.embeddings_initializer), 246 | } 247 | base_config = super(PositionEmbedding, self).get_config() 248 | return dict(list(base_config.items()) + list(config.items())) 249 | 250 | 251 | class FeedForward(Layer): 252 | """FeedForward层,其实就是两个Dense层的叠加 253 | """ 254 | 255 | def __init__(self, 256 | units, 257 | activation="relu", 258 | kernel_initializer="glorot_uniform", 259 | **kwargs): 260 | super(FeedForward, self).__init__(**kwargs) 261 | self.units = units 262 | self.activation = activations.get(activation) 263 | self.kernel_initializer = initializers.get(kernel_initializer) 264 | 265 | def build(self, input_shape): 266 | super(FeedForward, self).build(input_shape) 267 | output_dim = input_shape[-1] 268 | self.dense_1 = Dense(units=self.units, 269 | activation=self.activation, 270 | kernel_initializer=self.kernel_initializer) 271 | self.dense_2 = Dense(units=output_dim, 272 | kernel_initializer=self.kernel_initializer) 273 | 274 | def call(self, inputs): 275 | x = self.dense_1(inputs) 276 | x = self.dense_2(x) 277 | return x 278 | 279 | def get_config(self): 280 | config = { 281 | "units": self.units, 282 | "activation": activations.serialize(self.activation), 283 | "kernel_initializer": initializers.serialize(self.kernel_initializer), 284 | } 285 | base_config = super(FeedForward, self).get_config() 286 | return dict(list(base_config.items()) + list(config.items())) 287 | 288 | 289 | class EmbeddingDense(Layer): 290 | """运算跟Dense一致,但kernel用Embedding层的embeddings矩阵。 291 | 根据Embedding层的名字来搜索定位Embedding层。 292 | """ 293 | 294 | def __init__(self, embedding_name, activation="softmax", **kwargs): 295 | super(EmbeddingDense, self).__init__(**kwargs) 296 | self.embedding_name = embedding_name 297 | self.activation = activations.get(activation) 298 | 299 | def call(self, inputs): 300 | if not hasattr(self, "kernel"): 301 | embedding_layer = inputs._keras_history[0] 302 | 303 | if embedding_layer.name != self.embedding_name: 304 | 305 | def recursive_search(layer): 306 | """递归向上搜索,根据名字找Embedding层 307 | """ 308 | last_layer = layer._inbound_nodes[0].inbound_layers 309 | if isinstance(last_layer, list): 310 | if len(last_layer) == 0: 311 | return None 312 | else: 313 | last_layer = last_layer[0] 314 | if last_layer.name == self.embedding_name: 315 | return last_layer 316 | else: 317 | return recursive_search(last_layer) 318 | 319 | embedding_layer = recursive_search(embedding_layer) 320 | if embedding_layer is None: 321 | raise Exception("Embedding layer not found") 322 | 323 | self.kernel = K.transpose(embedding_layer.embeddings) 324 | self.units = K.int_shape(self.kernel)[1] 325 | self.bias = self.add_weight(name="bias", 326 | shape=(self.units,), 327 | initializer="zeros") 328 | 329 | outputs = K.dot(inputs, self.kernel) 330 | outputs = K.bias_add(outputs, self.bias) 331 | outputs = self.activation(outputs) 332 | return outputs 333 | 334 | def compute_output_shape(self, input_shape): 335 | return input_shape[:-1] + (self.units,) 336 | 337 | def get_config(self): 338 | config = { 339 | "embedding_name": self.embedding_name, 340 | "activation": activations.serialize(self.activation), 341 | } 342 | base_config = super(EmbeddingDense, self).get_config() 343 | return dict(list(base_config.items()) + list(config.items())) 344 | 345 | 346 | class BertModel: 347 | """构建跟Bert一样结构的Transformer-based模型 348 | 这是一个比较多接口的基础类,然后通过这个基础类衍生出更复杂的模型 349 | """ 350 | 351 | def __init__(self, 352 | vocab_size, # 词表大小 353 | max_position_embeddings, # 序列最大长度 354 | hidden_size, # 编码维度 355 | num_hidden_layers, # Transformer总层数 356 | num_attention_heads, # Attention的头数 357 | intermediate_size, # FeedForward的隐层维度 358 | hidden_act, # FeedForward隐层的激活函数 359 | dropout_rate, # Dropout比例 360 | initializer_range=None, # 权重初始化方差 361 | embedding_size=None, # 是否指定embedding_size 362 | with_pool=False, # 是否包含Pool部分 363 | with_nsp=False, # 是否包含NSP部分 364 | with_mlm=False, # 是否包含MLM部分 365 | keep_words=None, # 要保留的词ID列表 366 | block_sharing=False, # 是否共享同一个transformer block 367 | ): 368 | if keep_words is None: 369 | self.vocab_size = vocab_size 370 | else: 371 | self.vocab_size = len(keep_words) 372 | self.max_position_embeddings = max_position_embeddings 373 | self.hidden_size = hidden_size 374 | self.num_hidden_layers = num_hidden_layers 375 | self.num_attention_heads = num_attention_heads 376 | self.attention_head_size = hidden_size // num_attention_heads 377 | self.intermediate_size = intermediate_size 378 | self.dropout_rate = dropout_rate 379 | if initializer_range: 380 | self.initializer_range = initializer_range 381 | else: 382 | self.initializer_range = 0.02 383 | if embedding_size: 384 | self.embedding_size = embedding_size 385 | else: 386 | self.embedding_size = hidden_size 387 | self.with_pool = with_pool 388 | self.with_nsp = with_nsp 389 | self.with_mlm = with_mlm 390 | self.hidden_act = hidden_act 391 | self.keep_words = keep_words 392 | self.block_sharing = block_sharing 393 | self.additional_outputs = [] 394 | 395 | def build(self): 396 | """Bert模型构建函数 397 | """ 398 | x_in = Input(shape=(None,), name="Input-Token") 399 | s_in = Input(shape=(None,), name="Input-Segment") 400 | x, s = x_in, s_in 401 | 402 | # 自行构建Mask 403 | sequence_mask = Lambda(lambda x: K.cast(K.greater(x, 0), "float32"), 404 | name="Sequence-Mask")(x) 405 | 406 | # Embedding部分 407 | x = Embedding(input_dim=self.vocab_size, 408 | output_dim=self.embedding_size, 409 | embeddings_initializer=self.initializer, 410 | name="Embedding-Token")(x) 411 | s = Embedding(input_dim=2, 412 | output_dim=self.embedding_size, 413 | embeddings_initializer=self.initializer, 414 | name="Embedding-Segment")(s) 415 | x = Add(name="Embedding-Token-Segment")([x, s]) 416 | x = PositionEmbedding(input_dim=self.max_position_embeddings, 417 | output_dim=self.embedding_size, 418 | merge_mode="add", 419 | embeddings_initializer=self.initializer, 420 | name="Embedding-Position")(x) 421 | x = LayerNormalization(name="Embedding-Norm")(x) 422 | if self.dropout_rate > 0: 423 | x = Dropout(rate=self.dropout_rate, name="Embedding-Dropout")(x) 424 | if self.embedding_size != self.hidden_size: 425 | x = Dense(units=self.hidden_size, 426 | kernel_initializer=self.initializer, 427 | name="Embedding-Mapping")(x) 428 | 429 | # 主要Transformer部分 430 | layers = None 431 | for i in range(self.num_hidden_layers): 432 | attention_name = "Encoder-%d-MultiHeadSelfAttention" % (i + 1) 433 | feed_forward_name = "Encoder-%d-FeedForward" % (i + 1) 434 | x, layers = self.transformer_block( 435 | inputs=x, 436 | sequence_mask=sequence_mask, 437 | attention_mask=self.compute_attention_mask(i, s_in), 438 | attention_name=attention_name, 439 | feed_forward_name=feed_forward_name, 440 | input_layers=layers) 441 | x = self.post_processing(i, x) 442 | if not self.block_sharing: 443 | layers = None 444 | 445 | outputs = [x] 446 | 447 | if self.with_pool: 448 | # Pooler部分(提取CLS向量) 449 | x = outputs[0] 450 | x = Lambda(lambda x: x[:, 0], name="Pooler")(x) 451 | x = Dense(units=self.hidden_size, 452 | activation="tanh", 453 | kernel_initializer=self.initializer, 454 | name="Pooler-Dense")(x) 455 | if self.with_nsp: 456 | # Next Sentence Prediction 部分 457 | x = Dense(units=2, 458 | activation="softmax", 459 | kernel_initializer=self.initializer, 460 | name="NSP-Proba")(x) 461 | outputs.append(x) 462 | 463 | if self.with_mlm: 464 | # Masked Language Model 部分 465 | x = outputs[0] 466 | x = Dense(units=self.embedding_size, 467 | activation=self.hidden_act, 468 | kernel_initializer=self.initializer, 469 | name="MLM-Dense")(x) 470 | x = LayerNormalization(name="MLM-Norm")(x) 471 | x = EmbeddingDense(embedding_name="Embedding-Token", 472 | name="MLM-Proba")(x) 473 | outputs.append(x) 474 | 475 | outputs += self.additional_outputs 476 | if len(outputs) == 1: 477 | outputs = outputs[0] 478 | elif len(outputs) == 2: 479 | outputs = outputs[1] 480 | else: 481 | outputs = outputs[1:] 482 | 483 | self.model = Model([x_in, s_in], outputs) 484 | 485 | def transformer_block(self, 486 | inputs, 487 | sequence_mask, 488 | attention_mask=None, 489 | attention_name="attention", 490 | feed_forward_name="feed-forword", 491 | input_layers=None): 492 | """构建单个Transformer Block 493 | 如果没有传入input_layers则新建层;如果传入则重用旧层""" 494 | x = inputs 495 | if input_layers is None: 496 | layers = [ 497 | MultiHeadAttention(heads=self.num_attention_heads, 498 | head_size=self.attention_head_size, 499 | kernel_initializer=self.initializer, 500 | name=attention_name), 501 | Dropout(rate=self.dropout_rate, 502 | name="%s-Dropout" % attention_name), 503 | Add(name="%s-Add" % attention_name), 504 | LayerNormalization(name="%s-Norm" % attention_name), 505 | FeedForward(units=self.intermediate_size, 506 | activation=self.hidden_act, 507 | kernel_initializer=self.initializer, 508 | name=feed_forward_name), 509 | Dropout(rate=self.dropout_rate, 510 | name="%s-Dropout" % feed_forward_name), 511 | Add(name="%s-Add" % feed_forward_name), 512 | LayerNormalization(name="%s-Norm" % feed_forward_name) 513 | ] 514 | else: 515 | layers = input_layers 516 | # Self Attention 517 | xi = x 518 | if attention_mask is None: 519 | x = layers[0]([x, x, x, sequence_mask], v_mask=True) 520 | elif attention_mask is "history_only": 521 | x = layers[0]([x, x, x, sequence_mask], v_mask=True) 522 | else: 523 | x = layers[0]([x, x, x, sequence_mask, attention_mask], 524 | v_mask=True, 525 | a_mask=True) 526 | if self.dropout_rate > 0: 527 | x = layers[1](x) 528 | x = layers[2]([xi, x]) 529 | x = layers[3](x) 530 | # Feed Forward 531 | xi = xx = layers[4](x) 532 | if self.dropout_rate > 0: 533 | x = layers[5](x) 534 | x = layers[6]([xi, x]) 535 | x = layers[7](x) 536 | return x, layers 537 | 538 | def compute_attention_mask(self, layer_id, segment_ids): 539 | """定义每一层的Attention Mask,来实现不同的功能 540 | """ 541 | return None 542 | 543 | def post_processing(self, layer_id, inputs): 544 | """自定义每一个block的后处理操作 545 | """ 546 | return inputs 547 | 548 | @property 549 | def initializer(self): 550 | """默认使用截断正态分布初始化 551 | """ 552 | return keras.initializers.TruncatedNormal( 553 | stddev=self.initializer_range) 554 | 555 | def load_weights_from_checkpoint(self, checkpoint_file): 556 | """从预训练好的Bert的checkpoint中加载权重 557 | 为了简化写法,对变量名的匹配引入了一定的模糊匹配能力。 558 | """ 559 | model = self.model 560 | load_variable = lambda name: tf.train.load_variable(checkpoint_file, name) 561 | variable_names = [n[0] for n in tf.train.list_variables(checkpoint_file)] 562 | variable_names = [n for n in variable_names if "adam" not in n] 563 | 564 | def similarity(a, b, n=4): 565 | # 基于n-grams的jaccard相似度 566 | a = set([a[i: i + n] for i in range(len(a) - n)]) 567 | b = set([b[i: i + n] for i in range(len(b) - n)]) 568 | a_and_b = a & b 569 | if not a_and_b: 570 | return 0. 571 | a_or_b = a | b 572 | return 1. * len(a_and_b) / len(a_or_b) 573 | 574 | def loader(name): 575 | sims = [similarity(name, n) for n in variable_names] 576 | found_name = variable_names.pop(np.argmax(sims)) 577 | print("==> searching: %s, found name: %s" % (name, found_name)) 578 | return load_variable(found_name) 579 | 580 | if self.keep_words is None: 581 | keep_words = slice(0, None) 582 | else: 583 | keep_words = self.keep_words 584 | 585 | model.get_layer(name="Embedding-Token").set_weights( 586 | [loader("bert/embeddings/word_embeddings")[keep_words]]) 587 | model.get_layer(name="Embedding-Position").set_weights( 588 | [loader("bert/embeddings/position_embeddings")]) 589 | model.get_layer(name="Embedding-Segment").set_weights( 590 | [loader("bert/embeddings/token_type_embeddings")]) 591 | model.get_layer(name="Embedding-Norm").set_weights( 592 | [loader("bert/embeddings/LayerNorm/gamma"), 593 | loader("bert/embeddings/LayerNorm/beta")]) 594 | if self.embedding_size != self.hidden_size: 595 | model.get_layer(name="Embedding-Mapping").set_weights( 596 | [loader("bert/encoder/embedding_hidden_mapping_in/kernel"), 597 | loader("bert/encoder/embedding_hidden_mapping_in/bias")]) 598 | 599 | for i in range(self.num_hidden_layers): 600 | try: 601 | model.get_layer(name="Encoder-%d-MultiHeadSelfAttention" % (i + 1)) 602 | except ValueError: 603 | continue 604 | if ("bert/encoder/layer_%d/attention/self/query/kernel" % i) in variable_names: 605 | layer_name = "layer_%d" % i 606 | else: 607 | layer_name = "transformer/group_0/inner_group_0" 608 | model.get_layer(name="Encoder-%d-MultiHeadSelfAttention" % (i + 1)).set_weights( 609 | [loader("bert/encoder/%s/attention/self/query/kernel" % layer_name), 610 | loader("bert/encoder/%s/attention/self/query/bias" % layer_name), 611 | loader("bert/encoder/%s/attention/self/key/kernel" % layer_name), 612 | loader("bert/encoder/%s/attention/self/key/bias" % layer_name), 613 | loader("bert/encoder/%s/attention/self/value/kernel" % layer_name), 614 | loader("bert/encoder/%s/attention/self/value/bias" % layer_name), 615 | loader("bert/encoder/%s/attention/output/dense/kernel" % layer_name), 616 | loader("bert/encoder/%s/attention/output/dense/bias" % layer_name)]) 617 | model.get_layer(name="Encoder-%d-MultiHeadSelfAttention-Norm" % (i + 1)).set_weights( 618 | [loader("bert/encoder/%s/attention/output/LayerNorm/gamma" % layer_name), 619 | loader("bert/encoder/%s/attention/output/LayerNorm/beta" % layer_name)]) 620 | model.get_layer(name="Encoder-%d-FeedForward" % (i + 1)).set_weights( 621 | [loader("bert/encoder/%s/intermediate/dense/kernel" % layer_name), 622 | loader("bert/encoder/%s/intermediate/dense/bias" % layer_name), 623 | loader("bert/encoder/%s/output/dense/kernel" % layer_name), 624 | loader("bert/encoder/%s/output/dense/bias" % layer_name)]) 625 | model.get_layer(name="Encoder-%d-FeedForward-Norm" % (i + 1)).set_weights( 626 | [loader("bert/encoder/%s/output/LayerNorm/gamma" % layer_name), 627 | loader("bert/encoder/%s/output/LayerNorm/beta" % layer_name)]) 628 | 629 | if self.with_pool: 630 | model.get_layer(name="Pooler-Dense").set_weights( 631 | [loader("bert/pooler/dense/kernel"), 632 | loader("bert/pooler/dense/bias")]) 633 | if self.with_nsp: 634 | model.get_layer(name="NSP-Proba").set_weights( 635 | [loader("cls/seq_relationship/output_weights").T, 636 | loader("cls/seq_relationship/output_bias")]) 637 | 638 | if self.with_mlm: 639 | model.get_layer(name="MLM-Dense").set_weights( 640 | [loader("cls/predictions/transform/dense/kernel"), 641 | loader("cls/predictions/transform/dense/bias")]) 642 | model.get_layer(name="MLM-Norm").set_weights( 643 | [loader("cls/predictions/transform/LayerNorm/gamma"), 644 | loader("cls/predictions/transform/LayerNorm/beta")]) 645 | model.get_layer(name="MLM-Proba").set_weights( 646 | [loader("cls/predictions/output_bias")[keep_words]]) 647 | 648 | 649 | class Bert4Seq2seq(BertModel): 650 | """用来做seq2seq任务的Bert 651 | """ 652 | 653 | def __init__(self, *args, **kwargs): 654 | super(Bert4Seq2seq, self).__init__(*args, **kwargs) 655 | self.with_pool = False 656 | self.with_nsp = False 657 | self.with_mlm = True 658 | self.attention_mask = None 659 | 660 | def compute_attention_mask(self, layer_id, segment_ids): 661 | """为seq2seq采用特定的attention mask 662 | """ 663 | if self.attention_mask is None: 664 | def seq2seq_attention_mask(s, repeats=1): 665 | seq_len = K.shape(s)[1] 666 | ones = K.ones((1, repeats, seq_len, seq_len)) 667 | a_mask = tf.linalg.band_part(ones, -1, 0) 668 | s_ex12 = K.expand_dims(K.expand_dims(s, 1), 2) 669 | s_ex13 = K.expand_dims(K.expand_dims(s, 1), 3) 670 | a_mask = (1 - s_ex13) * (1 - s_ex12) + s_ex13 * a_mask 671 | a_mask = K.reshape(a_mask, (-1, seq_len, seq_len)) 672 | return a_mask 673 | 674 | self.attention_mask = Lambda( 675 | seq2seq_attention_mask, 676 | arguments={"repeats": self.num_attention_heads}, 677 | name="Attention-Mask")(segment_ids) 678 | 679 | return self.attention_mask 680 | 681 | 682 | class Bert4LM(BertModel): 683 | """用来做语言模型任务的Bert 684 | """ 685 | 686 | def __init__(self, *args, **kwargs): 687 | super(Bert4LM, self).__init__(*args, **kwargs) 688 | self.with_pool = False 689 | self.with_nsp = False 690 | self.with_mlm = True 691 | self.attention_mask = "history_only" 692 | 693 | def compute_attention_mask(self, layer_id, segment_ids): 694 | return self.attention_mask 695 | 696 | 697 | def build_bert_model(config_path, 698 | checkpoint_path=None, 699 | with_pool=False, 700 | with_nsp=False, 701 | with_mlm=False, 702 | application="encoder", 703 | keep_words=None, 704 | albert=False, 705 | return_keras_model=True): 706 | """根据配置文件构建bert模型,可选加载checkpoint权重 707 | """ 708 | config = json.load(open(config_path)) 709 | mapping = { 710 | "encoder": BertModel, 711 | "seq2seq": Bert4Seq2seq, 712 | "lm": Bert4LM 713 | } 714 | 715 | assert application in mapping, "application must be one of %s" % list(mapping.keys()) 716 | Bert = mapping[application] 717 | 718 | bert = Bert(vocab_size=config["vocab_size"], 719 | max_position_embeddings=config["max_position_embeddings"], 720 | hidden_size=config["hidden_size"], 721 | num_hidden_layers=config["num_hidden_layers"], 722 | num_attention_heads=config["num_attention_heads"], 723 | intermediate_size=config["intermediate_size"], 724 | hidden_act=config["hidden_act"], 725 | dropout_rate=config["hidden_dropout_prob"], 726 | initializer_range=config.get("initializer_range"), 727 | embedding_size=config.get("embedding_size"), 728 | with_pool=with_pool, 729 | with_nsp=with_nsp, 730 | with_mlm=with_mlm, 731 | keep_words=keep_words, 732 | block_sharing=albert) 733 | 734 | bert.build() 735 | 736 | if checkpoint_path is not None: 737 | bert.load_weights_from_checkpoint(checkpoint_path) 738 | 739 | if return_keras_model: 740 | return bert.model 741 | else: 742 | return bert --------------------------------------------------------------------------------