├── LICENSE ├── README.md ├── basic_extract_features ├── __init__.py └── extract_features.py ├── basis_framework ├── __init__.py └── basis_graph.py ├── classifier ├── __init__.py ├── bert_classifier.py └── bert_cnn_classifier.py ├── configs ├── __init__.py └── path_config.py ├── entity_relationship_extraction ├── __init__.py └── triplet_extraction.py ├── named_entity ├── __init__.py ├── entity_by_rules.py └── ner_handler.py ├── text_generation ├── __init__.py └── gpt2_ml.py └── utils ├── __init__.py ├── classifier_data_process.py ├── common_tools.py ├── dynamic_data_cache ├── __init__.py ├── entity.csv ├── keyword_dao.py ├── keyword_update.py └── trie_tree.py ├── logger.py ├── ner_data_process.py └── triplet_data_process.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert4keras4nlp 2 | 基于bert4keras进行nlp工作,暂时包含文本分类,实体识别,实体关系抽取 3 | -------------------------------------------------------------------------------- /basic_extract_features/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2021/1/4 11:33 5 | # ide: PyCharm 6 | -------------------------------------------------------------------------------- /basic_extract_features/extract_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2021/1/4 11:33 5 | # ide: PyCharm 6 | import os 7 | 8 | from bert4keras.backend import keras 9 | from bert4keras.models import build_transformer_model 10 | from bert4keras.snippets import to_array 11 | 12 | from basis_framework.basis_graph import BasisGraph 13 | from utils.common_tools import load_json, save_json 14 | 15 | 16 | class ExtractFeature(BasisGraph): 17 | def __init__(self, params={}, Train=False): 18 | super().__init__(params, Train) 19 | 20 | def save_params(self): 21 | self.params['max_len'] = self.max_len 22 | save_json(jsons=self.params, json_path=self.params_path) 23 | 24 | def load_params(self): 25 | load_params = load_json(self.params_path) 26 | self.max_len = load_params.get('max_len') 27 | 28 | def _set_gpu_id(self): 29 | """指定使用的GPU显卡id""" 30 | if self.gpu_id: 31 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id) 32 | 33 | def data_process(self): 34 | """ 35 | 模型框架搭建 36 | :return: 37 | """ 38 | raise NotImplementedError 39 | 40 | def build_model(self): 41 | """ 42 | 模型框架搭建 43 | :return: 44 | """ 45 | self.model = build_transformer_model(self.bert_config_path, self.bert_checkpoint_path) 46 | 47 | def compile_model(self): 48 | """ 49 | 模型框架搭建 50 | :return: 51 | """ 52 | raise NotImplementedError 53 | 54 | def extract_features(self, text: str): 55 | """ 56 | 编码测试 57 | :return: 58 | """ 59 | token_ids, segment_ids = self.tokenizer.encode(u'{}'.format(text)) 60 | token_ids, segment_ids = to_array([token_ids], [segment_ids]) 61 | print("\n === features === \n") 62 | print(self.predict([token_ids, segment_ids])) 63 | 64 | def save_model(self, model_path='test.model'): 65 | self.model.save(model_path) 66 | del self.model # 释放内存 67 | 68 | def load_model(self, model_path='test.model'): 69 | # self.model = keras.models.load_model(model_path) 70 | # self.extract_features('语言模型') 71 | pass 72 | 73 | def load_params(self): 74 | pass -------------------------------------------------------------------------------- /basis_framework/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 11:05 5 | # ide: PyCharm 6 | -------------------------------------------------------------------------------- /basis_framework/basis_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/11 17:17 5 | # ide: PyCharm 6 | import os 7 | 8 | from bert4keras.tokenizers import Tokenizer 9 | from configs.path_config import BERT_MODEL_PATH, MODEL_ROOT_PATH 10 | from utils.common_tools import load_json, save_json 11 | 12 | 13 | class BasisGraph(): 14 | def __init__(self, params={}, Train=False): 15 | self.bert_config_path = os.path.join(BERT_MODEL_PATH + "/bert_config.json") 16 | self.bert_checkpoint_path = os.path.join(BERT_MODEL_PATH + "/bert_model.ckpt") 17 | self.bert_vocab_path = os.path.join(BERT_MODEL_PATH + "/vocab.txt") 18 | self.tokenizer = Tokenizer(self.bert_vocab_path) 19 | self.model_code = params.get('model_code') 20 | if not self.model_code: raise Exception("No model code!,params must have a 'model_code'") 21 | self.MODEL_ROOT_PATH = os.path.join(MODEL_ROOT_PATH, self.model_code) 22 | if not os.path.exists(self.MODEL_ROOT_PATH):os.makedirs(self.MODEL_ROOT_PATH, exist_ok=True) 23 | self.params_path = os.path.join(self.MODEL_ROOT_PATH, 'params.json') 24 | self.model_path = os.path.join(self.MODEL_ROOT_PATH, 'best_model.weights') 25 | self.tensorboard_path = os.path.join(self.MODEL_ROOT_PATH, 'logs') 26 | self.max_len = params.get('max_len', 128) 27 | self.batch_size = params.get('batch_size', 32) 28 | self.patience = params.get('patience', 3) 29 | self.train_data_path = params.get('train_data_path') 30 | if Train and not self.train_data_path: raise Exception("No training data!") 31 | self.valid_data_path = params.get('valid_data_path') 32 | self.test_data_path = params.get('test_data_path') 33 | self.epoch = params.get('epoch', 10) 34 | self.learning_rate = params.get('learning_rate', 1e-5) # bert_layers越小,学习率应该要越大 35 | self.bert_layers = params.get('bert_layers', 12) 36 | self.crf_lr_multiplier = params.get('crf_lr_multiplier', 1000) # 必要时扩大CRF层的学习率 37 | self.gpu_id = params.get("gpu_id", None) 38 | self.activation = params.get('activation', 'softmax') # 分类激活函数,softmax或者signod 39 | self.loss = params.get('loss','sparse_categorical_crossentropy') 40 | # self.loss = params.get('loss','categorical_crossentropy') 41 | self.metrics = params.get('metrics',['accuracy']) 42 | self.split = params.get('split',0.8) # 训练/验证集划分 43 | self.dropout = params.get('dropout', 0.5) # dropout层系数,舍弃 44 | self.params = params 45 | self._set_gpu_id() # 设置训练的GPU_ID 46 | if Train: 47 | self.data_process() 48 | self.save_params() 49 | self.build_model() 50 | self.compile_model() 51 | else: 52 | self.load_params() 53 | self.build_model() 54 | self.load_model() 55 | def save_params(self): 56 | self.params['num_classes'] = self.num_classes 57 | self.params['labels'] = self.labels 58 | self.params['index2label'] = self.index2label 59 | self.params['label2index'] = self.label2index 60 | self.params['max_len'] = self.max_len 61 | save_json(jsons=self.params, json_path=self.params_path) 62 | def load_params(self): 63 | load_params = load_json(self.params_path) 64 | self.max_len = load_params.get('max_len') 65 | self.labels = load_params.get('labels') 66 | self.num_classes = load_params.get('num_classes') 67 | self.label2index = load_params.get('label2index') 68 | self.index2label = load_params.get('index2label') 69 | def _set_gpu_id(self): 70 | """指定使用的GPU显卡id""" 71 | if self.gpu_id: 72 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id) 73 | def data_process(self): 74 | """ 75 | 模型框架搭建 76 | :return: 77 | """ 78 | raise NotImplementedError 79 | def build_model(self): 80 | """ 81 | 模型框架搭建 82 | :return: 83 | """ 84 | raise NotImplementedError 85 | def compile_model(self): 86 | """ 87 | 模型框架搭建 88 | :return: 89 | """ 90 | raise NotImplementedError 91 | def train(self): 92 | """ 93 | 模型框架搭建 94 | :return: 95 | """ 96 | raise NotImplementedError 97 | def predict(self,text): 98 | """ 99 | 模型框架搭建 100 | :return: 101 | """ 102 | raise NotImplementedError 103 | def load_model(self): 104 | self.model.load_weights(self.model_path) -------------------------------------------------------------------------------- /classifier/__init__.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | # author leewfeng 3 | # 2020/12/11 21:14 -------------------------------------------------------------------------------- /classifier/bert_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 15:13 5 | # ide: PyCharm 6 | 7 | from __future__ import print_function, division 8 | 9 | import numpy as np 10 | from bert4keras.models import build_transformer_model 11 | from bert4keras.optimizers import extend_with_piecewise_linear_lr 12 | from keras.layers import Dense, Lambda 13 | from keras.models import Model 14 | from keras.optimizers import Adam 15 | 16 | from basis_framework.basis_graph import BasisGraph 17 | from configs.path_config import CORPUS_ROOT_PATH 18 | from utils.classifier_data_process import Data_Generator, Evaluator 19 | from utils.common_tools import data2csv, data_preprocess, split 20 | 21 | 22 | class BertGraph(BasisGraph): 23 | def __init__(self, params={}, Train=False): 24 | if not params.get('model_code'): 25 | params['model_code'] = 'bert_classifier' 26 | super().__init__(params, Train) 27 | 28 | def data_process(self, sep='\t'): 29 | """ 30 | 数据处理 31 | :return: 32 | """ 33 | if '.csv' not in self.train_data_path: 34 | self.train_data_path = data2csv(self.train_data_path, sep) 35 | self.index2label, self.label2index, self.labels, train_data = data_preprocess(self.train_data_path) 36 | self.num_classes = len(self.index2label) 37 | if self.valid_data_path: 38 | if '.csv' not in self.valid_data_path: 39 | self.valid_data_path = data2csv(self.valid_data_path, sep) 40 | _, _, _, valid_data = data_preprocess(self.valid_data_path) 41 | else: 42 | train_data, valid_data = split(train_data, self.split) 43 | if self.test_data_path: 44 | if '.csv' not in self.test_data_path: 45 | self.test_data_path = data2csv(self.test_data_path, sep) 46 | _, _, _, test_data = data_preprocess(self.test_data_path) 47 | else: 48 | test_data = [] 49 | self.train_generator = Data_Generator(train_data, self.label2index, self.tokenizer, self.batch_size, 50 | self.max_len) 51 | self.valid_generator = Data_Generator(valid_data, self.label2index, self.tokenizer, self.batch_size, 52 | self.max_len) 53 | self.test_generator = Data_Generator(test_data, self.label2index, self.tokenizer, self.batch_size, 54 | self.max_len) 55 | 56 | def build_model(self): 57 | bert = build_transformer_model( 58 | config_path=self.bert_config_path, 59 | checkpoint_path=self.bert_checkpoint_path, 60 | return_keras_model=False, 61 | ) 62 | output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) # 取出[cls]层对应的向量来做分类 63 | output = Dense(self.num_classes, activation=self.activation, kernel_initializer=bert.initializer)( 64 | output) # 全连接层激活函数分类 65 | self.model = Model(bert.model.input, output) 66 | print(self.model.summary(150)) 67 | 68 | def predict(self, text): 69 | token_ids, segment_ids = self.tokenizer.encode(text) 70 | pre = self.model.predict([[token_ids], [segment_ids]]) 71 | res = self.index2label.get(str(np.argmax(pre[0]))) 72 | return res 73 | 74 | def compile_model(self): 75 | # 派生为带分段线性学习率的优化器。 76 | # 其中name参数可选,但最好填入,以区分不同的派生优化器。 77 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR') 78 | self.model.compile(loss=self.loss, 79 | optimizer=AdamLR(lr=self.learning_rate, lr_schedule={ 80 | 1000: 1, 81 | 2000: 0.1 82 | }), 83 | metrics=self.metrics, ) 84 | 85 | def train(self): 86 | # 保存超参数 87 | evaluator = Evaluator(self.model, self.model_path, self.valid_generator, self.test_generator) 88 | 89 | # 模型训练 90 | self.model.fit_generator( 91 | self.train_generator.forfit(), 92 | steps_per_epoch=len(self.train_generator), 93 | epochs=self.epoch, 94 | callbacks=[evaluator], 95 | ) 96 | 97 | 98 | if __name__ == '__main__': 99 | params = { 100 | 'model_code': 'thuc_news_bert', 101 | 'train_data_path': CORPUS_ROOT_PATH + '/thuc_news/train.txt', 102 | 'valid_data_path': CORPUS_ROOT_PATH + '/thuc_news/dev.txt', 103 | 'test_data_path': CORPUS_ROOT_PATH + '/thuc_news/test.txt', 104 | 'batch_size': 128, 105 | 'max_len': 30, 106 | 'epoch': 10, 107 | 'learning_rate': 1e-5, 108 | 'gpu_id': 1, 109 | } 110 | bertModel = BertGraph(params, Train=True) 111 | bertModel.train() 112 | else: 113 | params = { 114 | 'model_code': 'thuc_news_bert', # 此处与训练时code保持一致 115 | 'gpu_id': 1, 116 | } 117 | bertModel = BertGraph(params) 118 | -------------------------------------------------------------------------------- /classifier/bert_cnn_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 15:13 5 | # ide: PyCharm 6 | 7 | from __future__ import print_function, division 8 | 9 | import numpy as np 10 | from bert4keras.models import build_transformer_model 11 | from bert4keras.optimizers import extend_with_piecewise_linear_lr 12 | from keras.layers import Dense, Dropout, Flatten, MaxPooling1D, concatenate, Conv1D 13 | from keras.models import Model 14 | from keras.optimizers import Adam 15 | 16 | from basis_framework.basis_graph import BasisGraph 17 | from configs.path_config import CORPUS_ROOT_PATH 18 | from utils.classifier_data_process import Data_Generator, Evaluator 19 | from utils.common_tools import data2csv, data_preprocess, split 20 | 21 | 22 | class BertGraph(BasisGraph): 23 | def __init__(self, params={}, Train=False): 24 | if not params.get('model_code'): 25 | params['model_code'] = 'bertcnn_classifier' 26 | self.filters = params.get('filters', [3, 4, 5]) # 卷积核大小 27 | self.filters_num = params.get('filters_num', 300) # 核数 28 | super().__init__(params, Train) 29 | 30 | def data_process(self, sep='\t'): 31 | """ 32 | 数据处理 33 | :return: 34 | """ 35 | if '.csv' not in self.train_data_path: 36 | self.train_data_path = data2csv(self.train_data_path, sep) 37 | self.index2label, self.label2index, self.labels, train_data = data_preprocess(self.train_data_path) 38 | self.num_classes = len(self.index2label) 39 | if self.valid_data_path: 40 | if '.csv' not in self.valid_data_path: 41 | self.valid_data_path = data2csv(self.valid_data_path, sep) 42 | _, _, _, valid_data = data_preprocess(self.valid_data_path) 43 | else: 44 | train_data, valid_data = split(train_data, self.split) 45 | if self.test_data_path: 46 | if '.csv' not in self.test_data_path: 47 | self.test_data_path = data2csv(self.test_data_path, sep) 48 | _, _, _, test_data = data_preprocess(self.test_data_path) 49 | else: 50 | test_data = [] 51 | self.train_generator = Data_Generator(train_data, self.label2index, self.tokenizer, self.batch_size, 52 | self.max_len) 53 | self.valid_generator = Data_Generator(valid_data, self.label2index, self.tokenizer, self.batch_size, 54 | self.max_len) 55 | self.test_generator = Data_Generator(test_data, self.label2index, self.tokenizer, self.batch_size, 56 | self.max_len) 57 | 58 | def build_model(self): 59 | bert = build_transformer_model( 60 | config_path=self.bert_config_path, 61 | checkpoint_path=self.bert_checkpoint_path, 62 | return_keras_model=False, 63 | ) 64 | print(bert.model.output.shape) 65 | # output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) # 取出[cls]层对应的向量来做分类 66 | # output = Dense(self.num_classes, activation=self.activation, kernel_initializer=bert.initializer)( 67 | # output) # 全连接层激活函数分类 68 | conv_pools = [] 69 | # 词窗大小分别为3,4,5 70 | for filter in self.filters: 71 | cnn = Conv1D(self.filters_num, filter, padding='same', strides=1, activation='relu')(bert.model.output) 72 | cnn = MaxPooling1D(pool_size=self.max_len - filter + 1)(cnn) 73 | conv_pools.append(cnn) 74 | # 合并三个模型的输出向量 75 | cnn = concatenate(conv_pools, axis=-1) 76 | flat = Flatten()(cnn) 77 | drop = Dropout(self.dropout)(flat) 78 | output = Dense(self.num_classes, activation=self.activation)(drop) 79 | self.model = Model(bert.model.input, output) 80 | print(self.model.summary(150)) 81 | 82 | def predict(self, text): 83 | token_ids, segment_ids = self.tokenizer.encode(text) 84 | pre = self.model.predict([[token_ids], [segment_ids]]) 85 | res = self.index2label.get(str(np.argmax(pre[0]))) 86 | return res 87 | 88 | def compile_model(self): 89 | # 派生为带分段线性学习率的优化器。 90 | # 其中name参数可选,但最好填入,以区分不同的派生优化器。 91 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR') 92 | self.model.compile(loss=self.loss, 93 | optimizer=AdamLR(lr=self.learning_rate, lr_schedule={ 94 | 1000: 1, 95 | 2000: 0.1 96 | }), 97 | metrics=self.metrics, ) 98 | 99 | def train(self): 100 | # 保存超参数 101 | evaluator = Evaluator(self.model, self.model_path, self.valid_generator, self.test_generator) 102 | 103 | # 模型训练 104 | self.model.fit_generator( 105 | self.train_generator.forfit(), 106 | steps_per_epoch=len(self.train_generator), 107 | epochs=self.epoch, 108 | callbacks=[evaluator], 109 | ) 110 | 111 | 112 | if __name__ == '__main__': 113 | params = { 114 | 'model_code': 'thuc_news_bert', 115 | 'train_data_path': CORPUS_ROOT_PATH + '/thuc_news/train.txt', 116 | 'valid_data_path': CORPUS_ROOT_PATH + '/thuc_news/dev.txt', 117 | 'test_data_path': CORPUS_ROOT_PATH + '/thuc_news/test.txt', 118 | 'batch_size': 128, 119 | 'max_len': 30, 120 | 'epoch': 10, 121 | 'learning_rate': 1e-4, 122 | 'gpu_id': 1, 123 | } 124 | bertModel = BertGraph(params, Train=True) 125 | bertModel.train() 126 | else: 127 | params = { 128 | 'model_code': 'thuc_news_bert', # 此处与训练时code保持一致 129 | 'gpu_id': 1, 130 | } 131 | bertModel = BertGraph(params) 132 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 14:09 5 | # ide: PyCharm 6 | -------------------------------------------------------------------------------- /configs/path_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 14:10 5 | # ide: PyCharm 6 | import os 7 | # 项目的根目录 8 | path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) 9 | path_root = path_root.replace('\\', '/') 10 | # print(path_root) 11 | # 日志路径配置 12 | LOG_PATH = os.path.join(path_root,"logs") 13 | LOG_NAME = "classification.log" 14 | 15 | # 模型文件路径 16 | MODEL_ROOT_PATH = os.path.join(path_root,'models') 17 | BERT_MODEL_PATH = os.path.join(MODEL_ROOT_PATH,'chinese_L-12_H-768_A-12') 18 | GPT2_MODEL_PATH =os.path.join(MODEL_ROOT_PATH,'gpt2_ml') 19 | 20 | # 训练语料路径 21 | CORPUS_ROOT_PATH = os.path.join(path_root,'corpus') 22 | 23 | # 实体字典路径 24 | ENTITY_DICT = os.path.join(path_root,'/utils/dynamic_data_cache/entity.csv') 25 | 26 | -------------------------------------------------------------------------------- /entity_relationship_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | # author leewfeng 3 | # 2020/12/12 20:02 -------------------------------------------------------------------------------- /entity_relationship_extraction/triplet_extraction.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | # author leewfeng 3 | # 2020/12/12 20:04 4 | import os 5 | 6 | from basis_framework.basis_graph import BasisGraph 7 | from utils.common_tools import load_json, save_json 8 | from utils.triplet_data_process import Data_Generator, data_process, Evaluator 9 | 10 | rootPath = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | 12 | from bert4keras.backend import K, batch_gather 13 | from bert4keras.layers import LayerNormalization 14 | from bert4keras.models import build_transformer_model 15 | from bert4keras.optimizers import Adam, extend_with_exponential_moving_average 16 | from keras.layers import Input, Dense, Lambda, Reshape 17 | from keras.models import Model 18 | import numpy as np 19 | 20 | model_root_path = rootPath + '/model/' 21 | corpus_root_path = rootPath + '/corpus/' 22 | 23 | 24 | class ReextractBertHandler(BasisGraph): 25 | def __init__(self, params={}, Train=False): 26 | if not params.get('model_code'): 27 | params['model_code'] = 'triplet_extraction' 28 | super().__init__(params, Train) 29 | 30 | def load_params(self): 31 | load_params = load_json(self.params_path) 32 | self.max_len = load_params.get('max_len') 33 | self.num_classes = load_params.get('num_classes') 34 | self.p2s_dict = load_params.get('p2s_dict') 35 | self.i2p_dict = load_params.get('i2p_dict') 36 | self.p2o_dict = load_params.get('p2o_dict') 37 | 38 | def save_params(self): 39 | self.params['num_classes'] = self.num_classes 40 | self.params['p2s_dict'] = self.p2s_dict 41 | self.params['i2p_dict'] = self.i2p_dict 42 | self.params['p2o_dict'] = self.p2o_dict 43 | self.params['max_len'] = self.max_len 44 | save_json(jsons=self.params, json_path=self.params_path) 45 | 46 | def data_process(self): 47 | train_data, self.valid_data, self.p2s_dict, self.p2o_dict, self.i2p_dict, self.p2i_dict = data_process( 48 | self.train_data_path, self.valid_data_path, self.max_len, self.params_path) 49 | self.num_classes = len(self.i2p_dict) 50 | self.train_generator = Data_Generator(train_data, self.batch_size, self.tokenizer, self.p2i_dict, 51 | self.max_len) 52 | 53 | def extrac_subject(self, inputs): 54 | """根据subject_ids从output中取出subject的向量表征 55 | """ 56 | output, subject_ids = inputs 57 | subject_ids = K.cast(subject_ids, 'int32') 58 | start = batch_gather(output, subject_ids[:, :1]) 59 | end = batch_gather(output, subject_ids[:, 1:]) 60 | subject = K.concatenate([start, end], 2) 61 | return subject[:, 0] 62 | 63 | def build_model(self): 64 | import tensorflow as tf 65 | from keras.backend.tensorflow_backend import set_session 66 | config = tf.ConfigProto() 67 | config.gpu_options.allocator_type = 'BFC' # A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc. 68 | if self.memory_fraction: 69 | config.gpu_options.per_process_gpu_memory_fraction = self.memory_fraction 70 | config.gpu_options.allow_growth = False 71 | else: 72 | config.gpu_options.allow_growth = True 73 | set_session(tf.Session(config=config)) 74 | 75 | # 补充输入 76 | subject_labels = Input(shape=(None, 2), name='Subject-Labels') 77 | subject_ids = Input(shape=(2,), name='Subject-Ids') 78 | object_labels = Input(shape=(None, self.num_classes, 2), name='Object-Labels') 79 | # 加载预训练模型 80 | bert = build_transformer_model( 81 | config_path=self.bert_config_path, 82 | checkpoint_path=self.bert_checkpoint_path, 83 | return_keras_model=False, 84 | ) 85 | # 预测subject 86 | output = Dense(units=2, 87 | activation='sigmoid', 88 | kernel_initializer=bert.initializer)(bert.model.output) 89 | subject_preds = Lambda(lambda x: x ** 2)(output) 90 | self.subject_model = Model(bert.model.inputs, subject_preds) 91 | # 传入subject,预测object 92 | # 通过Conditional Layer Normalization将subject融入到object的预测中 93 | output = bert.model.layers[-2].get_output_at(-1) 94 | subject = Lambda(self.extrac_subject)([output, subject_ids]) 95 | output = LayerNormalization(conditional=True)([output, subject]) 96 | output = Dense(units=self.num_classes * 2, 97 | activation='sigmoid', 98 | kernel_initializer=bert.initializer)(output) 99 | output = Lambda(lambda x: x ** 4)(output) 100 | object_preds = Reshape((-1, self.num_classes, 2))(output) 101 | self.object_model = Model(bert.model.inputs + [subject_ids], object_preds) 102 | # 训练模型 103 | self.model = Model(bert.model.inputs + [subject_labels, subject_ids, object_labels], 104 | [subject_preds, object_preds]) 105 | 106 | mask = bert.model.get_layer('Embedding-Token').output_mask 107 | mask = K.cast(mask, K.floatx()) 108 | subject_loss = K.binary_crossentropy(subject_labels, subject_preds) 109 | subject_loss = K.mean(subject_loss, 2) 110 | subject_loss = K.sum(subject_loss * mask) / K.sum(mask) 111 | object_loss = K.binary_crossentropy(object_labels, object_preds) 112 | object_loss = K.sum(K.mean(object_loss, 3), 2) 113 | object_loss = K.sum(object_loss * mask) / K.sum(mask) 114 | self.model.add_loss(subject_loss + object_loss) 115 | AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA') 116 | self.optimizer = AdamEMA(lr=1e-4) 117 | 118 | def compile_model(self): 119 | self.model.compile(optimizer=self.optimizer) 120 | 121 | def predict(self, text): 122 | """ 123 | 抽取输入text所包含的三元组 124 | text:str(<离开>是由张宇谱曲,演唱) 125 | """ 126 | tokens = self.tokenizer.tokenize(text, max_length=self.max_len) 127 | token_ids, segment_ids = self.tokenizer.encode(text, max_length=self.max_len) 128 | # 抽取subject 129 | subject_preds = self.subject_model.predict([[token_ids], [segment_ids]]) 130 | start = np.where(subject_preds[0, :, 0] > 0.6)[0] 131 | end = np.where(subject_preds[0, :, 1] > 0.5)[0] 132 | subjects = [] 133 | for i in start: 134 | j = end[end >= i] 135 | if len(j) > 0: 136 | j = j[0] 137 | subjects.append((i, j)) 138 | if subjects: 139 | spoes = [] 140 | token_ids = np.repeat([token_ids], len(subjects), 0) 141 | segment_ids = np.repeat([segment_ids], len(subjects), 0) 142 | subjects = np.array(subjects) 143 | # 传入subject,抽取object和predicate 144 | object_preds = self.object_model.predict([token_ids, segment_ids, subjects]) 145 | for subject, object_pred in zip(subjects, object_preds): 146 | start = np.where(object_pred[:, :, 0] > 0.6) 147 | end = np.where(object_pred[:, :, 1] > 0.5) 148 | for _start, predicate1 in zip(*start): 149 | for _end, predicate2 in zip(*end): 150 | if _start <= _end and predicate1 == predicate2: 151 | spoes.append((subject, predicate1, (_start, _end))) 152 | break 153 | return [ 154 | ( 155 | [self.tokenizer.decode(token_ids[0, s[0]:s[1] + 1], tokens[s[0]:s[1] + 1]), 156 | self.p2s_dict[self.i2p_dict[p]]], 157 | self.i2p_dict[p], 158 | [self.tokenizer.decode(token_ids[0, o[0]:o[1] + 1], tokens[o[0]:o[1] + 1]), 159 | self.p2o_dict[self.i2p_dict[p]]], 160 | (s[0], s[1] + 1), 161 | (o[0], o[1] + 1) 162 | ) for s, p, o in spoes 163 | ] 164 | else: 165 | return [] 166 | 167 | def train(self): 168 | evaluator = Evaluator(self.model, self.model_path, self.tokenizer, self.predict, self.optimizer, 169 | self.valid_data) 170 | 171 | self.model.fit_generator(self.train_generator.forfit(), 172 | steps_per_epoch=len(self.train_generator), 173 | epochs=self.epoch, 174 | callbacks=[evaluator]) 175 | 176 | 177 | if __name__ == '__main__': 178 | params = { 179 | "max_len": 128, 180 | "batch_size": 32, 181 | "epoch": 1, 182 | "train_data_path": rootPath + "/data/train_data.json", 183 | "dev_data_path": rootPath + "/data/valid_data.json", 184 | } 185 | 186 | model = ReextractBertHandler(params, Train=True) 187 | 188 | model.train() 189 | text = "马志舟,1907年出生,陕西三原人,汉族,中国共产党,任红四团第一连连长,1933年逝世" 190 | print(model.predict(text)) 191 | -------------------------------------------------------------------------------- /named_entity/__init__.py: -------------------------------------------------------------------------------- 1 | #!/user/bin/env python 2 | # coding=utf-8 3 | # __project__ = "" 4 | # __author__ = "Rick Hou" 5 | # __time__ = 2020/11/6 14:58 6 | 7 | -------------------------------------------------------------------------------- /named_entity/entity_by_rules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | 4 | from algorithm.named_entity.ner_handler import NerHandler 5 | 6 | from utils.dynamic_data_cache.keyword_dao import keywordinital 7 | 8 | 9 | def entity_rule(sentence: str) -> []: 10 | """ 11 | 关系实体规则识别 12 | :param sentence: 规则识别文本 13 | :return: [{'word': 'XXX', 'start': 1, 'end': 3, 'type': 'relation_entity'}] 14 | """ 15 | entity_list = [] 16 | if sentence: 17 | for entity_type, entity_tree in keywordinital.entity_trees.items(): 18 | tranfer_list = entity_tree.extract_keyword(sentence) 19 | if tranfer_list: 20 | tranfer_list.sort(key=lambda i: len(i), reverse=True) 21 | for entity_single in tranfer_list: 22 | position = re.finditer(entity_single, sentence) 23 | for posi in position: 24 | entity_dict = {'word': entity_single, 'start_pos': posi.start(), 'end_pos': posi.end(), 25 | 'entity_type': entity_type} 26 | if entity_dict not in entity_list: 27 | entity_list.append(entity_dict) 28 | return entity_list 29 | 30 | 31 | def ner_by_rule(sentence_list): 32 | entities = [] 33 | for sentence in sentence_list: 34 | entities_list = entity_rule(sentence) 35 | entities_list = [i for i in entities_list if i != {}] 36 | entities_list.sort(key=lambda i: i['start_pos']) 37 | entities.append(entities_list) 38 | return entities 39 | 40 | 41 | def ner_model_rule_syn(ner_rules, ner_models): 42 | """ 43 | 规则模型识别融合 44 | :param text: 原始文本 45 | :param ner_model_id:模型id 46 | :return: 实体list [] 47 | """ 48 | res_entity = [] 49 | for index, ner_rule in enumerate(ner_rules): 50 | ner_model_list = [] 51 | if ner_models[index]: 52 | ner_model_list = ner_models[index].get('entities') 53 | _ner_model_list = copy.deepcopy(ner_model_list) 54 | for elem_rule in ner_rule: 55 | for elem_model in ner_model_list: 56 | if elem_rule.get('start_pos') <= elem_model.get('start_pos') and elem_rule.get( 57 | 'end_pos') >= elem_model.get('start_pos'): 58 | if elem_model in _ner_model_list: 59 | _ner_model_list.remove(elem_model) 60 | elif elem_rule.get('start_pos') <= elem_model.get('end_pos') and elem_rule.get( 61 | 'end_pos') >= elem_model.get( 62 | 'end_pos'): 63 | if elem_model in _ner_model_list: 64 | _ner_model_list.remove(elem_model) 65 | ner_rule.extend(_ner_model_list) 66 | ner_rule.sort(key=lambda i: i['start_pos']) 67 | res_entity.append({'entities': ner_rule}) 68 | return res_entity 69 | 70 | 71 | if __name__ == '__main__': 72 | # while True: 73 | # print('input: ') 74 | # sentence = input() 75 | # ret = ner_by_rule([sentence]) 76 | # print(ret) 77 | nerModel = NerHandler() 78 | texts = ['这次海钓的地点在厦门和深圳之间的海域,中国建设银行金融科技中心在这里举办活动', '日俄两国国内政局都充满了变数'] 79 | res = nerModel.predict(texts) 80 | ret = ner_by_rule(texts) 81 | last = ner_model_rule_syn(ret, res) 82 | print(last) 83 | -------------------------------------------------------------------------------- /named_entity/ner_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/10 9:28 5 | # ide: PyCharm 6 | """支持多gpu训练""" 7 | import os 8 | import sys 9 | 10 | rootPath = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(rootPath) 12 | from bert4keras.backend import K 13 | from bert4keras.layers import ConditionalRandomField 14 | from bert4keras.models import build_transformer_model 15 | from bert4keras.optimizers import Adam 16 | from keras.layers import Dense 17 | from keras.models import Model 18 | 19 | from basis_framework.basis_graph import BasisGraph 20 | from configs.path_config import CORPUS_ROOT_PATH 21 | from utils.common_tools import split 22 | from utils.logger import logger 23 | from utils.ner_data_process import data_process, Data_Generator, NamedEntityRecognizer, Evaluator 24 | from keras.utils import multi_gpu_model 25 | 26 | class NerHandler(BasisGraph): 27 | def __init__(self, params={}, Train=False): 28 | if not params.get('model_code'): 29 | params['model_code'] = 'ner' 30 | super().__init__(params, Train) 31 | self.build_ViterbiDecoder() 32 | logger.info('init ner_handler done') 33 | 34 | def build_model(self): 35 | model = build_transformer_model( 36 | self.bert_config_path, 37 | self.bert_checkpoint_path, 38 | ) 39 | output_layer = 'Transformer-%s-FeedForward-Norm' % (self.bert_layers - 1) 40 | output = model.get_layer(output_layer).output 41 | output = Dense(self.num_classes)(output) 42 | self.CRF = ConditionalRandomField(lr_multiplier=self.crf_lr_multiplier) 43 | output = self.CRF(output) 44 | self.model = Model(model.input, output) 45 | self.model_ = multi_gpu_model(self.model, gpus=2) 46 | self.model.summary(120) 47 | logger.info('build model done') 48 | 49 | def data_process(self): 50 | labels, train_data = data_process(self.train_data_path) 51 | if self.valid_data_path: 52 | _, self.valid_data = data_process(self.valid_data_path) 53 | else: 54 | train_data, self.valid_data = split(train_data, self.split) 55 | if self.test_data_path: _, self.test_data = data_process(self.test_data_path) 56 | self.index2label = dict(enumerate(labels)) 57 | self.label2index = {j: i for i, j in self.index2label.items()} 58 | self.num_classes = len(labels) * 2 + 1 59 | self.labels = labels 60 | self.train_generator = Data_Generator(train_data, self.batch_size, self.tokenizer, self.label2index, 61 | self.max_len) 62 | logger.info('data process done') 63 | 64 | def build_ViterbiDecoder(self): 65 | self.NER = NamedEntityRecognizer(trans=K.eval(self.CRF.trans), tokenizer=self.tokenizer, model=self.model, 66 | id2label=self.index2label, 67 | starts=[0], ends=[0]) 68 | 69 | def compile_model(self): 70 | self.model_.compile( 71 | # self.model.compile( 72 | loss=self.CRF.sparse_loss, 73 | optimizer=Adam(self.learning_rate), 74 | metrics=[self.CRF.sparse_accuracy] 75 | ) 76 | logger.info('compile model done') 77 | 78 | def recognize(self, text): 79 | tokens = self.NER.recognize(text) 80 | return tokens 81 | 82 | def predict(self, sentences): 83 | """ 84 | :param sentences: 85 | :return: 86 | """ 87 | res00 = [] 88 | text = [t for t in sentences if t] 89 | if text: 90 | tmp_res = self.NER.batch_recognize(sentences) 91 | for res in tmp_res: 92 | entities = [] 93 | for item in res: 94 | dics = {} 95 | dics['word'] = item[0] 96 | dics['start_pos'] = item[2] 97 | dics['end_pos'] = item[2] + len(item[0]) 98 | dics['entity_type'] = item[1] 99 | entities.append(dics) 100 | entities = sorted(entities, key=lambda x: x['start_pos']) 101 | res00.append({'entities': entities}) 102 | return res00 103 | 104 | def train(self): 105 | evaluator = Evaluator(self.model, self.model_path, self.CRF, self.NER, self.recognize, self.label2index, 106 | self.valid_data, self.test_data) 107 | 108 | self.model_.fit_generator(self.train_generator.forfit(), 109 | # self.model.fit_generator(self.train_generator.forfit(), 110 | steps_per_epoch=len(self.train_generator), 111 | epochs=self.epoch, 112 | callbacks=[evaluator]) 113 | 114 | 115 | if __name__ == '__main__': 116 | params = { 117 | 'train_data_path': CORPUS_ROOT_PATH + '/28_baidu/train.txt', 118 | 'valid_data_path': CORPUS_ROOT_PATH + '/28_baidu/dev.txt', 119 | 'test_data_path': CORPUS_ROOT_PATH + '/28_baidu/test.txt', 120 | 'epoch': 1, 121 | 'batch_size': 64, 122 | # 'gpu_id': 0, 123 | } 124 | nerModel = NerHandler(params, Train=True) 125 | nerModel.train() 126 | texts = ['这次海钓的地点在厦门和深圳之间的海域,中国建设银行金融科技中心在这里举办活动', '日俄两国国内政局都充满了变数'] 127 | res = nerModel.predict(texts) 128 | print(res) 129 | else: 130 | nerModel = NerHandler() 131 | -------------------------------------------------------------------------------- /text_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | # author leewfeng 3 | # 2021/1/4 19:24 -------------------------------------------------------------------------------- /text_generation/gpt2_ml.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | # author leewfeng 3 | # 2021/1/4 19:25 4 | # 基本测试:中文GPT2_ML模型 5 | # 介绍链接:https://kexue.fm/archives/7292 6 | 7 | 8 | import os 9 | 10 | import numpy as np 11 | from bert4keras.models import build_transformer_model 12 | from bert4keras.snippets import AutoRegressiveDecoder 13 | from bert4keras.tokenizers import Tokenizer 14 | 15 | from basis_framework.basis_graph import BasisGraph 16 | from configs.path_config import GPT2_MODEL_PATH 17 | from utils.common_tools import load_json, save_json 18 | 19 | 20 | class ExtractFeature(BasisGraph): 21 | def __init__(self, params={}, Train=False): 22 | super().__init__(params, Train) 23 | self.config_path = os.path.join(GPT2_MODEL_PATH + "/config.json") 24 | self.checkpoint_path = os.path.join(GPT2_MODEL_PATH + "/model.ckpt-100000") 25 | self.vocab_path = os.path.join(GPT2_MODEL_PATH + "/vocab.txt") 26 | self.tokenizer = Tokenizer(self.vocab_path, token_start=None, token_end=None, do_lower_case=True) 27 | 28 | def save_params(self): 29 | self.params['max_len'] = self.max_len 30 | save_json(jsons=self.params, json_path=self.params_path) 31 | 32 | def load_params(self): 33 | load_params = load_json(self.params_path) 34 | self.max_len = load_params.get('max_len') 35 | 36 | def _set_gpu_id(self): 37 | """指定使用的GPU显卡id""" 38 | if self.gpu_id: 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id) 40 | 41 | def data_process(self): 42 | """ 43 | 模型框架搭建 44 | :return: 45 | """ 46 | raise NotImplementedError 47 | 48 | def build_model(self): 49 | """ 50 | 模型框架搭建 51 | :return: 52 | """ 53 | model = build_transformer_model(self.config_path, self.checkpoint_path, model='gpt2_ml') 54 | 55 | class ArticleCompletion(AutoRegressiveDecoder): 56 | """ 57 | 基于随机采样的文章续写 58 | """ 59 | 60 | def __init__(self, start_id, end_id, maxlen, minlen=None, model=None, tokenizer=None): 61 | self.tokenizer = tokenizer 62 | super().__init__(start_id, end_id, maxlen, minlen=None) 63 | 64 | @AutoRegressiveDecoder.wraps(default_rtype='probas') 65 | def predict(self, inputs, output_ids, step): 66 | token_ids = np.concatenate([inputs[0], output_ids], 1) 67 | return self.last_token(model).predict(token_ids) 68 | 69 | def generate(self, text, n=1, topp=0.95): 70 | token_ids, _ = self.tokenizer.encode(text) 71 | results = self.random_sample([token_ids], n, topp=topp) 72 | return [text + self.tokenizer.decode(ids) for ids in results] 73 | 74 | self.article_completion = ArticleCompletion(start_id=None, 75 | end_id=511, # 511是中文句号 76 | maxlen=256, 77 | minlen=128) 78 | 79 | def compile_model(self): 80 | """ 81 | 模型框架搭建 82 | :return: 83 | """ 84 | raise NotImplementedError 85 | 86 | def extract_features(self, text: str): 87 | """ 88 | 编码测试 89 | :return: 90 | """ 91 | print(self.article_completion.generate(u'今天天气不错')) 92 | 93 | def save_model(self, model_path='test.model'): 94 | self.model.save(model_path) 95 | del self.model # 释放内存 96 | 97 | def load_model(self, model_path='test.model'): 98 | # self.model = keras.models.load_model(model_path) 99 | # self.extract_features('语言模型') 100 | pass 101 | 102 | def load_params(self): 103 | pass 104 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 14:09 5 | # ide: PyCharm 6 | -------------------------------------------------------------------------------- /utils/classifier_data_process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 17:22 5 | # ide: PyCharm 6 | 7 | import keras 8 | from bert4keras.snippets import DataGenerator, sequence_padding 9 | from tqdm import tqdm 10 | 11 | 12 | # DataGenerator只是一种为了节约内存的数据方式 13 | class Data_Generator(DataGenerator): 14 | def __init__(self, data, l2i, tokenizer, batch_size, maxlen=128): 15 | super().__init__(data, batch_size=batch_size) 16 | self.l2i = l2i 17 | self.maxlen = maxlen 18 | self.tokenizer = tokenizer 19 | 20 | def __iter__(self, random=False): 21 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 22 | for is_end, (label, text) in self.sample(random): 23 | token_ids, segment_ids = self.tokenizer.encode(text, max_length=self.maxlen) 24 | batch_token_ids.append(token_ids) 25 | batch_segment_ids.append(segment_ids) 26 | batch_labels.append([self.l2i.get(str(label))]) 27 | if len(batch_token_ids) == self.batch_size or is_end: 28 | batch_token_ids = sequence_padding(batch_token_ids) 29 | batch_segment_ids = sequence_padding(batch_segment_ids) 30 | batch_labels = sequence_padding(batch_labels) 31 | yield [batch_token_ids, batch_segment_ids], batch_labels 32 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 33 | 34 | 35 | def evaluate(data, predict): 36 | total, right = 0., 0. 37 | for x_true, y_true in tqdm(data): 38 | # for x_true, y_true in data: 39 | y_pred = predict(x_true).argmax(axis=1) 40 | y_true = y_true[:, 0] 41 | total += len(y_true) 42 | right += (y_true == y_pred).sum() 43 | return right / total 44 | 45 | 46 | class Evaluator(keras.callbacks.Callback): 47 | """评估与保存 48 | """ 49 | 50 | def __init__(self, model, model_path, valid_generator, test_generator): 51 | self.best_val_acc = 0. 52 | self.model = model 53 | self.model_path = model_path 54 | self.valid_generator = valid_generator 55 | self.test_generator = test_generator 56 | 57 | def on_epoch_end(self, epoch, logs=None): 58 | val_acc = evaluate(self.valid_generator, self.model.predict) 59 | if val_acc > self.best_val_acc: 60 | self.best_val_acc = val_acc 61 | self.model.save_weights(self.model_path) 62 | print( 63 | u'val_acc: %.5f, best_val_acc: %.5f\n' % 64 | (val_acc, self.best_val_acc) 65 | ) 66 | 67 | def on_train_end(self, logs=None): 68 | test_acc = evaluate(self.test_generator, self.model.predict) 69 | print( 70 | u'best_val_acc: %.5f, test_acc: %.5f\n' % 71 | (self.best_val_acc, test_acc) 72 | ) 73 | -------------------------------------------------------------------------------- /utils/common_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 16:52 5 | # ide: PyCharm 6 | import codecs 7 | import json 8 | import os 9 | import random 10 | import re 11 | 12 | import jieba 13 | import pandas as pd 14 | 15 | 16 | def search(pattern, sequence): 17 | """从sequence中寻找子串pattern 18 | 如果找到,返回第一个下标;否则返回-1。 19 | """ 20 | n = len(pattern) 21 | for i in range(len(sequence)): 22 | if sequence[i:i + n] == pattern: 23 | return i 24 | return -1 25 | 26 | 27 | def data2csv(data_path, sep): 28 | # 训练数据、测试数据和标签转化为模型输入格式 29 | label = [] 30 | content = [] 31 | with open(data_path, 'r', encoding='utf8') as f: 32 | contents = f.readlines() 33 | for line in contents: 34 | line = line.split(sep) 35 | label.append(line[1]) 36 | content.append(line[0]) 37 | data = {} 38 | data['label'] = label 39 | data['content'] = content 40 | data_path = ''.join(data_path.split('.')[:-1]) + '.csv' 41 | pd.DataFrame(data).to_csv(data_path, index=False) 42 | return data_path 43 | 44 | 45 | def txt_read(file_path, encode_type='utf-8'): 46 | """ 47 | 读取txt文件,默认utf8格式 48 | :param file_path: str, 文件路径 49 | :param encode_type: str, 编码格式 50 | :return: list 51 | """ 52 | list_line = [] 53 | try: 54 | file = open(file_path, 'r', encoding=encode_type) 55 | while True: 56 | line = file.readline() 57 | line = line.strip() 58 | if not line: 59 | break 60 | list_line.append(line) 61 | file.close() 62 | except Exception as e: 63 | print(str(e)) 64 | finally: 65 | return list_line 66 | 67 | 68 | def txt_write(list_line, file_path, type='w', encode_type='utf-8'): 69 | """ 70 | txt写入list文件 71 | :param listLine:list, list文件,写入要带"/n" 72 | :param filePath:str, 写入文件的路径 73 | :param type: str, 写入类型, w, a等 74 | :param encode_type: 75 | :return: 76 | """ 77 | try: 78 | file = open(file_path, type, encoding=encode_type) 79 | file.writelines(list_line) 80 | file.close() 81 | 82 | except Exception as e: 83 | print(str(e)) 84 | 85 | 86 | def extract_chinese(text): 87 | """ 88 | 只提取出中文、字母和数字 89 | :param text: str, input of sentence 90 | :return: 91 | """ 92 | chinese_exttract = ''.join(re.findall(u"([/u4e00-/u9fa5A-Za-z0-9@._])", text)) 93 | return chinese_exttract 94 | 95 | 96 | def read_and_process(path): 97 | """ 98 | 读取文本数据并 99 | :param path: 100 | :return: 101 | """ 102 | 103 | data = pd.read_csv(path) 104 | ques = data["ques"].values.tolist() 105 | labels = data["label"].values.tolist() 106 | line_x = [extract_chinese(str(line).upper()) for line in labels] 107 | line_y = [extract_chinese(str(line).upper()) for line in ques] 108 | return line_x, line_y 109 | 110 | 111 | def preprocess_label_ques(path): 112 | x, y, x_y = [], [], [] 113 | x_y.append('label,ques/n') 114 | with open(path, 'r', encoding='utf-8') as f: 115 | while True: 116 | line = f.readline() 117 | try: 118 | line_json = json.loads(line) 119 | except: 120 | break 121 | ques = line_json['title'] 122 | label = line_json['category'][0:2] 123 | line_x = " ".join( 124 | [extract_chinese(word) for word in list(jieba.cut(ques, cut_all=False, HMM=True))]).strip().replace( 125 | ' ', ' ') 126 | line_y = extract_chinese(label) 127 | x_y.append(line_y + ',' + line_x + '/n') 128 | return x_y 129 | 130 | 131 | def save_json(jsons, json_path): 132 | """ 133 | 保存json, 134 | :param json_: json 135 | :param path: str 136 | :return: None 137 | """ 138 | with open(json_path, 'w', encoding='utf-8') as fj: 139 | fj.write(json.dumps(jsons, ensure_ascii=False)) 140 | fj.close() 141 | 142 | 143 | def load_json(path): 144 | """ 145 | 获取json,只取第一行 146 | :param path: str 147 | :return: json 148 | """ 149 | with open(path, 'r', encoding='utf-8') as fj: 150 | model_json = json.loads(fj.readlines()[0]) 151 | return model_json 152 | 153 | 154 | def delete_file(path): 155 | """ 156 | 删除一个目录下的所有文件 157 | :param path: str, dir path 158 | :return: None 159 | """ 160 | for i in os.listdir(path): 161 | # 取文件或者目录的绝对路径 162 | path_children = os.path.join(path, i) 163 | if os.path.isfile(path_children): 164 | if path_children.endswith(".h5") or path_children.endswith(".json"): 165 | os.remove(path_children) 166 | else: # 递归, 删除目录下的所有文件 167 | delete_file(path_children) 168 | 169 | 170 | def token_process(vocab_path): 171 | """ 172 | 数据处理 173 | :return: 174 | """ 175 | # 将词表中的词转换为字典 176 | token_dict = {} 177 | with codecs.open(vocab_path, 'r', 'utf8') as reader: 178 | for line in reader: 179 | token = line.strip() 180 | token_dict[token] = len(token_dict) 181 | return token_dict 182 | 183 | 184 | def data_preprocess(data_path, label='label',usecols=['label','content']): 185 | """ 186 | 处理数据返回类别标签转换字典 187 | :param data_path: 188 | :return: 189 | """ 190 | df = pd.read_csv(data_path,usecols=usecols).dropna() 191 | label_unique = df[label].unique().tolist() 192 | data = df.values.tolist() 193 | i2l = {i: str(v) for i, v in enumerate(label_unique)} 194 | l2i = {str(v): i for i, v in enumerate(label_unique)} 195 | return i2l, l2i,label_unique, data 196 | 197 | 198 | def split(train_data, sep=0.8): 199 | data_len = len(train_data) 200 | indexs = list(range(data_len)) 201 | random.shuffle(indexs) 202 | sep = int(data_len * sep) 203 | train_data, valid_data = [train_data[i] for i in indexs[:sep]], [train_data[i] for i in 204 | indexs[sep:]] 205 | # self.train_data ,self.valid_data = [self.train_data[i] for i in indexs[:sep]],[self.train_data[i] for i in indexs[sep:]] 206 | return train_data, valid_data 207 | 208 | 209 | if __name__ == '__main__': 210 | data_preprocess('E:/lwf_practice/Text_Classification/corpus/baidu_qa_2019/baike_qa_train.csv') 211 | -------------------------------------------------------------------------------- /utils/dynamic_data_cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leefsir/bert4keras4nlp/92935a2358213677e9e9a6fee439ef642741ed6d/utils/dynamic_data_cache/__init__.py -------------------------------------------------------------------------------- /utils/dynamic_data_cache/entity.csv: -------------------------------------------------------------------------------- 1 | value,entity_type 2 | 企业法人, client 3 | 非法人企业, client 4 | 个体工商户, client 5 | 存款人, client 6 | 企业, client 7 | 金融机构客户, client 8 | 资产托管人, client 9 | 中国建设银行金融科技中心, organization 10 | 金融机构, organization 11 | 机关, organization 12 | 事业单位, organization 13 | 预算管理, organization 14 | 营业网点, organization 15 | 运营管理部, organization 16 | 总行, organization 17 | 管理部门, organization 18 | 金融机构部, organization 19 | 法律合规部, organization 20 | 财务会计部, organization 21 | 办公室, organization 22 | 网络金融部, organization 23 | 信息科技部, organization 24 | 分行, organization 25 | 作业中心, organization 26 | 营业网点, organization 27 | 法律合规部, organization 28 | 资财部, organization 29 | 中国人民银行, organization 30 | 财政部门, organization 31 | 上海浦东发展银行同业存放业务管理办法 , system 32 | 关于落实中国人民银行加强银行业金融机构人民币同业银行结算账户管理规定的通知, system 33 | 上海浦东发展银行资产托管业务托管账户管理办法, system 34 | 企业银行账户, business 35 | 银行结算账户业务, business 36 | 企业银行结算账户, business 37 | 基金, business 38 | 信托, business 39 | 资管计划, business 40 | 理财产品, business 41 | 账户业务, business 42 | 银行结算账户, business 43 | 基本存款账户, business 44 | 一般存款账户, business 45 | 专用存款账户, business 46 | 临时存款账户, business 47 | 银行结算账户, business 48 | 基本存款账户, business -------------------------------------------------------------------------------- /utils/dynamic_data_cache/keyword_dao.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from configs.path_config import ENTITY_DICT 4 | from utils.dynamic_data_cache.trie_tree import get_trie_tree 5 | 6 | 7 | class KeywordInitial: 8 | def __init__(self): 9 | # 从数据库获取所有关键词,并按 ',' 分割。返回(KEY, VALUE, TYPE, ID) 10 | self.get_entity() 11 | self.get_entity_tree() 12 | 13 | def get_entity_tree(self): 14 | self.entity_trees = {} 15 | for entity_type, entity_lists in self.entitys.items(): 16 | self.entity_trees[entity_type] = get_trie_tree(entity_lists) 17 | 18 | def delete(self, keyword_tuple: [(), ]): 19 | 20 | for value, entity_type in keyword_tuple: 21 | if self.entitys.get(entity_type) and value in self.entitys.get(entity_type): 22 | self.entitys.get(entity_type).remove(value) 23 | self.get_entity_tree() 24 | 25 | def add(self, keyword_tuple: [(), ]): 26 | ''' 27 | :param keyword_tuple: [(value, entity_type),] 28 | :return: 29 | ''' 30 | for value, entity_type in keyword_tuple: 31 | if self.entitys.get(entity_type) and value not in self.entitys.get(entity_type): 32 | self.entitys.get(entity_type).append(value) 33 | elif not self.entitys.get(entity_type): 34 | self.entitys[entity_type] = [value] 35 | self.get_entity_tree() 36 | 37 | def update_all(self): 38 | self.get_entity() 39 | self.get_entity_tree() 40 | 41 | def get_entity(self, entity_path=ENTITY_DICT): 42 | entity_df = pd.read_csv(entity_path) 43 | if list(entity_df) != ['value', 'entity_type']: 44 | raise Exception("Incorrect format! The column name is ['value', 'entity_type'] not {}".format(list(entity_df))) 45 | entity_df = entity_df.drop_duplicates(subset='value', keep='first') 46 | self.entitys = {} 47 | for value, entity_type in entity_df.values.tolist(): 48 | if self.entitys.get(entity_type): 49 | self.entitys.get(entity_type).append(value) 50 | else: 51 | self.entitys[entity_type] = [value] 52 | 53 | 54 | # 项目启动就加载所有关键字 55 | keywordinital = KeywordInitial() 56 | -------------------------------------------------------------------------------- /utils/dynamic_data_cache/keyword_update.py: -------------------------------------------------------------------------------- 1 | from utils.dynamic_data_cache.keyword_dao import keywordinital 2 | from utils.logger import logger 3 | 4 | 5 | def keyword_operate(operate_code, keyword_tuple_list): 6 | if operate_code == -1: # 删除 7 | logger.info(keyword_tuple_list) 8 | keywordinital.delete(keyword_tuple_list) 9 | 10 | 11 | elif operate_code == 1: # 新增 12 | logger.info(keyword_tuple_list) 13 | keywordinital.add(keyword_tuple_list) 14 | 15 | elif operate_code == 2: # 更新所有关键词 16 | logger.info(keyword_tuple_list) 17 | keywordinital.update_all() 18 | 19 | else: 20 | raise NotImplementedError('{} is not a vaild ' 21 | 'operate code.'.format(operate_code)) 22 | 23 | -------------------------------------------------------------------------------- /utils/dynamic_data_cache/trie_tree.py: -------------------------------------------------------------------------------- 1 | class TrieNode: 2 | """ 3 | 前缀树节点-链表格式 4 | """ 5 | def __init__(self): 6 | self.child = {} 7 | # 可以加一个判断条件,但是人名提取用不到 8 | # self.flag = 0 9 | 10 | 11 | class TrieTree: 12 | """ 13 | 前缀树构建、新增关键词、关键词词语查找等 14 | """ 15 | def __init__(self): 16 | self.root = TrieNode() 17 | 18 | def add_keyword_one(self, keyword): 19 | """ 20 | 新增一个关键词 21 | :param keyword: str,构建的关键词 22 | :return: None 23 | """ 24 | node_curr = self.root 25 | for word in keyword: 26 | if node_curr.child.get(word) is None: 27 | node_next = TrieNode() 28 | node_curr.child[word] = node_next 29 | node_curr = node_curr.child[word] 30 | # 每个关键词词后边,加入end标志位 31 | if node_curr.child.get('end') is None: 32 | node_next = TrieNode() 33 | node_curr.child['end'] = node_next 34 | node_curr = node_curr.child['end'] 35 | 36 | def add_keyword_list(self, keywords): 37 | """ 38 | 新增关键词s, 格式为list 39 | :param keyword: list, 构建的关键词 40 | :return: None 41 | """ 42 | for keyword in keywords: 43 | self.add_keyword_one(keyword) 44 | 45 | def extract_keyword(self, sentence): 46 | """ 47 | 从句子中提取关键词,取得大于2个的,例如有人名"大漠帝国",那么"大漠帝"也取得 48 | :param sentence: str, 输入的句子 49 | :return: list, 提取到的关键词 50 | """ 51 | if not sentence: 52 | return [] 53 | node_curr = self.root # 关键词的第一位, 每次遍历完一个后重新初始化 54 | word_last = sentence[-1] 55 | name_list = [] 56 | name = '' 57 | for word in sentence: 58 | if node_curr.child.get(word) is None: # 查看有无后缀 59 | if name: # 提取到的关键词(也可能是前面的几位) 60 | if node_curr.child.get('end') is not None: # 取以end结尾的关键词, 或者len(name) > 2 61 | name_list.append(name) 62 | node_curr = self.root # 重新初始化 63 | if self.root.child.get(word): 64 | name = word 65 | node_curr = node_curr.child[word] 66 | else: 67 | name = '' 68 | else: # 有缀就加到name里边 69 | name = name + word 70 | node_curr = node_curr.child[word] 71 | if word == word_last: # 实体结尾的情况 72 | if node_curr.child.get('end') is not None: 73 | name_list.append(name) 74 | return name_list 75 | 76 | 77 | def get_trie_tree(keywords): 78 | """ 79 | 根据list关键词,初始化trie树 80 | :param keywords: list, input 81 | :return: objext, 返回实例化的trie 82 | """ 83 | trie = TrieTree() 84 | trie.add_keyword_list(keywords) 85 | return trie 86 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/1 14:09 5 | # ide: PyCharm 6 | 7 | import logging 8 | import os 9 | import re 10 | from logging.handlers import TimedRotatingFileHandler 11 | 12 | from configs.path_config import LOG_PATH, LOG_NAME 13 | 14 | LOGGER_LEVEL = logging.INFO 15 | # 日志保存个数 16 | BACKUP_COUNT = 30 17 | 18 | if not os.path.exists(LOG_PATH): 19 | os.makedirs(LOG_PATH, exist_ok=True) 20 | 21 | 22 | def setup_log(log_path, log_name): 23 | logging.basicConfig(level=logging.ERROR) 24 | # 创建logger对象。传入logger名字 25 | logger = logging.getLogger(log_name) 26 | log_path = os.path.join(log_path, log_name) 27 | # 设置日志记录等级 28 | logger.setLevel(LOGGER_LEVEL) 29 | # interval 滚动周期, 30 | # when="MIDNIGHT", interval=1 表示每天0点为更新点,每天生成一个文件 31 | file_handler = TimedRotatingFileHandler( 32 | filename=log_path, when="MIDNIGHT", interval=1, backupCount=BACKUP_COUNT 33 | ) 34 | # 设置时间 35 | file_handler.suffix = "%Y-%m-%d.log" 36 | # extMatch是编译好正则表达式,用于匹配日志文件名后缀 37 | # 需要注意的是suffix和extMatch一定要匹配的上,如果不匹配,过期日志不会被删除。 38 | file_handler.extMatch = re.compile(r"^\d{4}-\d{2}-\d{2}.log$") 39 | # 定义日志输出格式 40 | file_handler.setFormatter( 41 | logging.Formatter( 42 | "[%(asctime)s] [%(process)d] [%(levelname)s] - %(module)s.%(funcName)s (%(filename)s:%(lineno)d) - %(message)s" 43 | ) 44 | ) 45 | logger.addHandler(file_handler) 46 | return logger 47 | 48 | 49 | logger = setup_log(LOG_PATH, LOG_NAME) 50 | -------------------------------------------------------------------------------- /utils/ner_data_process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # author: liwfeng 4 | # datetime: 2020/12/10 9:56 5 | # ide: PyCharm 6 | import keras 7 | from bert4keras.backend import K 8 | from bert4keras.snippets import DataGenerator, sequence_padding, ViterbiDecoder 9 | from tqdm import tqdm 10 | 11 | from utils.common_tools import search 12 | from utils.logger import logger 13 | 14 | 15 | def data_process(filename): 16 | D = [] 17 | flags = [] 18 | with open(filename, encoding='utf-8') as f: 19 | f = f.read() 20 | for l in f.split('\n\n'): 21 | if not l: 22 | continue 23 | d, last_flag = [], '' 24 | for c in l.split('\n'): 25 | c = c.strip() 26 | if len(c.split(' ')) == 1: 27 | continue 28 | char, this_flag = c.split(' ') 29 | flags.append(this_flag) 30 | if this_flag == 'O' and last_flag == 'O': 31 | d[-1][0] += char 32 | elif this_flag == 'O' and last_flag != 'O': 33 | d.append([char, 'O']) 34 | elif this_flag[:1] == 'B': 35 | d.append([char, this_flag[2:]]) 36 | else: 37 | d[-1][0] += char 38 | last_flag = this_flag 39 | D.append(d) 40 | flags = list(set(flags)) 41 | flags = list(set([item.split('-')[1] for item in flags if len(item.split('-')) > 1])) 42 | return flags, D 43 | 44 | 45 | class Data_Generator(DataGenerator): 46 | """数据生成器 47 | """ 48 | 49 | def __init__(self, data, batch_size, tokenizer, label2id, maxlen): 50 | super().__init__(data, batch_size=batch_size) 51 | self.tokenizer = tokenizer 52 | self.label2id = label2id 53 | self.maxlen = maxlen 54 | 55 | def __iter__(self, random=False): 56 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 57 | for is_end, item in self.sample(random): 58 | token_ids, labels = [self.tokenizer._token_start_id], [0] 59 | for w, l in item: 60 | w_token_ids = self.tokenizer.encode(w)[0][1:-1] # 只获取token_ids 不包含cls和sep 61 | if len(token_ids) + len(w_token_ids) < self.maxlen: 62 | token_ids += w_token_ids 63 | if l == 'O': 64 | labels += [0] * len(w_token_ids) 65 | else: 66 | B = self.label2id[l] * 2 + 1 # 防止当某个标签为0时和‘O’的标签id冲突故而使除O外的标签id>0 67 | I = self.label2id[l] * 2 + 2 68 | labels += ([B] + [I] * (len(w_token_ids) - 1)) 69 | else: 70 | break 71 | token_ids += [self.tokenizer._token_end_id] # ['[CLS]']+[ids] +['[SEP]'] 72 | labels += [0] 73 | segment_ids = [0] * len(token_ids) 74 | batch_token_ids.append(token_ids) 75 | batch_segment_ids.append(segment_ids) 76 | batch_labels.append(labels) 77 | if len(batch_token_ids) == self.batch_size or is_end: 78 | batch_token_ids = sequence_padding(batch_token_ids) 79 | batch_segment_ids = sequence_padding(batch_segment_ids) 80 | batch_labels = sequence_padding(batch_labels) 81 | yield [batch_token_ids, batch_segment_ids], batch_labels 82 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 83 | 84 | 85 | def evaluate(data, recognize): 86 | """评测函数 87 | """ 88 | X, Y, Z = 1e-10, 1e-10, 1e-10 89 | for d in tqdm(data): 90 | text = ''.join([i[0] for i in d]) 91 | R = set(recognize(text)) 92 | T = set([tuple(i) for i in d if i[1] != 'O']) 93 | X += len(R & T) 94 | Y += len(R) 95 | Z += len(T) 96 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 97 | return f1, precision, recall 98 | 99 | 100 | def test_evaluate(data, recognize, label2id): 101 | """评测函数 102 | """ 103 | X, Y, Z = 1e-10, 1e-10, 1e-10 104 | xyz_dict = {label: {'X': 0, 'Y': 1e-10, 'Z': 1e-10} for label in list(label2id.keys())} 105 | for d in tqdm(data): 106 | text = ''.join([i[0] for i in d]) 107 | R_p = recognize(text) 108 | R = set(R_p) 109 | T_t = [tuple(i) for i in d if i[1] != 'O'] 110 | T = set(T_t) 111 | X += len(R & T) 112 | Y += len(R) 113 | Z += len(T) 114 | # 按标签统计 115 | for t in [tuple(i) for i in T_t if i[1] != 'O']: 116 | if t in R_p: 117 | R_p.remove(t) 118 | xyz_dict[t[1]]['X'] += 1 # 标签label预测正确TP 119 | xyz_dict[t[1]]['Y'] += 1 # 标签label预测正确TP 120 | xyz_dict[t[1]]['Z'] += 1 # 标签label真实数量TP+FN 121 | for p in R_p: 122 | xyz_dict[p[1]]['Y'] += 1 # 标签label预测伪真FP 123 | label_fpr = { 124 | label: {'f1': 2 * xyz['X'] / (xyz['Y'] + xyz['Z']), 'precision': xyz['X'] / xyz['Y'], 125 | 'recall': xyz['X'] / xyz['Z'], 'total': xyz['X']} 126 | for label, xyz in xyz_dict.items()} 127 | for label, fpr in label_fpr.items(): 128 | logger.info( 129 | '%s: f1: %.5f, precision: %.5f, recall: %.5f total: %d\n' % 130 | (label, fpr['f1'], fpr['precision'], fpr['recall'], fpr['total']) 131 | ) 132 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 133 | return f1, precision, recall 134 | 135 | 136 | class Evaluator(keras.callbacks.Callback): 137 | def __init__(self, model, model_path, CRF, NER, recognize, label2id, valid_data, test_data=None): 138 | self.best_val_f1 = 0 139 | self.model = model 140 | self.model_path = model_path 141 | self.CRF = CRF 142 | self.NER = NER 143 | self.recognize = recognize 144 | self.valid_data = valid_data 145 | self.test_data = test_data 146 | self.label2id = label2id 147 | 148 | def on_epoch_end(self, epoch, logs=None): 149 | trans = K.eval(self.CRF.trans) 150 | self.NER.trans = trans 151 | # print(self.NER.trans) 152 | f1, precision, recall = evaluate(self.valid_data, self.recognize) 153 | # 保存最优 154 | if f1 >= self.best_val_f1: 155 | self.best_val_f1 = f1 156 | self.model.save_weights(self.model_path) 157 | print( 158 | 'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f \n' % 159 | (f1, precision, recall, self.best_val_f1) 160 | ) 161 | 162 | def on_train_end(self, logs=None): 163 | if self.test_data: 164 | f1, precision, recall = test_evaluate(self.test_data, self.recognize, self.label2id) 165 | print( 166 | 'all_test: f1: %.5f, precision: %.5f, recall: %.5f\n' % 167 | (f1, precision, recall) 168 | ) 169 | else: 170 | print('Done!') 171 | 172 | 173 | class NamedEntityRecognizer(ViterbiDecoder): 174 | """命名实体识别器 175 | """ 176 | 177 | def __init__(self, trans, tokenizer=None, model=None, id2label=None, starts=None, ends=None): 178 | self.tokenizer = tokenizer 179 | self.model = model 180 | self.id2label = id2label 181 | super().__init__(trans, starts, ends) 182 | 183 | def recognize(self, text): 184 | tokens = self.tokenizer.tokenize(text) 185 | while len(tokens) > 512: 186 | tokens.pop(-2) 187 | mapping = self.tokenizer.rematch(text, tokens) 188 | token_ids = self.tokenizer.tokens_to_ids(tokens) 189 | segment_ids = [0] * len(token_ids) 190 | nodes = self.model.predict([[token_ids], [segment_ids]])[0] 191 | labels = self.decode(nodes) 192 | entities, starting = [], False 193 | for i, label in enumerate(labels): 194 | if label > 0: 195 | if label % 2 == 1: 196 | starting = True 197 | entities.append([[i], self.id2label[(label - 1) // 2]]) # [[B_index],label] 198 | elif starting: 199 | entities[-1][0].append(i) # [[B_index,I_index,...I_index],label] 200 | else: 201 | starting = False 202 | else: 203 | starting = False 204 | 205 | return [(text[mapping[w[0]][0]:mapping[w[-1]][-1] + 1], l) # [ ('string',label),...] 206 | for w, l in entities] 207 | 208 | def batch_recognize(self, text: [], maxlen=None): 209 | ret = [] 210 | batch_token_ids, batch_segment_ids, batch_token = [], [], [] 211 | for sentence in text: 212 | tokens = self.tokenizer.tokenize(sentence, max_length=maxlen) 213 | while len(tokens) > 512: 214 | tokens.pop(-2) 215 | batch_token.append(tokens) 216 | token_ids, segment_ids = self.tokenizer.encode(sentence, max_length=maxlen) 217 | batch_token_ids.append(token_ids) 218 | batch_segment_ids.append(segment_ids) 219 | batch_token_ids = sequence_padding(batch_token_ids) 220 | batch_segment_ids = sequence_padding(batch_segment_ids) 221 | nodes = self.model.predict([batch_token_ids, batch_segment_ids]) 222 | for index, node in enumerate(nodes): 223 | pre_dict = [] 224 | labels = self.decode(node) 225 | arguments, starting = [], False 226 | for i, label in enumerate(labels): 227 | if label > 0: 228 | if label % 2 == 1: 229 | starting = True 230 | arguments.append([[i], self.id2label[str((label - 1) // 2)]]) 231 | elif starting: 232 | arguments[-1][0].append(i) 233 | else: 234 | starting = False 235 | else: 236 | starting = False 237 | pre_ = [ 238 | (self.tokenizer.decode(batch_token_ids[index, w[0]:w[-1] + 1], batch_token[index][w[0]:w[-1] + 1]), l, 239 | search( 240 | self.tokenizer.decode(batch_token_ids[index, w[0]:w[-1] + 1], batch_token[index][w[0]:w[-1] + 1]), 241 | text[index])) 242 | for w, l in arguments] 243 | 244 | ret.append(pre_) 245 | return ret 246 | 247 | 248 | -------------------------------------------------------------------------------- /utils/triplet_data_process.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | # author leewfeng 3 | # 2020/12/12 20:05 4 | import json 5 | 6 | import keras 7 | import numpy as np 8 | from bert4keras.snippets import DataGenerator, sequence_padding 9 | from tqdm import tqdm 10 | 11 | from utils.common_tools import search 12 | 13 | 14 | def data_process(train_data_file_path, valid_data_file_path, max_len, params_path): 15 | train_data = json.load(open(train_data_file_path, encoding='utf-8')) 16 | 17 | if valid_data_file_path: 18 | train_data_ret = train_data 19 | valid_data_ret = json.load(open(valid_data_file_path, encoding='utf-8')) 20 | else: 21 | split = int(len(train_data) * 0.8) 22 | train_data_ret, valid_data_ret = train_data[:split], train_data[split:] 23 | p2s_dict = {} 24 | p2o_dict = {} 25 | predicate = [] 26 | 27 | for content in train_data: 28 | for spo in content.get('new_spo_list'): 29 | s_type = spo.get('s').get('type') 30 | p_key = spo.get('p').get('entity') 31 | o_type = spo.get('o').get('type') 32 | if p_key not in p2s_dict: 33 | p2s_dict[p_key] = s_type 34 | if p_key not in p2o_dict: 35 | p2o_dict[p_key] = o_type 36 | if p_key not in predicate: 37 | predicate.append(p_key) 38 | i2p_dict = {i: key for i, key in enumerate(predicate)} 39 | p2i_dict = {key: i for i, key in enumerate(predicate)} 40 | save_params = {} 41 | save_params['p2s_dict'] = p2s_dict 42 | save_params['i2p_dict'] = i2p_dict 43 | save_params['p2o_dict'] = p2o_dict 44 | save_params['maxlen'] = max_len 45 | save_params['num_classes'] = len(i2p_dict) 46 | # 数据保存 47 | json.dump(save_params, 48 | open(params_path, 'w', encoding='utf-8'), 49 | ensure_ascii=False, indent=4) 50 | return train_data_ret, valid_data_ret, p2s_dict, p2o_dict, i2p_dict, p2i_dict 51 | 52 | 53 | class Data_Generator(DataGenerator): 54 | """数据生成器 55 | """ 56 | 57 | def __init__(self, data, batch_size, tokenizer, p2i_dict, maxlen): 58 | super().__init__(data, batch_size=batch_size) 59 | self.tokenizer = tokenizer 60 | self.p2i_dict = p2i_dict 61 | self.maxlen = maxlen 62 | 63 | def __iter__(self, random=False): 64 | batch_token_ids, batch_segment_ids = [], [] 65 | batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] 66 | for is_end, d in self.sample(random): 67 | token_ids, segment_ids = self.tokenizer.encode(d['text'], max_length=self.maxlen) 68 | # 整理三元组 {s: [(o_start,0_end, p)]}/{s_token_ids:[]} 69 | spoes = {} 70 | for spo in d['new_spo_list']: 71 | s = spo['s'] 72 | p = spo['p'] 73 | o = spo['o'] 74 | s_token = self.tokenizer.encode(s['entity'])[0][1:-1] 75 | p = self.p2i_dict[p['entity']] 76 | o_token = self.tokenizer.encode(o['entity'])[0][1:-1] 77 | s_idx = search(s_token, token_ids) # s_idx s起始位置 78 | o_idx = search(o_token, token_ids) # o_idx o起始位置 79 | if s_idx != -1 and o_idx != -1: 80 | s = (s_idx, s_idx + len(s_token) - 1) # s s起始结束位置,s的类别 81 | o = (o_idx, o_idx + len(o_token) - 1, p) # o o起始结束位置及p的id,o的类别 82 | if s not in spoes: 83 | spoes[s] = [] 84 | spoes[s].append(o) 85 | if spoes: 86 | # subject标签,采用二维向量分别标记subject的起始位置和结束位置 87 | subject_labels = np.zeros((len(token_ids), 2)) 88 | for s in spoes: 89 | subject_labels[s[0], 0] = 1 90 | subject_labels[s[1], 1] = 1 91 | # 随机选一个subject 92 | start, end = np.array(list(spoes.keys())).T 93 | start = np.random.choice(start) 94 | end = np.random.choice(end[end >= start]) 95 | subject_ids = (start, end) 96 | # 对应的object标签 97 | object_labels = np.zeros((len(token_ids), len(self.p2i_dict), 2)) 98 | for o in spoes.get(subject_ids, []): 99 | object_labels[o[0], o[2], 0] = 1 100 | object_labels[o[1], o[2], 1] = 1 101 | # 构建batch 102 | batch_token_ids.append(token_ids) 103 | batch_segment_ids.append(segment_ids) 104 | batch_subject_labels.append(subject_labels) 105 | batch_subject_ids.append(subject_ids) 106 | batch_object_labels.append(object_labels) 107 | if len(batch_token_ids) == self.batch_size or is_end: 108 | batch_token_ids = sequence_padding(batch_token_ids) 109 | batch_segment_ids = sequence_padding(batch_segment_ids) 110 | batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2)) 111 | batch_subject_ids = np.array(batch_subject_ids) 112 | batch_object_labels = sequence_padding(batch_object_labels, 113 | padding=np.zeros((3, 2))) 114 | yield [ 115 | batch_token_ids, batch_segment_ids, 116 | batch_subject_labels, batch_subject_ids, batch_object_labels 117 | 118 | ], None 119 | batch_token_ids, batch_segment_ids = [], [] 120 | batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] 121 | 122 | 123 | def evaluate(tokenizer, data, predict): 124 | """评估函数,计算f1、precision、recall 125 | """ 126 | X, Y, Z = 1e-10, 1e-10, 1e-10 127 | f = open('dev_pred.json', 'w', encoding='utf-8') 128 | pbar = tqdm() 129 | 130 | class SPO(tuple): 131 | """用来存三元组的类 132 | 表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法, 133 | 使得在判断两个三元组是否等价时容错性更好。 134 | """ 135 | 136 | def __init__(self, spo): 137 | self.spox = ( 138 | tuple(spo[0]), 139 | spo[1], 140 | tuple(spo[2]), 141 | ) 142 | 143 | def __hash__(self): 144 | return self.spox.__hash__() 145 | 146 | def __eq__(self, spo): 147 | return self.spox == spo.spox 148 | 149 | for d in data: 150 | R = set([SPO(spo) for spo in 151 | [[tokenizer.tokenize(spo_str[0][0]), spo_str[1], tokenizer.tokenize(spo_str[2][0])] for 152 | spo_str 153 | in predict(d['text'])]]) 154 | T = set([SPO(spo) for spo in 155 | [[tokenizer.tokenize(spo_str['s']['entity']), spo_str['p']['entity'], 156 | tokenizer.tokenize(spo_str['o']['entity'])] for spo_str 157 | in d['new_spo_list']]]) 158 | X += len(R & T) 159 | Y += len(R) 160 | Z += len(T) 161 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 162 | pbar.update() 163 | pbar.set_description('f1: %.5f, precision: %.5f, recall: %.5f' % 164 | (f1, precision, recall)) 165 | s = json.dumps( 166 | { 167 | 'text': d['text'], 168 | 'spo_list': list(T), 169 | 'spo_list_pred': list(R), 170 | 'new': list(R - T), 171 | 'lack': list(T - R), 172 | }, 173 | ensure_ascii=False, 174 | indent=4) 175 | f.write(s + '/n') 176 | pbar.close() 177 | f.close() 178 | return f1, precision, recall 179 | 180 | 181 | class Evaluator(keras.callbacks.Callback): 182 | """评估和保存模型 183 | """ 184 | 185 | def __init__(self, model, model_path, tokenizer, predict, optimizer, valid_data): 186 | self.EMAer = optimizer 187 | self.best_val_f1 = 0. 188 | self.model = model 189 | self.model_path = model_path 190 | self.tokenizer = tokenizer 191 | self.predict = predict 192 | self.valid_data = valid_data 193 | 194 | def on_epoch_end(self, epoch, logs=None): 195 | self.EMAer.apply_ema_weights() 196 | f1, precision, recall = evaluate(self.tokenizer, self.valid_data, self.predict) 197 | if f1 >= self.best_val_f1: 198 | self.best_val_f1 = f1 199 | self.model.save_weights(self.model_path) 200 | self.EMAer.reset_old_weights() 201 | print('f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f/n' % 202 | (f1, precision, recall, self.best_val_f1)) 203 | --------------------------------------------------------------------------------