├── .gitignore ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── _config.yml ├── docs ├── Text classification.png ├── cluster_train_seg_samples.png ├── img.png ├── logo.png ├── logo.svg └── wechat.jpeg ├── examples ├── albert_classification_zh_demo.py ├── baidu_extract_2020_train.csv ├── bert_classification_en_demo.py ├── bert_classification_tnews_demo.py ├── bert_classification_zh_demo.py ├── bert_hierarchical_classification_zh_demo.py ├── bert_multilabel_classification_en_demo.py ├── bert_multilabel_classification_zh_demo.py ├── cluster_demo.py ├── fasttext_classification_demo.py ├── lr_classification_demo.py ├── lr_en_classification_demo.py ├── multilabel_jd_comments.csv ├── my_vectorizer_demo.py ├── onnx_predict_demo.py ├── onnx_xlnet_predict_demo.py ├── random_forest_classification_demo.py ├── textcnn_classification_demo.py ├── textrnn_classification_demo.py ├── thucnews_train_10w.txt ├── thucnews_train_1w.txt └── visual_feature_importance.ipynb ├── pytextclassifier ├── __init__.py ├── base_classifier.py ├── bert_classfication_utils.py ├── bert_classification_model.py ├── bert_classifier.py ├── bert_multi_label_classification_model.py ├── classic_classifier.py ├── data_helper.py ├── fasttext_classifier.py ├── stopwords.txt ├── textcluster.py ├── textcnn_classifier.py ├── textrnn_classifier.py ├── time_util.py └── tokenizer.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── test_bert_onnx_bs_qps.py ├── test_bert_onnx_speed.py ├── test_fasttext.py ├── test_lr_classification.py └── test_lr_vec.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | *.pyc 91 | .idea 92 | .idea/ 93 | *.iml 94 | .vscode 95 | *.bin -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Xu" 5 | given-names: "Ming" 6 | title: "Pytextclassifier: Text classifier toolkit for NLP" 7 | url: "https://github.com/shibing624/pytextclassifier" 8 | data-released: 2021-10-29 9 | version: 1.1.4 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We are happy to accept your contributions to make `pytextclassifier` better and more awesome! To avoid unnecessary work on either 4 | side, please stick to the following process: 5 | 6 | 1. Check if there is already [an issue](https://github.com/shibing624/pytextclassifier/issues) for your concern. 7 | 2. If there is not, open a new one to start a discussion. We hate to close finished PRs! 8 | 3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the 9 | commit guidelines below. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Logo 4 | 5 |
6 | 7 | ----------------- 8 | 9 | # PyTextClassifier: Python Text Classifier 10 | [![PyPI version](https://badge.fury.io/py/pytextclassifier.svg)](https://badge.fury.io/py/pytextclassifier) 11 | [![Downloads](https://static.pepy.tech/badge/pytextclassifier)](https://pepy.tech/project/pytextclassifier) 12 | [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) 13 | [![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) 14 | [![python_vesion](https://img.shields.io/badge/Python-3.5%2B-green.svg)](requirements.txt) 15 | [![GitHub issues](https://img.shields.io/github/issues/shibing624/pytextclassifier.svg)](https://github.com/shibing624/pytextclassifier/issues) 16 | [![Wechat Group](https://img.shields.io/badge/wechat-group-green.svg?logo=wechat)](#Contact) 17 | 18 | 19 | ## Introduction 20 | PyTextClassifier: Python Text Classifier. It can be applied to the fields of sentiment polarity analysis, text risk classification and so on, 21 | and it supports multiple classification algorithms and clustering algorithms. 22 | 23 | **pytextclassifier** is a python Open Source Toolkit for text classification. The goal is to implement 24 | text analysis algorithm, so to achieve the use in the production environment. 25 | 26 | 文本分类器,提供多种文本分类和聚类算法,支持句子和文档级的文本分类任务,支持二分类、多分类、多标签分类、多层级分类和Kmeans聚类,开箱即用。python3开发。 27 | 28 | **Guide** 29 | 30 | - [Feature](#Feature) 31 | - [Install](#install) 32 | - [Usage](#usage) 33 | - [Dataset](#Dataset) 34 | - [Contact](#Contact) 35 | - [Citation](#Citation) 36 | - [Reference](#reference) 37 | 38 | ## Feature 39 | 40 | **pytextclassifier** has the characteristics 41 | of clear algorithm, high performance and customizable corpus. 42 | 43 | Functions: 44 | ### Classifier 45 | - [x] LogisticRegression 46 | - [x] Random Forest 47 | - [x] Decision Tree 48 | - [x] K-Nearest Neighbours 49 | - [x] Naive bayes 50 | - [x] Xgboost 51 | - [x] Support Vector Machine(SVM) 52 | - [x] TextCNN 53 | - [x] TextRNN 54 | - [x] Fasttext 55 | - [x] BERT 56 | 57 | ### Cluster 58 | - [x] MiniBatchKmeans 59 | 60 | While providing rich functions, **pytextclassifier** internal modules adhere to low coupling, model adherence to inert loading, dictionary publication, and easy to use. 61 | 62 | ## Install 63 | 64 | - Requirements and Installation 65 | 66 | ``` 67 | pip3 install torch # conda install pytorch 68 | pip3 install pytextclassifier 69 | ``` 70 | 71 | or 72 | 73 | ``` 74 | git clone https://github.com/shibing624/pytextclassifier.git 75 | cd pytextclassifier 76 | python3 setup.py install 77 | ``` 78 | 79 | 80 | ## Usage 81 | ### Text Classifier 82 | 83 | ### English Text Classifier 84 | 85 | Including model training, saving, predict, evaluate, for example [examples/lr_en_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_en_classification_demo.py): 86 | 87 | ```python 88 | import sys 89 | 90 | sys.path.append('..') 91 | from pytextclassifier import ClassicClassifier 92 | 93 | if __name__ == '__main__': 94 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr') 95 | # ClassicClassifier support model_name:lr, random_forest, decision_tree, knn, bayes, svm, xgboost 96 | print(m) 97 | data = [ 98 | ('education', 'Student debt to cost Britain billions within decades'), 99 | ('education', 'Chinese education for TV experiment'), 100 | ('sports', 'Middle East and Asia boost investment in top level sports'), 101 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 102 | ] 103 | # train and save best model 104 | m.train(data) 105 | # load best model from model_dir 106 | m.load_model() 107 | predict_label, predict_proba = m.predict([ 108 | 'Abbott government spends $8 million on higher education media blitz']) 109 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 110 | 111 | test_data = [ 112 | ('education', 'Abbott government spends $8 million on higher education media blitz'), 113 | ('sports', 'Middle East and Asia boost investment in top level sports'), 114 | ] 115 | acc_score = m.evaluate_model(test_data) 116 | print(f'acc_score: {acc_score}') 117 | ``` 118 | 119 | output: 120 | 121 | ``` 122 | ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438) 123 | predict_label: ['education'], predict_proba: [0.5378236358492112] 124 | acc_score: 1.0 125 | ``` 126 | 127 | ### Chinese Text Classifier(中文文本分类) 128 | 129 | Text classification compatible with Chinese and English corpora. 130 | 131 | example [examples/lr_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_classification_demo.py) 132 | 133 | ```python 134 | import sys 135 | 136 | sys.path.append('..') 137 | from pytextclassifier import ClassicClassifier 138 | 139 | if __name__ == '__main__': 140 | m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr') 141 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost 142 | data = [ 143 | ('education', '名师指导托福语法技巧:名词的复数形式'), 144 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 145 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 146 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 147 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 148 | ('sports', '米兰客场8战不败国米10年连胜'), 149 | ] 150 | m.train(data) 151 | print(m) 152 | # load best model from model_dir 153 | m.load_model() 154 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 155 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 156 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 157 | 158 | test_data = [ 159 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 160 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 161 | ] 162 | acc_score = m.evaluate_model(test_data) 163 | print(f'acc_score: {acc_score}') # 1.0 164 | 165 | #### train model with 1w data 166 | print('-' * 42) 167 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr') 168 | data_file = 'thucnews_train_1w.txt' 169 | m.train(data_file) 170 | m.load_model() 171 | predict_label, predict_proba = m.predict( 172 | ['顺义北京苏活88平米起精装房在售', 173 | '美EB-5项目“15日快速移民”将推迟']) 174 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 175 | ``` 176 | 177 | output: 178 | 179 | ``` 180 | ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438) 181 | predict_label: ['education' 'sports'], predict_proba: [0.5, 0.598941806741534] 182 | acc_score: 1.0 183 | ------------------------------------------ 184 | predict_label: ['realty' 'education'], predict_proba: [0.7302956923617372, 0.2565005445322923] 185 | ``` 186 | 187 | ### Visual Feature Importance 188 | 189 | Show feature weights of model, and prediction word weight, for example [examples/visual_feature_importance.ipynb](https://github.com/shibing624/pytextclassifier/blob/master/examples/visual_feature_importance.ipynb) 190 | 191 | ```python 192 | import sys 193 | 194 | sys.path.append('..') 195 | from pytextclassifier import ClassicClassifier 196 | import jieba 197 | 198 | tc = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr') 199 | data = [ 200 | ('education', '名师指导托福语法技巧:名词的复数形式'), 201 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 202 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 203 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 204 | ('sports', '米兰客场8战不败国米10年连胜') 205 | ] 206 | tc.train(data) 207 | import eli5 208 | 209 | infer_data = ['高考指导托福语法技巧国际认可', 210 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'] 211 | eli5.show_weights(tc.model, vec=tc.feature) 212 | seg_infer_data = [' '.join(jieba.lcut(i)) for i in infer_data] 213 | eli5.show_prediction(tc.model, seg_infer_data[0], vec=tc.feature, 214 | target_names=['education', 'sports']) 215 | ``` 216 | 217 | output: 218 | 219 | ![img.png](docs/img.png) 220 | 221 | ### Deep Classification model 222 | 223 | 本项目支持以下深度分类模型:FastText、TextCNN、TextRNN、Bert模型,`import`模型对应的方法来调用: 224 | ```python 225 | from pytextclassifier import FastTextClassifier, TextCNNClassifier, TextRNNClassifier, BertClassifier 226 | ``` 227 | 228 | 下面以FastText模型为示例,其他模型的使用方法类似。 229 | 230 | ### FastText 模型 231 | 232 | 训练和预测`FastText`模型示例[examples/fasttext_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/fasttext_classification_demo.py) 233 | 234 | ```python 235 | import sys 236 | 237 | sys.path.append('..') 238 | from pytextclassifier import FastTextClassifier, load_data 239 | 240 | if __name__ == '__main__': 241 | m = FastTextClassifier(output_dir='models/fasttext-toy') 242 | data = [ 243 | ('education', '名师指导托福语法技巧:名词的复数形式'), 244 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 245 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 246 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 247 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 248 | ('sports', '米兰客场8战不败保持连胜'), 249 | ] 250 | m.train(data, num_epochs=3) 251 | print(m) 252 | # load trained best model 253 | m.load_model() 254 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 255 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 256 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 257 | test_data = [ 258 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 259 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 260 | ] 261 | acc_score = m.evaluate_model(test_data) 262 | print(f'acc_score: {acc_score}') # 1.0 263 | 264 | #### train model with 1w data 265 | print('-' * 42) 266 | data_file = 'thucnews_train_1w.txt' 267 | m = FastTextClassifier(output_dir='models/fasttext') 268 | m.train(data_file, names=('labels', 'text'), num_epochs=3) 269 | # load best trained model from model_dir 270 | m.load_model() 271 | predict_label, predict_proba = m.predict( 272 | ['顺义北京苏活88平米起精装房在售', 273 | '美EB-5项目“15日快速移民”将推迟'] 274 | ) 275 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 276 | x, y, df = load_data(data_file) 277 | test_data = df[:100] 278 | acc_score = m.evaluate_model(test_data) 279 | print(f'acc_score: {acc_score}') 280 | ``` 281 | 282 | ### BERT 类模型 283 | 284 | #### 多分类模型 285 | 训练和预测`BERT`多分类模型,示例[examples/bert_classification_zh_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_classification_zh_demo.py) 286 | 287 | ```python 288 | import sys 289 | 290 | sys.path.append('..') 291 | from pytextclassifier import BertClassifier 292 | 293 | if __name__ == '__main__': 294 | m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2, 295 | model_type='bert', model_name='bert-base-chinese', num_epochs=2) 296 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 297 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 298 | data = [ 299 | ('education', '名师指导托福语法技巧:名词的复数形式'), 300 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 301 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 302 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 303 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 304 | ('sports', '米兰客场8战不败国米10年连胜'), 305 | ] 306 | m.train(data) 307 | print(m) 308 | # load trained best model from model_dir 309 | m.load_model() 310 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 311 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 312 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 313 | 314 | test_data = [ 315 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 316 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 317 | ] 318 | acc_score = m.evaluate_model(test_data) 319 | print(f'acc_score: {acc_score}') 320 | 321 | # train model with 1w data file and 10 classes 322 | print('-' * 42) 323 | m = BertClassifier(output_dir='models/bert-chinese', num_classes=10, 324 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, 325 | args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, }) 326 | data_file = 'thucnews_train_1w.txt' 327 | # 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用 328 | m.train(data_file, test_size=0, names=('labels', 'text')) 329 | m.load_model() 330 | predict_label, predict_proba = m.predict( 331 | ['顺义北京苏活88平米起精装房在售', 332 | '美EB-5项目“15日快速移民”将推迟', 333 | '恒生AH溢指收平 A股对H股折价1.95%']) 334 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 335 | ``` 336 | PS:如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用 337 | 338 | #### 多标签分类模型 339 | 分类可以分为多分类和多标签分类。多分类的标签是排他的,而多标签分类的所有标签是不排他的。 340 | 341 | 多标签分类比较直观的理解是,一个样本可以同时拥有几个类别标签, 342 | 比如一首歌的标签可以是流行、轻快,一部电影的标签可以是动作、喜剧、搞笑等,这都是多标签分类的情况。 343 | 344 | 训练和预测`BERT`多标签分类模型,示例[examples/bert_multilabel_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_multilabel_classification_zh_demo.py) 345 | 346 | ```python 347 | import sys 348 | import pandas as pd 349 | 350 | sys.path.append('..') 351 | from pytextclassifier import BertClassifier 352 | 353 | 354 | def load_jd_data(file_path): 355 | """ 356 | Load jd data from file. 357 | @param file_path: 358 | format: content,其他,互联互通,产品功耗,滑轮提手,声音,APP操控性,呼吸灯,外观,底座,制热范围,遥控器电池,味道,制热效果,衣物烘干,体积大小 359 | @return: 360 | """ 361 | data = [] 362 | with open(file_path, 'r', encoding='utf-8') as f: 363 | for line in f: 364 | line = line.strip() 365 | if line.startswith('#'): 366 | continue 367 | if not line: 368 | continue 369 | terms = line.split(',') 370 | if len(terms) != 16: 371 | continue 372 | val = [int(i) for i in terms[1:]] 373 | data.append([terms[0], val]) 374 | return data 375 | 376 | 377 | if __name__ == '__main__': 378 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 379 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 380 | m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15, 381 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True) 382 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists. 383 | train_data = [ 384 | ["一个小时房间仍然没暖和", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], 385 | ["耗电情况:这个没有注意", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 386 | ] 387 | data = load_jd_data('multilabel_jd_comments.csv') 388 | train_data.extend(data) 389 | print(train_data[:5]) 390 | train_df = pd.DataFrame(train_data, columns=["text", "labels"]) 391 | 392 | print(train_df.head()) 393 | m.train(train_df) 394 | print(m) 395 | # Evaluate the model 396 | acc_score = m.evaluate_model(train_df[:20]) 397 | print(f'acc_score: {acc_score}') 398 | 399 | # load trained best model from model_dir 400 | m.load_model() 401 | predict_label, predict_proba = m.predict(['一个小时房间仍然没暖和', '耗电情况:这个没有注意']) 402 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 403 | ``` 404 | 405 | #### 多层级分类模型 406 | **多层级标签分类任务**,如行业分类(一级行业下分二级子行业,再分三级)、产品分类,可以使用多标签分类模型,将多层级标签转换为多标签形式, 407 | 示例[examples/bert_hierarchical_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_hierarchical_classification_zh_demo.py) 408 | 409 | 410 | #### ONNX推理加速 411 | 412 | 支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用。 413 | 414 | - GPU环境下导出ONNX模型,用ONNX模型推理,可以获得10倍以上的推理加速,需要安装`onnxruntime-gpu`库:`pip install onnxruntime-gpu` 415 | - CPU环境下导出ONNX模型,用ONNX模型推理,可以获得6倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime` 416 | 417 | 示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py) 418 | 419 | ```python 420 | import os 421 | import shutil 422 | import sys 423 | import time 424 | 425 | import torch 426 | 427 | sys.path.append('..') 428 | from pytextclassifier import BertClassifier 429 | 430 | m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2, 431 | model_type='bert', model_name='bert-base-chinese', num_epochs=1) 432 | data = [ 433 | ('education', '名师指导托福语法技巧:名词的复数形式'), 434 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 435 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 436 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'), 437 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'), 438 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'), 439 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'), 440 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 441 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 442 | ('sports', '米兰客场8战不败国米10年连胜1'), 443 | ('sports', '米兰客场8战不败国米10年连胜2'), 444 | ('sports', '米兰客场8战不败国米10年连胜3'), 445 | ('sports', '米兰客场8战不败国米10年连胜4'), 446 | ('sports', '米兰客场8战不败国米10年连胜5'), 447 | ] 448 | m.train(data * 10) 449 | m.load_model() 450 | 451 | samples = ['名师指导托福语法技巧', 452 | '米兰客场8战不败', 453 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100 454 | 455 | start_time = time.time() 456 | predict_label_bert, predict_proba_bert = m.predict(samples) 457 | print(f'predict_label_bert size: {len(predict_label_bert)}') 458 | end_time = time.time() 459 | elapsed_time_bert = end_time - start_time 460 | print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds') 461 | 462 | # convert to onnx, and load onnx model to predict, speed up 10x 463 | save_onnx_dir = 'models/bert-chinese-v1/onnx' 464 | m.model.convert_to_onnx(save_onnx_dir) 465 | # copy label_vocab.json to save_onnx_dir 466 | if os.path.exists(m.label_vocab_path): 467 | shutil.copy(m.label_vocab_path, save_onnx_dir) 468 | 469 | # Manually delete the model and clear CUDA cache 470 | del m 471 | torch.cuda.empty_cache() 472 | 473 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir, 474 | args={"onnx": True}) 475 | m.load_model() 476 | start_time = time.time() 477 | predict_label_bert, predict_proba_bert = m.predict(samples) 478 | print(f'predict_label_bert size: {len(predict_label_bert)}') 479 | end_time = time.time() 480 | elapsed_time_onnx = end_time - start_time 481 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds') 482 | ``` 483 | 484 | #### 推理耗时评测 485 | 评测脚本:[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/tests/test_bert_onnx_bs_qps.py) 486 | 487 | 488 | ##### GPU (Tesla T4) 489 | 490 | | Model Type | Batch Size | Average QPS | Average Latency (s) | 491 | |--------------------|-------------|-------------|---------------------| 492 | | **Standard BERT** | 1 | 9.67 | 0.1034 | 493 | | | 8 | 34.85 | 0.0287 | 494 | | | 16 | 42.23 | 0.0237 | 495 | | | 32 | 46.79 | 0.0214 | 496 | | | 64 | 48.79 | 0.0205 | 497 | | | 128 | **50.15** | 0.0199 | 498 | | **ONNX Model** | 1 | 121.89 | 0.0082 | 499 | | | 8 | 123.38 | 0.0081 | 500 | | | 16 | 132.26 | 0.0076 | 501 | | | 32 | 128.33 | 0.0078 | 502 | | | 64 | **134.59** | 0.0074 | 503 | | | 128 | 128.94 | 0.0078 | 504 | 505 | ##### CPU (10核 Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz) 506 | 507 | | Model Type | Batch Size | Average QPS | Average Latency (s) | 508 | |--------------------|-------------|-------------|---------------------| 509 | | **Standard BERT** | 1 | 4.87 | 0.2053 | 510 | | | 8 | **9.21** | 0.1086 | 511 | | | 16 | 7.59 | 0.1318 | 512 | | | 32 | 7.48 | 0.1337 | 513 | | | 64 | 7.01 | 0.1426 | 514 | | | 128 | 6.34 | 0.1576 | 515 | | **ONNX Model** | 1 | **65.25** | 0.0153 | 516 | | | 8 | 52.93 | 0.0189 | 517 | | | 16 | 56.99 | 0.0175 | 518 | | | 32 | 55.03 | 0.0182 | 519 | | | 64 | 56.23 | 0.0178 | 520 | | | 128 | 46.22 | 0.0216 | 521 | 522 | 523 | ## Evaluation 524 | 525 | ### Dataset 526 | 527 | 1. THUCNews中文文本数据集(1.56GB):官方[下载地址](http://thuctc.thunlp.org/),抽样了10万条THUCNews中文文本10分类数据集(6MB),地址:[examples/thucnews_train_10w.txt](https://github.com/shibing624/pytextclassifier/blob/master/examples/thucnews_train_10w.txt)。 528 | 2. TNEWS今日头条中文新闻(短文本)分类 Short Text Classificaiton for News,该数据集(5.1MB)来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等,地址:[tnews_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip) 529 | 530 | ### Evaluation Result 531 | 在THUCNews中文文本10分类数据集(6MB)上评估,模型在测试集(test)评测效果如下: 532 | 533 | | 模型 | acc | 说明 | 534 | |-------------|------------|------------------------------| 535 | | LR | 0.8803 | 逻辑回归Logistics Regression | 536 | | TextCNN | 0.8809 | Kim 2014 经典的CNN文本分类 | 537 | | TextRNN_Att | 0.9022 | BiLSTM+Attention | 538 | | FastText | 0.9177 | bow+bigram+trigram, 效果出奇的好 | 539 | | DPCNN | 0.9125 | 深层金字塔CNN | 540 | | Transformer | 0.8991 | 效果较差 | 541 | | BERT-base | **0.9483** | bert + fc | 542 | | ERNIE | 0.9461 | 比bert略差 | 543 | 544 | 在中文新闻短文本分类数据集TNEWS上评估,模型在开发集(dev)评测效果如下: 545 | 546 | 模型|acc|说明 547 | --|--|-- 548 | BERT-base|**0.5660**|本项目实现 549 | BERT-base|0.5609|CLUE Benchmark Leaderboard结果 [CLUEbenchmark](https://github.com/CLUEbenchmark/CLUE) 550 | 551 | - 以上结果均为分类的准确率(accuracy)结果 552 | - THUCNews数据集评测结果可以基于`examples/thucnews_train_10w.txt`数据用`examples`下的各模型demo复现 553 | - TNEWS数据集评测结果可以下载TNEWS数据集,运行`examples/bert_classification_tnews_demo.py`复现 554 | 555 | ### 命令行调用 556 | 557 | 提供分类模型命令行调用脚本,文件树: 558 | ```bash 559 | pytextclassifier 560 | ├── bert_classifier.py 561 | ├── fasttext_classifier.py 562 | ├── classic_classifier.py 563 | ├── textcnn_classifier.py 564 | └── textrnn_classifier.py 565 | ``` 566 | 567 | 每个文件对应一个模型方法,各模型完全独立,可以直接运行,也方便修改,支持通过`argparse` 修改`--data_path`等参数。 568 | 569 | 直接在终端调用fasttext模型训练: 570 | ```bash 571 | python -m pytextclassifier.fasttext_classifier -h 572 | ``` 573 | 574 | ## Text Cluster 575 | 576 | 577 | Text clustering, for example [examples/cluster_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/cluster_demo.py) 578 | 579 | ```python 580 | import sys 581 | 582 | sys.path.append('..') 583 | from pytextclassifier.textcluster import TextCluster 584 | 585 | if __name__ == '__main__': 586 | m = TextCluster(output_dir='models/cluster-toy', n_clusters=2) 587 | print(m) 588 | data = [ 589 | 'Student debt to cost Britain billions within decades', 590 | 'Chinese education for TV experiment', 591 | 'Abbott government spends $8 million on higher education', 592 | 'Middle East and Asia boost investment in top level sports', 593 | 'Summit Series look launches HBO Canada sports doc series: Mudhar' 594 | ] 595 | m.train(data) 596 | m.load_model() 597 | r = m.predict(['Abbott government spends $8 million on higher education media blitz', 598 | 'Middle East and Asia boost investment in top level sports']) 599 | print(r) 600 | 601 | ########### load chinese train data from 1w data file 602 | from sklearn.feature_extraction.text import TfidfVectorizer 603 | 604 | tcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10) 605 | data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\t', use_col=1) 606 | feature, labels = tcluster.train(data[:5000]) 607 | tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png') 608 | r = tcluster.predict(data[:30]) 609 | print(r) 610 | ``` 611 | 612 | output: 613 | 614 | ``` 615 | TextCluster instance (MiniBatchKMeans(n_clusters=2, n_init=10), , TfidfVectorizer(ngram_range=(1, 2))) 616 | [1 1 1 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1 1 9 1 1 8 1 1 9 1] 617 | ``` 618 | clustering plot image: 619 | 620 | ![cluster_image](https://github.com/shibing624/pytextclassifier/blob/master/docs/cluster_train_seg_samples.png) 621 | 622 | 623 | ## Contact 624 | 625 | - Issue(建议):[![GitHub issues](https://img.shields.io/github/issues/shibing624/pytextclassifier.svg)](https://github.com/shibing624/pytextclassifier/issues) 626 | - 邮件我:xuming: xuming624@qq.com 627 | - 微信我:加我*微信号:xuming624*, 进Python-NLP交流群,备注:*姓名-公司名-NLP* 628 | 629 | 630 | 631 | ## Citation 632 | 633 | 如果你在研究中使用了pytextclassifier,请按如下格式引用: 634 | 635 | APA: 636 | ```latex 637 | Xu, M. Pytextclassifier: Text classifier toolkit for NLP (Version 1.2.0) [Computer software]. https://github.com/shibing624/pytextclassifier 638 | ``` 639 | 640 | BibTeX: 641 | ```latex 642 | @misc{Pytextclassifier, 643 | title={Pytextclassifier: Text classifier toolkit for NLP}, 644 | author={Xu Ming}, 645 | year={2022}, 646 | howpublished={\url{https://github.com/shibing624/pytextclassifier}}, 647 | } 648 | ``` 649 | 650 | 651 | ## License 652 | 653 | 654 | 授权协议为 [The Apache License 2.0](LICENSE),可免费用做商业用途。请在产品说明中附加**pytextclassifier**的链接和授权协议。 655 | 656 | 657 | ## Contribute 658 | 项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点: 659 | 660 | - 在`tests`添加相应的单元测试 661 | - 使用`python setup.py test`来运行所有单元测试,确保所有单测都是通过的 662 | 663 | 之后即可提交PR。 664 | 665 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /docs/Text classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/Text classification.png -------------------------------------------------------------------------------- /docs/cluster_train_seg_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/cluster_train_seg_samples.png -------------------------------------------------------------------------------- /docs/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/img.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/logo.png -------------------------------------------------------------------------------- /docs/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/wechat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/wechat.jpeg -------------------------------------------------------------------------------- /examples/albert_classification_zh_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import BertClassifier 10 | 11 | if __name__ == '__main__': 12 | m = BertClassifier(output_dir='models/albert-chinese-toy', num_classes=2, 13 | model_type='albert', model_name='uer/albert-base-chinese-cluecorpussmall', num_epochs=2) 14 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 15 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 16 | data = [ 17 | ('education', '名师指导托福语法技巧:名词的复数形式'), 18 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 19 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 20 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 21 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 22 | ('sports', '米兰客场8战不败国米10年连胜'), 23 | ] 24 | m.train(data) 25 | print(m) 26 | # load trained best model from output_dir 27 | m.load_model() 28 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 29 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 30 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 31 | 32 | test_data = [ 33 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 34 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 35 | ] 36 | acc_score = m.evaluate_model(test_data) 37 | print(f'acc_score: {acc_score}') # 1.0 38 | 39 | # train model with 1w data file and 10 classes 40 | print('-' * 42) 41 | m = BertClassifier(output_dir='models/albert-chinese', num_classes=10, 42 | model_type='albert', model_name='uer/albert-base-chinese-cluecorpussmall', num_epochs=2, 43 | args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, }) 44 | data_file = 'thucnews_train_1w.txt' 45 | # 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用 46 | m.train(data_file, test_size=0, names=('labels', 'text')) 47 | m.load_model() 48 | predict_label, predict_proba = m.predict( 49 | ['顺义北京苏活88平米起精装房在售', 50 | '美EB-5项目“15日快速移民”将推迟', 51 | '恒生AH溢指收平 A股对H股折价1.95%']) 52 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 53 | -------------------------------------------------------------------------------- /examples/bert_classification_en_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import BertClassifier 10 | 11 | if __name__ == '__main__': 12 | m = BertClassifier(output_dir='models/bert-english-toy', num_classes=2, 13 | model_type='bert', model_name='bert-base-uncased', num_epochs=2) 14 | data = [ 15 | ('education', 'Student debt to cost Britain billions within decades'), 16 | ('education', 'Chinese education for TV experiment'), 17 | ('sports', 'Middle East and Asia boost investment in top level sports'), 18 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 19 | ] 20 | m.train(data) 21 | print(m) 22 | # load trained best model 23 | m.load_model() 24 | predict_label, predict_proba = m.predict(['Abbott government spends $8 million on higher education media blitz', 25 | 'Middle East and Asia boost investment in top level sports']) 26 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 27 | 28 | test_data = [ 29 | ('education', 'Abbott government spends $8 million on higher education media blitz'), 30 | ('sports', 'Middle East and Asia boost investment in top level sports'), 31 | ] 32 | acc_score = m.evaluate_model(test_data) 33 | print(f'acc_score: {acc_score}') 34 | -------------------------------------------------------------------------------- /examples/bert_classification_tnews_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 字节TNEWS新闻分类标准数据集评估模型 5 | """ 6 | import sys 7 | import json 8 | 9 | sys.path.append('..') 10 | from pytextclassifier import BertClassifier 11 | 12 | def convert_json_to_csv(path): 13 | lst = [] 14 | with open(path, 'r', encoding='utf-8') as f: 15 | for line in f: 16 | d = json.loads(line.strip('\n')) 17 | lst.append((d['label_desc'], d['sentence'])) 18 | return lst 19 | 20 | 21 | if __name__ == '__main__': 22 | # train model with TNEWS data file 23 | train_file = './TNEWS/train.json' 24 | dev_file = './TNEWS/dev.json' 25 | train_data = convert_json_to_csv(train_file)[:None] 26 | dev_data = convert_json_to_csv(dev_file)[:None] 27 | print('train_data head:', train_data[:10]) 28 | 29 | m = BertClassifier(output_dir='models/bert-tnews', num_classes=15, 30 | model_type='bert', model_name='bert-base-chinese', 31 | batch_size=32, num_epochs=5) 32 | m.train(train_data, test_size=0) 33 | m.load_model() 34 | # {"label": "102", "label_desc": "news_entertainment", "sentence": "江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物", "keywords": "江疏影,美少女,经纪人,甜甜圈"} 35 | # {"label": "110", "label_desc": "news_military", "sentence": "以色列大规模空袭开始!伊朗多个军事目标遭遇打击,誓言对等反击", "keywords": "伊朗,圣城军,叙利亚,以色列国防军,以色列"} 36 | predict_label, predict_proba = m.predict( 37 | ['江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物', 38 | '以色列大规模空袭开始!伊朗多个军事目标遭遇打击,誓言对等反击', 39 | ]) 40 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 41 | 42 | score = m.evaluate_model(dev_data) 43 | print(f'score: {score}') # acc: 0.5643 44 | 45 | -------------------------------------------------------------------------------- /examples/bert_classification_zh_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import BertClassifier 10 | 11 | if __name__ == '__main__': 12 | m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2, 13 | model_type='bert', model_name='bert-base-chinese', num_epochs=2) 14 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 15 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 16 | data = [ 17 | ('education', '名师指导托福语法技巧:名词的复数形式'), 18 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 19 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 20 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'), 21 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'), 22 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'), 23 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'), 24 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 25 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 26 | ('sports', '米兰客场8战不败国米10年连胜1'), 27 | ('sports', '米兰客场8战不败国米10年连胜2'), 28 | ('sports', '米兰客场8战不败国米10年连胜3'), 29 | ('sports', '米兰客场8战不败国米10年连胜4'), 30 | ('sports', '米兰客场8战不败国米10年连胜5'), 31 | ] 32 | m.train(data) 33 | print(m) 34 | # load trained best model from output_dir 35 | m.load_model() 36 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 37 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 38 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 39 | 40 | test_data = [ 41 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 42 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 43 | ] 44 | acc_score = m.evaluate_model(test_data) 45 | print(f'acc_score: {acc_score}') # 1.0 46 | 47 | # train model with 1w data file and 10 classes 48 | print('-' * 42) 49 | m = BertClassifier(output_dir='models/bert-chinese', num_classes=10, 50 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, 51 | args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, }) 52 | data_file = 'thucnews_train_1w.txt' 53 | # 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用 54 | m.train(data_file, test_size=0, names=('labels', 'text')) 55 | m.load_model() 56 | predict_label, predict_proba = m.predict( 57 | ['顺义北京苏活88平米起精装房在售', 58 | '美EB-5项目“15日快速移民”将推迟', 59 | '恒生AH溢指收平 A股对H股折价1.95%']) 60 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 61 | -------------------------------------------------------------------------------- /examples/bert_hierarchical_classification_zh_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | import pandas as pd 8 | 9 | sys.path.append('..') 10 | from pytextclassifier import BertClassifier 11 | 12 | 13 | def load_baidu_data(file_path): 14 | """ 15 | Load baidu data from file. 16 | @param file_path: 17 | format: content,labels 18 | @return: 19 | """ 20 | data = [] 21 | with open(file_path, 'r', encoding='utf-8') as f: 22 | for line in f: 23 | line = line.strip() 24 | if line.startswith('#'): 25 | continue 26 | if not line: 27 | continue 28 | terms = line.split('\t') 29 | if len(terms) != 2: 30 | continue 31 | data.append([terms[0], terms[1]]) 32 | return data 33 | 34 | 35 | if __name__ == '__main__': 36 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 37 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 38 | m = BertClassifier(output_dir='models/hierarchical-bert-zh-model', num_classes=34, 39 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True) 40 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists. 41 | train_data = [ 42 | ["马国明承认与黄心颖分手,女方被娱乐圈封杀,现如今朋友关系", "人生,人生##分手"], 43 | ["RedmiBook14集显版明天首发:酷睿i5+8G内存3799元", "产品行为,产品行为##发布"], 44 | ] 45 | data = load_baidu_data('baidu_extract_2020_train.csv') 46 | train_data.extend(data) 47 | print(train_data[:5]) 48 | train_df = pd.DataFrame(train_data, columns=["text", "labels"]) 49 | 50 | print(train_df.head()) 51 | m.train(train_df) 52 | print(m) 53 | # Evaluate the model 54 | acc_score = m.evaluate_model(train_df[:20]) 55 | print(f'acc_score: {acc_score}') 56 | 57 | # load trained best model from output_dir 58 | m.load_model() 59 | predict_label, predict_proba = m.predict([ 60 | '马国明承认与黄心颖分手,女方被娱乐圈封杀', 'RedmiBook14集显版明天首发']) 61 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 62 | -------------------------------------------------------------------------------- /examples/bert_multilabel_classification_en_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | import pandas as pd 8 | 9 | sys.path.append('..') 10 | from pytextclassifier import BertClassifier 11 | 12 | if __name__ == '__main__': 13 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 14 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 15 | m = BertClassifier(output_dir='models/multilabel-bert-en-toy', num_classes=6, 16 | model_type='bert', model_name='bert-base-uncased', num_epochs=2, multi_label=True) 17 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists. 18 | train_data = [ 19 | ["Example sentence 1 for multilabel classification.", [1, 1, 1, 1, 0, 1]], 20 | ["This is another example sentence. ", [0, 1, 1, 0, 0, 0]], 21 | ["This is the third example sentence. ", [1, 1, 1, 0, 0, 0]], 22 | ] 23 | train_df = pd.DataFrame(train_data, columns=["text", "labels"]) 24 | 25 | eval_data = [ 26 | ["Example eval sentence for multilabel classification.", [1, 1, 1, 1, 0, 1]], 27 | ["Example eval senntence belonging to class 2", [0, 1, 1, 0, 0, 0]], 28 | ] 29 | eval_df = pd.DataFrame(eval_data, columns=["text", "labels"]) 30 | print(train_df.head()) 31 | # m.train(train_df) 32 | print(m) 33 | # Evaluate the model 34 | acc_score = m.evaluate_model(eval_df) 35 | print(f'acc_score: {acc_score}') 36 | 37 | # load trained best model from output_dir 38 | m.load_model() 39 | predict_label, predict_proba = m.predict(['some new sentence']) 40 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 41 | -------------------------------------------------------------------------------- /examples/bert_multilabel_classification_zh_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | import pandas as pd 8 | 9 | sys.path.append('..') 10 | from pytextclassifier import BertClassifier 11 | 12 | 13 | def load_jd_data(file_path): 14 | """ 15 | Load jd data from file. 16 | @param file_path: 17 | format: content,其他,互联互通,产品功耗,滑轮提手,声音,APP操控性,呼吸灯,外观,底座,制热范围,遥控器电池,味道,制热效果,衣物烘干,体积大小 18 | @return: 19 | """ 20 | data = [] 21 | with open(file_path, 'r', encoding='utf-8') as f: 22 | for line in f: 23 | line = line.strip() 24 | if line.startswith('#'): 25 | continue 26 | if not line: 27 | continue 28 | terms = line.split(',') 29 | if len(terms) != 16: 30 | continue 31 | val = [int(i) for i in terms[1:]] 32 | data.append([terms[0], val]) 33 | return data 34 | 35 | 36 | if __name__ == '__main__': 37 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet' 38 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 39 | m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15, 40 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True) 41 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists. 42 | train_data = [ 43 | ["一个小时房间仍然没暖和", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], 44 | ["耗电情况:这个没有注意", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 45 | ] 46 | data = load_jd_data('multilabel_jd_comments.csv') 47 | train_data.extend(data) 48 | print(train_data[:5]) 49 | train_df = pd.DataFrame(train_data, columns=["text", "labels"]) 50 | 51 | print(train_df.head()) 52 | m.train(train_df) 53 | print(m) 54 | # Evaluate the model 55 | acc_score = m.evaluate_model(train_df[:20]) 56 | print(f'acc_score: {acc_score}') 57 | 58 | # load trained best model from output_dir 59 | m.load_model() 60 | predict_label, predict_proba = m.predict(['一个小时房间仍然没暖和', '耗电情况:这个没有注意']) 61 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 62 | -------------------------------------------------------------------------------- /examples/cluster_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier.textcluster import TextCluster 10 | 11 | if __name__ == '__main__': 12 | m = TextCluster(output_dir='models/cluster-toy', n_clusters=2) 13 | print(m) 14 | data = [ 15 | 'Student debt to cost Britain billions within decades', 16 | 'Chinese education for TV experiment', 17 | 'Abbott government spends $8 million on higher education', 18 | 'Middle East and Asia boost investment in top level sports', 19 | 'Summit Series look launches HBO Canada sports doc series: Mudhar' 20 | ] 21 | m.train(data) 22 | m.load_model() 23 | r = m.predict(['Abbott government spends $8 million on higher education media blitz', 24 | 'Middle East and Asia boost investment in top level sports']) 25 | print(r) 26 | 27 | ########### load chinese train data from 1w data file 28 | from sklearn.feature_extraction.text import TfidfVectorizer 29 | 30 | tcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10) 31 | data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\t', use_col=1) 32 | feature, labels = tcluster.train(data[:5000]) 33 | tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png') 34 | r = tcluster.predict(data[:30]) 35 | print(r) 36 | -------------------------------------------------------------------------------- /examples/fasttext_classification_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import FastTextClassifier, load_data 10 | 11 | if __name__ == '__main__': 12 | m = FastTextClassifier(output_dir='models/fasttext-toy', enable_ngram=False) 13 | data = [ 14 | ('education', '名师指导托福语法技巧:名词的复数形式'), 15 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 16 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 17 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 18 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 19 | ('sports', '米兰客场8战不败保持连胜'), 20 | ] 21 | m.train(data, num_epochs=3) 22 | print(m) 23 | # load trained best model 24 | m.load_model() 25 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 26 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 27 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 28 | test_data = [ 29 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 30 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 31 | ] 32 | acc_score = m.evaluate_model(test_data) 33 | print(f'acc_score: {acc_score}') # 1.0 34 | 35 | #### train model with 1w data 36 | print('-' * 42) 37 | data_file = 'thucnews_train_1w.txt' 38 | m = FastTextClassifier(output_dir='models/fasttext') 39 | m.train(data_file, names=('labels', 'text'), num_epochs=3) 40 | # load best trained model from output_dir 41 | m.load_model() 42 | predict_label, predict_proba = m.predict( 43 | ['顺义北京苏活88平米起精装房在售', 44 | '美EB-5项目“15日快速移民”将推迟'] 45 | ) 46 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 47 | x, y, df = load_data(data_file) 48 | test_data = df[:100] 49 | acc_score = m.evaluate_model(test_data) 50 | print(f'acc_score: {acc_score}') 51 | -------------------------------------------------------------------------------- /examples/lr_classification_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import ClassicClassifier 10 | 11 | if __name__ == '__main__': 12 | m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr') 13 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost 14 | data = [ 15 | ('education', '名师指导托福语法技巧:名词的复数形式'), 16 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 17 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 18 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 19 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 20 | ('sports', '米兰客场8战不败国米10年连胜'), 21 | ] 22 | m.train(data) 23 | print(m) 24 | # load best model from output_dir 25 | m.load_model() 26 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 27 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 28 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 29 | 30 | test_data = [ 31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 33 | ] 34 | acc_score = m.evaluate_model(test_data) 35 | print(f'acc_score: {acc_score}') # 1.0 36 | 37 | # train model with 1w data 38 | print('-' * 42) 39 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr') 40 | data_file = 'thucnews_train_1w.txt' 41 | m.train(data_file) 42 | m.load_model() 43 | predict_label, predict_proba = m.predict( 44 | ['顺义北京苏活88平米起精装房在售', 45 | '美EB-5项目“15日快速移民”将推迟']) 46 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 47 | -------------------------------------------------------------------------------- /examples/lr_en_classification_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import ClassicClassifier 10 | from loguru import logger 11 | 12 | logger.remove() # Remove default log handler 13 | logger.add(sys.stderr, level="INFO") # 设置log级别 14 | 15 | if __name__ == '__main__': 16 | m = ClassicClassifier(output_dir='models/lr-english-toy', model_name_or_model='lr') 17 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost 18 | print(m) 19 | data = [ 20 | ('education', 'Student debt to cost Britain billions within decades'), 21 | ('education', 'Chinese education for TV experiment'), 22 | ('sports', 'Middle East and Asia boost investment in top level sports'), 23 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 24 | ] 25 | # train and save best model 26 | m.train(data) 27 | # load best model from output_dir 28 | m.load_model() 29 | predict_label, predict_proba = m.predict([ 30 | 'Abbott government spends $8 million on higher education media blitz']) 31 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 32 | 33 | test_data = [ 34 | ('education', 'Abbott government spends $8 million on higher education media blitz'), 35 | ('sports', 'Middle East and Asia boost investment in top level sports'), 36 | ] 37 | acc_score = m.evaluate_model(test_data) 38 | print(f'acc_score: {acc_score}') 39 | -------------------------------------------------------------------------------- /examples/my_vectorizer_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import ClassicClassifier 10 | from sklearn.feature_extraction.text import CountVectorizer 11 | 12 | if __name__ == '__main__': 13 | vec = CountVectorizer(ngram_range=(1, 3)) 14 | m = ClassicClassifier(output_dir='models/lr-vec', model_name_or_model='lr', feature_name_or_feature=vec) 15 | 16 | data = [ 17 | ('education', '名师指导托福语法技巧:名词的复数形式'), 18 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 19 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 20 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 21 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 22 | ('sports', '米兰客场8战不败国米10年连胜'), 23 | ] 24 | m.train(data) 25 | m.load_model() 26 | predict_label, predict_label_prob = m.predict(['福建春季公务员考试报名18日截止 2月6日考试']) 27 | print(predict_label, predict_label_prob) 28 | print('classes_: ', m.model.classes_) # the classes ordered as prob 29 | 30 | test_data = [ 31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 33 | ] 34 | acc_score = m.evaluate_model(test_data) 35 | print(acc_score) # 1.0 36 | -------------------------------------------------------------------------------- /examples/onnx_predict_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | 7 | import os 8 | import shutil 9 | import sys 10 | import time 11 | 12 | import torch 13 | 14 | sys.path.append('..') 15 | from pytextclassifier import BertClassifier 16 | 17 | if __name__ == '__main__': 18 | m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2, 19 | model_type='bert', model_name='bert-base-chinese', num_epochs=1) 20 | data = [ 21 | ('education', '名师指导托福语法技巧:名词的复数形式'), 22 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 23 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 24 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'), 25 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'), 26 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'), 27 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'), 28 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 29 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 30 | ('sports', '米兰客场8战不败国米10年连胜1'), 31 | ('sports', '米兰客场8战不败国米10年连胜2'), 32 | ('sports', '米兰客场8战不败国米10年连胜3'), 33 | ('sports', '米兰客场8战不败国米10年连胜4'), 34 | ('sports', '米兰客场8战不败国米10年连胜5'), 35 | ] 36 | # m.train(data * 10) 37 | m.load_model() 38 | 39 | samples = ['名师指导托福语法技巧', 40 | '米兰客场8战不败', 41 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100 42 | 43 | start_time = time.time() 44 | predict_label_bert, predict_proba_bert = m.predict(samples) 45 | print(f'predict_label_bert size: {len(predict_label_bert)}') 46 | end_time = time.time() 47 | elapsed_time_bert = end_time - start_time 48 | print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds') 49 | 50 | # convert to onnx, and load onnx model to predict, speed up 10x 51 | save_onnx_dir = 'models/bert-chinese-v1/onnx' 52 | m.model.convert_to_onnx(save_onnx_dir) 53 | # copy label_vocab.json to save_onnx_dir 54 | if os.path.exists(m.label_vocab_path): 55 | shutil.copy(m.label_vocab_path, save_onnx_dir) 56 | 57 | # Manually delete the model and clear CUDA cache 58 | del m 59 | torch.cuda.empty_cache() 60 | 61 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir, 62 | args={"onnx": True}) 63 | m.load_model() 64 | start_time = time.time() 65 | predict_label_bert, predict_proba_bert = m.predict(samples) 66 | print(f'predict_label_bert size: {len(predict_label_bert)}') 67 | end_time = time.time() 68 | elapsed_time_onnx = end_time - start_time 69 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds') 70 | -------------------------------------------------------------------------------- /examples/onnx_xlnet_predict_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | 7 | import os 8 | import shutil 9 | import sys 10 | import time 11 | 12 | import torch 13 | 14 | sys.path.append('..') 15 | from pytextclassifier import BertClassifier 16 | 17 | if __name__ == '__main__': 18 | m = BertClassifier(output_dir='models/xlnet-chinese-v1', num_classes=2, 19 | model_type='xlnet', model_name='hfl/chinese-xlnet-base', num_epochs=1) 20 | data = [ 21 | ('education', '名师指导托福语法技巧:名词的复数形式'), 22 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 23 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 24 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'), 25 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'), 26 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'), 27 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'), 28 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 29 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 30 | ('sports', '米兰客场8战不败国米10年连胜1'), 31 | ('sports', '米兰客场8战不败国米10年连胜2'), 32 | ('sports', '米兰客场8战不败国米10年连胜3'), 33 | ('sports', '米兰客场8战不败国米10年连胜4'), 34 | ('sports', '米兰客场8战不败国米10年连胜5'), 35 | ] 36 | m.train(data * 1) 37 | m.load_model() 38 | 39 | samples = ['名师指导托福语法技巧', 40 | '米兰客场8战不败', 41 | '恒生AH溢指收平 A股对H股折价1.95%'] * 10 42 | 43 | start_time = time.time() 44 | predict_label_bert, predict_proba_bert = m.predict(samples) 45 | print(f'predict_label_bert size: {len(predict_label_bert)}') 46 | end_time = time.time() 47 | elapsed_time_bert = end_time - start_time 48 | print(f'Standard xlnet model prediction time: {elapsed_time_bert} seconds') 49 | 50 | # convert to onnx, and load onnx model to predict, speed up 10x 51 | save_onnx_dir = 'models/xlnet-chinese-v1/onnx' 52 | m.model.convert_to_onnx(save_onnx_dir) 53 | # copy label_vocab.json to save_onnx_dir 54 | if os.path.exists(m.label_vocab_path): 55 | shutil.copy(m.label_vocab_path, save_onnx_dir) 56 | 57 | # Manually delete the model and clear CUDA cache 58 | del m 59 | torch.cuda.empty_cache() 60 | 61 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='xlnet', model_name=save_onnx_dir, 62 | args={"onnx": True}) 63 | m.load_model() 64 | start_time = time.time() 65 | predict_label_bert, predict_proba_bert = m.predict(samples) 66 | print(f'predict_label_bert size: {len(predict_label_bert)}') 67 | end_time = time.time() 68 | elapsed_time_onnx = end_time - start_time 69 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds') 70 | -------------------------------------------------------------------------------- /examples/random_forest_classification_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import ClassicClassifier 10 | 11 | if __name__ == '__main__': 12 | m = ClassicClassifier(output_dir='models/random_forest-toy', model_name_or_model='random_forest') 13 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost 14 | print(m) 15 | data = [ 16 | ('education', 'Student debt to cost Britain billions within decades'), 17 | ('education', 'Chinese education for TV experiment'), 18 | ('sports', 'Middle East and Asia boost investment in top level sports'), 19 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 20 | ] 21 | m.train(data) 22 | m.load_model() 23 | predict_label, predict_proba = m.predict([ 24 | 'Abbott government spends $8 million on higher education media blitz']) 25 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 26 | 27 | test_data = [ 28 | ('education', 'Abbott government spends $8 million on higher education media blitz'), 29 | ('sports', 'Middle East and Asia boost investment in top level sports'), 30 | ] 31 | acc_score = m.evaluate_model(test_data) 32 | print(f'acc_score: {acc_score}') # 1.0 33 | -------------------------------------------------------------------------------- /examples/textcnn_classification_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import TextCNNClassifier 10 | 11 | if __name__ == '__main__': 12 | m = TextCNNClassifier(output_dir='models/textcnn-toy') 13 | data = [ 14 | ('education', '名师指导托福语法技巧:名词的复数形式'), 15 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 16 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 17 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 18 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 19 | ('sports', '米兰客场8战不败国米10年连胜') 20 | ] 21 | # train and save best model 22 | m.train(data, num_epochs=3, evaluate_during_training_steps=1) 23 | print(m) 24 | # load best model from output_dir 25 | m.load_model() 26 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 27 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 28 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 29 | 30 | test_data = [ 31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 33 | ] 34 | acc_score = m.evaluate_model(test_data) 35 | print(f'acc_score: {acc_score}') # 1.0 36 | -------------------------------------------------------------------------------- /examples/textrnn_classification_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import sys 7 | 8 | sys.path.append('..') 9 | from pytextclassifier import TextRNNClassifier 10 | 11 | if __name__ == '__main__': 12 | m = TextRNNClassifier(output_dir='models/textrnn-toy') 13 | data = [ 14 | ('education', '名师指导托福语法技巧:名词的复数形式'), 15 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 16 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 17 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 18 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 19 | ('sports', '米兰客场8战不败国米10年连胜') 20 | ] 21 | # train and save best model 22 | m.train(data, num_epochs=3, evaluate_during_training_steps=1) 23 | print(m) 24 | # load best model from output_dir 25 | m.load_model() 26 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试', 27 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']) 28 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') 29 | 30 | test_data = [ 31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'), 32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'), 33 | ] 34 | acc_score = m.evaluate_model(test_data) 35 | print(f'acc_score: {acc_score}') 36 | -------------------------------------------------------------------------------- /pytextclassifier/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | __version__ = '1.4.0' 7 | 8 | from pytextclassifier.classic_classifier import ClassicClassifier 9 | from pytextclassifier.fasttext_classifier import FastTextClassifier 10 | from pytextclassifier.textcnn_classifier import TextCNNClassifier 11 | from pytextclassifier.textrnn_classifier import TextRNNClassifier 12 | from pytextclassifier.bert_classifier import BertClassifier 13 | from pytextclassifier.base_classifier import load_data 14 | from pytextclassifier.textcluster import TextCluster 15 | from pytextclassifier.bert_classification_model import BertClassificationModel, BertClassificationArgs 16 | -------------------------------------------------------------------------------- /pytextclassifier/base_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import os 7 | 8 | import pandas as pd 9 | from loguru import logger 10 | 11 | 12 | def load_data(data_list_or_path, header=None, names=('labels', 'text'), delimiter='\t', 13 | labels_sep=',', is_train=False): 14 | """ 15 | Encoding data_list text 16 | @param data_list_or_path: list of (label, text), eg: [(label, text), (label, text) ...] 17 | @param header: read_csv header 18 | @param names: read_csv names 19 | @param delimiter: read_csv sep 20 | @param labels_sep: multi label split 21 | @param is_train: is train data 22 | @return: X, y, data_df 23 | """ 24 | if isinstance(data_list_or_path, list): 25 | data_df = pd.DataFrame(data_list_or_path, columns=names) 26 | elif isinstance(data_list_or_path, str) and os.path.exists(data_list_or_path): 27 | data_df = pd.read_csv(data_list_or_path, header=header, delimiter=delimiter, names=names) 28 | elif isinstance(data_list_or_path, pd.DataFrame): 29 | data_df = data_list_or_path 30 | else: 31 | raise TypeError('should be list or file path, eg: [(label, text), ... ]') 32 | X, y = data_df['text'], data_df['labels'] 33 | labels = set() 34 | if y.size: 35 | for label in y.tolist(): 36 | if isinstance(label, str): 37 | labels.update(label.split(labels_sep)) 38 | elif isinstance(label, list): 39 | labels.update(range(len(label))) 40 | else: 41 | labels.add(label) 42 | num_classes = len(labels) 43 | labels = sorted(list(labels)) 44 | logger.debug(f'loaded data list, X size: {len(X)}, y size: {len(y)}') 45 | if is_train: 46 | logger.debug('num_classes: %d, labels: %s' % (num_classes, labels)) 47 | assert len(X) == len(y) 48 | 49 | return X, y, data_df 50 | 51 | 52 | class ClassifierABC: 53 | """ 54 | Abstract class for classifier 55 | """ 56 | 57 | def train(self, data_list_or_path, model_dir: str, **kwargs): 58 | raise NotImplementedError('train method not implemented.') 59 | 60 | def predict(self, sentences: list): 61 | raise NotImplementedError('predict method not implemented.') 62 | 63 | def evaluate_model(self, **kwargs): 64 | raise NotImplementedError('evaluate_model method not implemented.') 65 | 66 | def evaluate(self, **kwargs): 67 | raise NotImplementedError('evaluate method not implemented.') 68 | 69 | def load_model(self): 70 | raise NotImplementedError('load method not implemented.') 71 | 72 | def save_model(self): 73 | raise NotImplementedError('save method not implemented.') 74 | -------------------------------------------------------------------------------- /pytextclassifier/bert_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: BERT Classifier, support 'bert', 'albert', 'roberta', 'xlnet' model 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import sys 10 | 11 | import numpy as np 12 | import torch 13 | from loguru import logger 14 | from sklearn.model_selection import train_test_split 15 | 16 | sys.path.append('..') 17 | from pytextclassifier.base_classifier import ClassifierABC, load_data 18 | from pytextclassifier.data_helper import set_seed 19 | from pytextclassifier.bert_classification_model import BertClassificationModel, BertClassificationArgs 20 | 21 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 22 | device = 'cuda' if torch.cuda.is_available() else ( 23 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu') 24 | default_use_cuda = torch.cuda.is_available() 25 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 26 | 27 | 28 | class BertClassifier(ClassifierABC): 29 | def __init__( 30 | self, 31 | num_classes, 32 | output_dir="outputs", 33 | model_type='bert', 34 | model_name='bert-base-chinese', 35 | num_epochs=3, 36 | batch_size=8, 37 | max_seq_length=256, 38 | multi_label=False, 39 | labels_sep=',', 40 | use_cuda=None, 41 | args=None, 42 | ): 43 | 44 | """ 45 | Init classification model 46 | @param output_dir: output model dir 47 | @param model_type: support 'bert', 'albert', 'roberta', 'xlnet' 48 | @param model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ... 49 | @param num_classes: number of label classes 50 | @param num_epochs: train epochs 51 | @param batch_size: train batch size 52 | @param max_seq_length: max seq length, trim longer sentence. 53 | @param multi_label: bool, multi label or single label 54 | @param labels_sep: label separator, default is ',' 55 | @param use_cuda: bool, use cuda or not 56 | @param args: dict, train args 57 | """ 58 | default_args = { 59 | "output_dir": output_dir, 60 | "max_seq_length": max_seq_length, 61 | "num_train_epochs": num_epochs, 62 | "train_batch_size": batch_size, 63 | "best_model_dir": os.path.join(output_dir, 'best_model'), 64 | "labels_sep": labels_sep, 65 | } 66 | train_args = BertClassificationArgs() 67 | if args and isinstance(args, dict): 68 | train_args.update_from_dict(args) 69 | train_args.update_from_dict(default_args) 70 | if use_cuda is None: 71 | use_cuda = default_use_cuda 72 | 73 | self.model = BertClassificationModel( 74 | model_type=model_type, 75 | model_name=model_name, 76 | num_labels=num_classes, 77 | multi_label=multi_label, 78 | args=train_args, 79 | use_cuda=use_cuda, 80 | ) 81 | self.output_dir = output_dir 82 | self.model_type = model_type 83 | self.num_classes = num_classes 84 | self.batch_size = batch_size 85 | self.max_seq_length = max_seq_length 86 | self.num_epochs = num_epochs 87 | self.train_args = train_args 88 | self.use_cuda = use_cuda 89 | self.multi_label = multi_label 90 | self.labels_sep = labels_sep 91 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json') 92 | self.is_trained = False 93 | 94 | def __str__(self): 95 | return f'BertClassifier instance ({self.model})' 96 | 97 | def train( 98 | self, 99 | data_list_or_path, 100 | dev_data_list_or_path=None, 101 | header=None, 102 | names=('labels', 'text'), 103 | delimiter='\t', 104 | test_size=0.1, 105 | ): 106 | """ 107 | Train model with data_list_or_path and save model to output_dir 108 | @param data_list_or_path: 109 | @param dev_data_list_or_path: 110 | @param header: 111 | @param names: 112 | @param delimiter: 113 | @param test_size: 114 | @return: 115 | """ 116 | logger.debug('train model ...') 117 | SEED = 1 118 | set_seed(SEED) 119 | # load data 120 | X, y, data_df = load_data( 121 | data_list_or_path, 122 | header=header, 123 | names=names, 124 | delimiter=delimiter, 125 | labels_sep=self.labels_sep, 126 | is_train=True 127 | ) 128 | if self.output_dir: 129 | os.makedirs(self.output_dir, exist_ok=True) 130 | labels_map = self.build_labels_map(y, self.label_vocab_path, self.multi_label, self.labels_sep) 131 | labels_list = sorted(list(labels_map.keys())) 132 | if dev_data_list_or_path is not None: 133 | dev_X, dev_y, dev_df = load_data( 134 | dev_data_list_or_path, 135 | header=header, names=names, 136 | delimiter=delimiter, 137 | labels_sep=self.labels_sep, 138 | is_train=False 139 | ) 140 | train_data = data_df 141 | dev_data = dev_df 142 | else: 143 | if test_size > 0: 144 | train_data, dev_data = train_test_split(data_df, test_size=test_size, random_state=SEED) 145 | else: 146 | train_data = data_df 147 | dev_data = None 148 | logger.debug(f"train_data size: {len(train_data)}") 149 | logger.debug(f'train_data sample:\n{train_data[:3]}') 150 | if dev_data is not None and dev_data.size: 151 | logger.debug(f"dev_data size: {len(dev_data)}") 152 | logger.debug(f'dev_data sample:\n{dev_data[:3]}') 153 | # train model 154 | if self.train_args.lazy_loading: 155 | train_data = data_list_or_path 156 | dev_data = dev_data_list_or_path 157 | if dev_data is not None: 158 | self.model.train_model( 159 | train_data, eval_df=dev_data, 160 | args={'labels_map': labels_map, 'labels_list': labels_list} 161 | ) 162 | else: 163 | self.model.train_model( 164 | train_data, 165 | args={'labels_map': labels_map, 'labels_list': labels_list} 166 | ) 167 | self.is_trained = True 168 | logger.debug('train model done') 169 | 170 | def predict(self, sentences: list): 171 | """ 172 | Predict labels and label probability for sentences. 173 | @param sentences: list, input text list, eg: [text1, text2, ...] 174 | @return: predict_labels, predict_probs 175 | """ 176 | if not self.is_trained: 177 | raise ValueError('model not trained.') 178 | # predict 179 | predictions, raw_outputs = self.model.predict(sentences) 180 | if self.multi_label: 181 | return predictions, raw_outputs 182 | else: 183 | # predict probability 184 | predict_probs = [1 - np.exp(-np.max(raw_output)) for raw_output, prediction in 185 | zip(raw_outputs, predictions)] 186 | return predictions, predict_probs 187 | 188 | def evaluate_model( 189 | self, 190 | data_list_or_path, 191 | header=None, 192 | names=('labels', 'text'), 193 | delimiter='\t', 194 | **kwargs 195 | ): 196 | """ 197 | Evaluate model with data_list_or_path 198 | @param data_list_or_path: 199 | @param header: 200 | @param names: 201 | @param delimiter: 202 | @param kwargs: 203 | @return: 204 | """ 205 | if self.train_args.lazy_loading: 206 | eval_df = data_list_or_path 207 | else: 208 | X_test, y_test, eval_df = load_data( 209 | data_list_or_path, 210 | header=header, 211 | names=names, 212 | delimiter=delimiter, 213 | labels_sep=self.labels_sep 214 | ) 215 | if not self.is_trained: 216 | self.load_model() 217 | result, model_outputs, wrong_predictions = self.model.eval_model( 218 | eval_df, 219 | output_dir=self.output_dir, 220 | **kwargs, 221 | ) 222 | return result 223 | 224 | def load_model(self): 225 | """ 226 | Load model from output_dir 227 | @return: 228 | """ 229 | model_config_file = os.path.join(self.output_dir, 'config.json') 230 | if os.path.exists(model_config_file): 231 | labels_map = json.load(open(self.label_vocab_path, 'r', encoding='utf-8')) 232 | labels_list = sorted(list(labels_map.keys())) 233 | num_classes = len(labels_map) 234 | assert num_classes == self.num_classes, f'num_classes not match, {num_classes} != {self.num_classes}' 235 | self.train_args.update_from_dict({'labels_map': labels_map, 'labels_list': labels_list}) 236 | self.model = BertClassificationModel( 237 | model_type=self.model_type, 238 | model_name=self.output_dir, 239 | num_labels=self.num_classes, 240 | multi_label=self.multi_label, 241 | args=self.train_args, 242 | use_cuda=self.use_cuda, 243 | ) 244 | self.is_trained = True 245 | else: 246 | logger.error(f'{model_config_file} not exists.') 247 | self.is_trained = False 248 | return self.is_trained 249 | 250 | @staticmethod 251 | def build_labels_map(y, label_vocab_path, multi_label=False, labels_sep=','): 252 | """ 253 | Build labels map 254 | @param y: 255 | @param label_vocab_path: 256 | @param multi_label: 257 | @param labels_sep: 258 | @return: 259 | """ 260 | if multi_label: 261 | labels = set() 262 | for label in y.tolist(): 263 | if isinstance(label, str): 264 | labels.update(label.split(labels_sep)) 265 | elif isinstance(label, list): 266 | labels.update(range(len(label))) 267 | else: 268 | labels.add(label) 269 | else: 270 | labels = set(y.tolist()) 271 | labels = sorted(list(labels)) 272 | id_label_map = {id: v for id, v in enumerate(labels)} 273 | label_id_map = {v: k for k, v in id_label_map.items()} 274 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 275 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}") 276 | return label_id_map 277 | 278 | 279 | if __name__ == '__main__': 280 | parser = argparse.ArgumentParser(description='Bert Text Classification') 281 | parser.add_argument('--pretrain_model_type', default='bert', type=str, 282 | help='pretrained huggingface model type') 283 | parser.add_argument('--pretrain_model_name', default='bert-base-chinese', type=str, 284 | help='pretrained huggingface model name') 285 | parser.add_argument('--output_dir', default='models/bert', type=str, help='save model dir') 286 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'), 287 | type=str, help='sample data file path') 288 | parser.add_argument('--num_classes', default=10, type=int, help='number of label classes') 289 | parser.add_argument('--num_epochs', default=3, type=int, help='train epochs') 290 | parser.add_argument('--batch_size', default=64, type=int, help='train batch size') 291 | parser.add_argument('--max_seq_length', default=128, type=int, help='max seq length, trim longer sentence.') 292 | args = parser.parse_args() 293 | print(args) 294 | # create model 295 | m = BertClassifier( 296 | num_classes=args.num_classes, 297 | output_dir=args.output_dir, 298 | model_type=args.pretrain_model_type, 299 | model_name=args.pretrain_model_name, 300 | num_epochs=args.num_epochs, 301 | batch_size=args.batch_size, 302 | max_seq_length=args.max_seq_length, 303 | multi_label=False, 304 | ) 305 | # train model 306 | m.train(data_list_or_path=args.data_path) 307 | # load trained model and predict 308 | m.load_model() 309 | print('best model loaded from file, and predict') 310 | X, y, _ = load_data(args.data_path) 311 | X = X[:5] 312 | y = y[:5] 313 | predict_labels, predict_probs = m.predict(X) 314 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y): 315 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth) 316 | -------------------------------------------------------------------------------- /pytextclassifier/bert_multi_label_classification_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn import BCEWithLogitsLoss 9 | from transformers import ( 10 | BertModel, 11 | BertPreTrainedModel, 12 | FlaubertModel, 13 | LongformerModel, 14 | RemBertModel, 15 | RemBertPreTrainedModel, 16 | XLMModel, 17 | XLMPreTrainedModel, 18 | XLNetModel, 19 | XLNetPreTrainedModel, 20 | ) 21 | from transformers.modeling_utils import SequenceSummary 22 | from transformers.models.albert.modeling_albert import ( 23 | AlbertModel, 24 | AlbertPreTrainedModel, 25 | ) 26 | from transformers.models.longformer.modeling_longformer import ( 27 | LongformerClassificationHead, 28 | LongformerPreTrainedModel, 29 | ) 30 | 31 | try: 32 | import wandb 33 | 34 | wandb_available = True 35 | except ImportError: 36 | wandb_available = False 37 | 38 | 39 | class BertForMultiLabelSequenceClassification(BertPreTrainedModel): 40 | """ 41 | Bert model adapted for multi-label sequence classification 42 | """ 43 | 44 | def __init__(self, config, pos_weight=None): 45 | super(BertForMultiLabelSequenceClassification, self).__init__(config) 46 | self.num_labels = config.num_labels 47 | self.bert = BertModel(config) 48 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 49 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 50 | self.pos_weight = pos_weight 51 | 52 | self.init_weights() 53 | 54 | def forward( 55 | self, 56 | input_ids, 57 | attention_mask=None, 58 | token_type_ids=None, 59 | position_ids=None, 60 | head_mask=None, 61 | labels=None, 62 | ): 63 | outputs = self.bert( 64 | input_ids, 65 | attention_mask=attention_mask, 66 | token_type_ids=token_type_ids, 67 | position_ids=position_ids, 68 | head_mask=head_mask, 69 | ) 70 | 71 | pooled_output = outputs[1] 72 | 73 | pooled_output = self.dropout(pooled_output) 74 | logits = self.classifier(pooled_output) 75 | 76 | outputs = (logits,) + outputs[ 77 | 2: 78 | ] # add hidden states and attention if they are here 79 | 80 | if labels is not None: 81 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 82 | labels = labels.float() 83 | loss = loss_fct( 84 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 85 | ) 86 | outputs = (loss,) + outputs 87 | 88 | return outputs # (loss), logits, (hidden_states), (attentions) 89 | 90 | 91 | class BertForHierarchicalMultiLabelSequenceClassification(BertPreTrainedModel): 92 | def __init__(self, config, pos_weight=None): 93 | super(BertForHierarchicalMultiLabelSequenceClassification, self).__init__(config) 94 | config.update({"mlp_size": 1024}) 95 | self.num_labels = config.num_labels 96 | self.bert = BertModel(config) 97 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 98 | self.mlp = nn.Sequential( 99 | nn.Linear(config.hidden_size + config.num_labels, config.mlp_size), 100 | nn.ReLU(), 101 | nn.Linear(config.mlp_size, config.mlp_size), 102 | nn.ReLU(), 103 | ) 104 | self.classifier = nn.Linear(config.mlp_size, config.num_labels) 105 | self.pos_weight = pos_weight 106 | 107 | self.init_weights() 108 | 109 | def forward( 110 | self, 111 | input_ids, 112 | attention_mask=None, 113 | token_type_ids=None, 114 | position_ids=None, 115 | head_mask=None, 116 | labels=None, 117 | parent_labels=None, 118 | ): 119 | outputs = self.bert( 120 | input_ids, 121 | attention_mask=attention_mask, 122 | token_type_ids=token_type_ids, 123 | position_ids=position_ids, 124 | head_mask=head_mask, 125 | ) 126 | pooled_output = outputs[1] 127 | pooled_output = self.dropout(pooled_output) 128 | concat_output = torch.cat((pooled_output, parent_labels), dim=1) 129 | mlp_output = self.mlp(concat_output) 130 | logits = self.classifier(mlp_output) 131 | 132 | outputs = (logits,) + outputs[2:] 133 | 134 | if labels is not None: 135 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 136 | labels = labels.float() 137 | loss = loss_fct( 138 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 139 | ) 140 | outputs = (loss,) + outputs 141 | 142 | return outputs # (loss), logits, (hidden_states), (attentions) 143 | 144 | 145 | class RemBertForMultiLabelSequenceClassification(RemBertPreTrainedModel): 146 | """ 147 | Bert model adapted for multi-label sequence classification 148 | """ 149 | 150 | def __init__(self, config, pos_weight=None): 151 | super(RemBertForMultiLabelSequenceClassification, self).__init__(config) 152 | self.num_labels = config.num_labels 153 | self.rembert = RemBertModel(config) 154 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 155 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 156 | self.pos_weight = pos_weight 157 | 158 | self.init_weights() 159 | 160 | def forward( 161 | self, 162 | input_ids, 163 | attention_mask=None, 164 | token_type_ids=None, 165 | position_ids=None, 166 | head_mask=None, 167 | labels=None, 168 | ): 169 | outputs = self.rembert( 170 | input_ids, 171 | attention_mask=attention_mask, 172 | token_type_ids=token_type_ids, 173 | position_ids=position_ids, 174 | head_mask=head_mask, 175 | ) 176 | 177 | pooled_output = outputs[1] 178 | 179 | pooled_output = self.dropout(pooled_output) 180 | logits = self.classifier(pooled_output) 181 | 182 | outputs = (logits,) + outputs[ 183 | 2: 184 | ] # add hidden states and attention if they are here 185 | 186 | if labels is not None: 187 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 188 | labels = labels.float() 189 | loss = loss_fct( 190 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 191 | ) 192 | outputs = (loss,) + outputs 193 | 194 | return outputs # (loss), logits, (hidden_states), (attentions) 195 | 196 | 197 | class XLNetForMultiLabelSequenceClassification(XLNetPreTrainedModel): 198 | """ 199 | XLNet model adapted for multi-label sequence classification 200 | """ 201 | 202 | def __init__(self, config, pos_weight=None): 203 | super(XLNetForMultiLabelSequenceClassification, self).__init__(config) 204 | self.num_labels = config.num_labels 205 | self.pos_weight = pos_weight 206 | 207 | self.transformer = XLNetModel(config) 208 | self.sequence_summary = SequenceSummary(config) 209 | self.logits_proj = nn.Linear(config.d_model, config.num_labels) 210 | 211 | self.init_weights() 212 | 213 | def forward( 214 | self, 215 | input_ids=None, 216 | attention_mask=None, 217 | mems=None, 218 | perm_mask=None, 219 | target_mapping=None, 220 | token_type_ids=None, 221 | input_mask=None, 222 | head_mask=None, 223 | inputs_embeds=None, 224 | labels=None, 225 | ): 226 | transformer_outputs = self.transformer( 227 | input_ids, 228 | attention_mask=attention_mask, 229 | mems=mems, 230 | perm_mask=perm_mask, 231 | target_mapping=target_mapping, 232 | token_type_ids=token_type_ids, 233 | input_mask=input_mask, 234 | head_mask=head_mask, 235 | ) 236 | output = transformer_outputs[0] 237 | 238 | output = self.sequence_summary(output) 239 | logits = self.logits_proj(output) 240 | 241 | outputs = (logits,) + transformer_outputs[ 242 | 1: 243 | ] # Keep mems, hidden states, attentions if there are in it 244 | 245 | if labels is not None: 246 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 247 | labels = labels.float() 248 | loss = loss_fct( 249 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 250 | ) 251 | outputs = (loss,) + outputs 252 | 253 | return outputs 254 | 255 | 256 | class XLMForMultiLabelSequenceClassification(XLMPreTrainedModel): 257 | """ 258 | XLM model adapted for multi-label sequence classification 259 | """ 260 | 261 | def __init__(self, config, pos_weight=None): 262 | super(XLMForMultiLabelSequenceClassification, self).__init__(config) 263 | self.num_labels = config.num_labels 264 | self.pos_weight = pos_weight 265 | 266 | self.transformer = XLMModel(config) 267 | self.sequence_summary = SequenceSummary(config) 268 | 269 | self.init_weights() 270 | 271 | def forward( 272 | self, 273 | input_ids=None, 274 | attention_mask=None, 275 | langs=None, 276 | token_type_ids=None, 277 | position_ids=None, 278 | lengths=None, 279 | cache=None, 280 | head_mask=None, 281 | inputs_embeds=None, 282 | labels=None, 283 | ): 284 | transformer_outputs = self.transformer( 285 | input_ids, 286 | attention_mask=attention_mask, 287 | langs=langs, 288 | token_type_ids=token_type_ids, 289 | position_ids=position_ids, 290 | lengths=lengths, 291 | cache=cache, 292 | head_mask=head_mask, 293 | ) 294 | 295 | output = transformer_outputs[0] 296 | logits = self.sequence_summary(output) 297 | 298 | outputs = (logits,) + transformer_outputs[ 299 | 1: 300 | ] # Keep new_mems and attention/hidden states if they are here 301 | 302 | if labels is not None: 303 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 304 | labels = labels.float() 305 | loss = loss_fct( 306 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 307 | ) 308 | outputs = (loss,) + outputs 309 | 310 | return outputs 311 | 312 | 313 | class AlbertForMultiLabelSequenceClassification(AlbertPreTrainedModel): 314 | """ 315 | Alber model adapted for multi-label sequence classification 316 | """ 317 | 318 | def __init__(self, config, pos_weight=None): 319 | super(AlbertForMultiLabelSequenceClassification, self).__init__(config) 320 | 321 | self.num_labels = config.num_labels 322 | self.pos_weight = pos_weight 323 | 324 | self.albert = AlbertModel(config) 325 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 326 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 327 | 328 | self.init_weights() 329 | 330 | def forward( 331 | self, 332 | input_ids=None, 333 | attention_mask=None, 334 | token_type_ids=None, 335 | position_ids=None, 336 | head_mask=None, 337 | inputs_embeds=None, 338 | labels=None, 339 | ): 340 | outputs = self.albert( 341 | input_ids=input_ids, 342 | attention_mask=attention_mask, 343 | token_type_ids=token_type_ids, 344 | position_ids=position_ids, 345 | head_mask=head_mask, 346 | inputs_embeds=inputs_embeds, 347 | ) 348 | 349 | pooled_output = outputs[1] 350 | 351 | pooled_output = self.dropout(pooled_output) 352 | logits = self.classifier(pooled_output) 353 | 354 | outputs = (logits,) + outputs[ 355 | 2: 356 | ] # add hidden states and attention if they are here 357 | 358 | if labels is not None: 359 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 360 | labels = labels.float() 361 | loss = loss_fct( 362 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 363 | ) 364 | outputs = (loss,) + outputs 365 | 366 | return outputs # (loss), logits, (hidden_states), (attentions) 367 | 368 | 369 | class FlaubertForMultiLabelSequenceClassification(FlaubertModel): 370 | """ 371 | Flaubert model adapted for multi-label sequence classification 372 | """ 373 | 374 | def __init__(self, config, pos_weight=None): 375 | super(FlaubertForMultiLabelSequenceClassification, self).__init__(config) 376 | self.num_labels = config.num_labels 377 | self.pos_weight = pos_weight 378 | 379 | self.transformer = FlaubertModel(config) 380 | self.sequence_summary = SequenceSummary(config) 381 | 382 | self.init_weights() 383 | 384 | def forward( 385 | self, 386 | input_ids=None, 387 | attention_mask=None, 388 | langs=None, 389 | token_type_ids=None, 390 | position_ids=None, 391 | lengths=None, 392 | cache=None, 393 | head_mask=None, 394 | inputs_embeds=None, 395 | labels=None, 396 | ): 397 | transformer_outputs = self.transformer( 398 | input_ids, 399 | attention_mask=attention_mask, 400 | langs=langs, 401 | token_type_ids=token_type_ids, 402 | position_ids=position_ids, 403 | lengths=lengths, 404 | cache=cache, 405 | head_mask=head_mask, 406 | ) 407 | 408 | output = transformer_outputs[0] 409 | logits = self.sequence_summary(output) 410 | 411 | outputs = (logits,) + transformer_outputs[ 412 | 1: 413 | ] # Keep new_mems and attention/hidden states if they are here 414 | 415 | if labels is not None: 416 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 417 | labels = labels.float() 418 | loss = loss_fct( 419 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 420 | ) 421 | outputs = (loss,) + outputs 422 | 423 | return outputs 424 | 425 | 426 | class LongformerForMultiLabelSequenceClassification(LongformerPreTrainedModel): 427 | """ 428 | Longformer model adapted for multilabel sequence classification. 429 | """ 430 | 431 | def __init__(self, config, pos_weight=None): 432 | super(LongformerForMultiLabelSequenceClassification, self).__init__(config) 433 | self.num_labels = config.num_labels 434 | self.pos_weight = pos_weight 435 | 436 | self.longformer = LongformerModel(config) 437 | self.classifier = LongformerClassificationHead(config) 438 | 439 | self.init_weights() 440 | 441 | def forward( 442 | self, 443 | input_ids=None, 444 | attention_mask=None, 445 | global_attention_mask=None, 446 | token_type_ids=None, 447 | position_ids=None, 448 | inputs_embeds=None, 449 | labels=None, 450 | ): 451 | if global_attention_mask is None: 452 | global_attention_mask = torch.zeros_like(input_ids) 453 | # global attention on cls token 454 | global_attention_mask[:, 0] = 1 455 | 456 | outputs = self.longformer( 457 | input_ids, 458 | attention_mask=attention_mask, 459 | global_attention_mask=global_attention_mask, 460 | token_type_ids=token_type_ids, 461 | position_ids=position_ids, 462 | ) 463 | sequence_output = outputs[0] 464 | logits = self.classifier(sequence_output) 465 | 466 | outputs = (logits,) + outputs[2:] 467 | if labels is not None: 468 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 469 | labels = labels.float() 470 | loss = loss_fct( 471 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) 472 | ) 473 | outputs = (loss,) + outputs 474 | 475 | return outputs 476 | -------------------------------------------------------------------------------- /pytextclassifier/classic_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: Classic Classifier, support Naive Bayes, Logistic Regression, Random Forest, SVM, XGBoost 5 | and so on sklearn classification model 6 | """ 7 | import argparse 8 | import os 9 | import sys 10 | import pickle 11 | import numpy as np 12 | from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer 13 | from sklearn import metrics 14 | from sklearn.ensemble import RandomForestClassifier 15 | from sklearn.linear_model import LogisticRegression 16 | from sklearn.naive_bayes import MultinomialNB 17 | from sklearn.neighbors import KNeighborsClassifier 18 | from sklearn.svm import SVC 19 | from sklearn.tree import DecisionTreeClassifier 20 | from sklearn.model_selection import train_test_split 21 | from loguru import logger 22 | 23 | sys.path.append('..') 24 | from pytextclassifier.base_classifier import ClassifierABC, load_data 25 | from pytextclassifier.tokenizer import Tokenizer 26 | 27 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 28 | default_stopwords_path = os.path.join(pwd_path, 'stopwords.txt') 29 | 30 | 31 | class ClassicClassifier(ClassifierABC): 32 | def __init__( 33 | self, 34 | output_dir="outputs", 35 | model_name_or_model='lr', 36 | feature_name_or_feature='tfidf', 37 | stopwords_path=default_stopwords_path, 38 | tokenizer=None 39 | ): 40 | """ 41 | 经典机器学习分类模型,支持lr, random_forest, decision_tree, knn, bayes, svm, xgboost 42 | @param output_dir: 模型保存路径 43 | @param model_name_or_model: 44 | @param feature_name_or_feature: 45 | @param stopwords_path: 46 | @param tokenizer: 切词器,默认为jieba切词 47 | """ 48 | self.output_dir = output_dir 49 | if isinstance(model_name_or_model, str): 50 | model_name = model_name_or_model.lower() 51 | if model_name not in ['lr', 'random_forest', 'decision_tree', 'knn', 'bayes', 'xgboost', 'svm']: 52 | raise ValueError('model_name not found.') 53 | logger.debug(f'model_name: {model_name}') 54 | self.model = self.get_model(model_name) 55 | elif hasattr(model_name_or_model, 'fit'): 56 | self.model = model_name_or_model 57 | else: 58 | raise ValueError('model_name_or_model set error.') 59 | if isinstance(feature_name_or_feature, str): 60 | feature_name = feature_name_or_feature.lower() 61 | if feature_name not in ['tfidf', 'count']: 62 | raise ValueError('feature_name not found.') 63 | logger.debug(f'feature_name: {feature_name}') 64 | if feature_name == 'tfidf': 65 | self.feature = TfidfVectorizer(ngram_range=(1, 2)) 66 | else: 67 | self.feature = CountVectorizer(ngram_range=(1, 2)) 68 | elif hasattr(feature_name_or_feature, 'fit_transform'): 69 | self.feature = feature_name_or_feature 70 | else: 71 | raise ValueError('feature_name_or_feature set error.') 72 | self.is_trained = False 73 | self.stopwords = set(self.load_list(stopwords_path)) if stopwords_path and os.path.exists( 74 | stopwords_path) else set() 75 | self.tokenizer = tokenizer if tokenizer else Tokenizer() 76 | 77 | def __str__(self): 78 | return f'ClassicClassifier instance ({self.model}, stopwords size: {len(self.stopwords)})' 79 | 80 | @staticmethod 81 | def get_model(model_type): 82 | if model_type in ["lr", "logistic_regression"]: 83 | model = LogisticRegression(solver='lbfgs', fit_intercept=False) # 快,准确率一般。val mean acc:0.91 84 | elif model_type == "random_forest": 85 | model = RandomForestClassifier(n_estimators=300) # 速度还行,准确率一般。val mean acc:0.93125 86 | elif model_type == "decision_tree": 87 | model = DecisionTreeClassifier() # 速度快,准确率低。val mean acc:0.62 88 | elif model_type == "knn": 89 | model = KNeighborsClassifier() # 速度一般,准确率低。val mean acc:0.675 90 | elif model_type == "bayes": 91 | model = MultinomialNB(alpha=0.1, fit_prior=False) # 速度快,准确率低。val mean acc:0.62 92 | elif model_type == "xgboost": 93 | try: 94 | from xgboost import XGBClassifier 95 | except ImportError: 96 | raise ImportError('xgboost not installed, please install it with "pip install xgboost"') 97 | model = XGBClassifier() # 速度慢,准确率高。val mean acc:0.95 98 | elif model_type == "svm": 99 | model = SVC(kernel='linear', probability=True) # 速度慢,准确率高,val mean acc:0.945 100 | else: 101 | raise ValueError('model type set error.') 102 | return model 103 | 104 | @staticmethod 105 | def load_list(path): 106 | return [word for word in open(path, 'r', encoding='utf-8').read().split()] 107 | 108 | def tokenize_sentences(self, sentences): 109 | """ 110 | Tokenize input text 111 | :param sentences: list of text, eg: [text1, text2, ...] 112 | :return: X_tokens 113 | """ 114 | X_tokens = [' '.join([w for w in self.tokenizer.tokenize(line) if w not in self.stopwords]) for line in 115 | sentences] 116 | return X_tokens 117 | 118 | def load_pkl(self, pkl_path): 119 | """ 120 | 加载词典文件 121 | :param pkl_path: 122 | :return: 123 | """ 124 | with open(pkl_path, 'rb') as f: 125 | result = pickle.load(f) 126 | return result 127 | 128 | def save_pkl(self, vocab, pkl_path, overwrite=True): 129 | """ 130 | 存储文件 131 | :param pkl_path: 132 | :param overwrite: 133 | :return: 134 | """ 135 | if pkl_path and os.path.exists(pkl_path) and not overwrite: 136 | return 137 | if pkl_path: 138 | with open(pkl_path, 'wb') as f: 139 | pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL) # python3 140 | 141 | def train(self, data_list_or_path, header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1): 142 | """ 143 | Train model with data_list_or_path and save model to output_dir 144 | @param data_list_or_path: 145 | @param header: 146 | @param names: 147 | @param delimiter: 148 | @param test_size: 149 | @return: 150 | """ 151 | # load data 152 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True) 153 | # split validation set 154 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=1) 155 | # train model 156 | logger.debug(f"X_train size: {len(X_train)}, X_test size: {len(X_test)}") 157 | assert len(X_train) == len(y_train) 158 | logger.debug(f'X_train sample:\n{X_train[:3]}\ny_train sample:\n{y_train[:3]}') 159 | logger.debug(f'num_classes:{len(set(y))}') 160 | # tokenize text 161 | X_train_tokens = self.tokenize_sentences(X_train) 162 | logger.debug(f'X_train_tokens sample:\n{X_train_tokens[:3]}') 163 | X_train_feat = self.feature.fit_transform(X_train_tokens) 164 | # fit 165 | self.model.fit(X_train_feat, y_train) 166 | self.is_trained = True 167 | # evaluate 168 | test_acc = self.evaluate(X_test, y_test) 169 | logger.debug(f'evaluate, X size: {len(X_test)}, y size: {len(y_test)}, acc: {test_acc}') 170 | # save model 171 | self.save_model() 172 | return test_acc 173 | 174 | def predict(self, sentences: list): 175 | """ 176 | Predict labels and label probability for sentences. 177 | @param sentences: list, input text list, eg: [text1, text2, ...] 178 | @return: predict_label, predict_prob 179 | """ 180 | if not self.is_trained: 181 | raise ValueError('model not trained.') 182 | # tokenize text 183 | X_tokens = self.tokenize_sentences(sentences) 184 | # transform 185 | X_feat = self.feature.transform(X_tokens) 186 | predict_labels = self.model.predict(X_feat) 187 | probs = self.model.predict_proba(X_feat) 188 | predict_probs = [prob[np.where(self.model.classes_ == label)][0] for label, prob in zip(predict_labels, probs)] 189 | return predict_labels, predict_probs 190 | 191 | def evaluate_model(self, data_list_or_path, header=None, names=('labels', 'text'), delimiter='\t'): 192 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter) 193 | return self.evaluate(X_test, y_test) 194 | 195 | def evaluate(self, X_test, y_test): 196 | """ 197 | Evaluate model. 198 | @param X_test: 199 | @param y_test: 200 | @return: accuracy score 201 | """ 202 | if not self.is_trained: 203 | raise ValueError('model not trained.') 204 | # evaluate the model 205 | y_pred, _ = self.predict(X_test) 206 | acc_score = metrics.accuracy_score(y_test, y_pred) 207 | return acc_score 208 | 209 | def load_model(self): 210 | """ 211 | Load model from output_dir 212 | @return: 213 | """ 214 | model_path = os.path.join(self.output_dir, 'classifier_model.pkl') 215 | if os.path.exists(model_path): 216 | self.model = self.load_pkl(model_path) 217 | feature_path = os.path.join(self.output_dir, 'classifier_feature.pkl') 218 | self.feature = self.load_pkl(feature_path) 219 | logger.info(f'Loaded model: {model_path}.') 220 | self.is_trained = True 221 | else: 222 | logger.error(f'{model_path} not exists.') 223 | self.is_trained = False 224 | return self.is_trained 225 | 226 | def save_model(self): 227 | """ 228 | Save model to output_dir 229 | @return: 230 | """ 231 | if self.output_dir: 232 | os.makedirs(self.output_dir, exist_ok=True) 233 | if self.is_trained: 234 | feature_path = os.path.join(self.output_dir, 'classifier_feature.pkl') 235 | self.save_pkl(self.feature, feature_path) 236 | model_path = os.path.join(self.output_dir, 'classifier_model.pkl') 237 | self.save_pkl(self.model, model_path) 238 | logger.info(f'Saved model: {model_path}, feature_path: {feature_path}') 239 | else: 240 | logger.error('model is not trained, please train model first') 241 | return self.model, self.feature 242 | 243 | 244 | if __name__ == '__main__': 245 | parser = argparse.ArgumentParser(description='Text Classification') 246 | parser.add_argument('--model_name', default='lr', type=str, help='model name') 247 | parser.add_argument('--output_dir', default='models/lr', type=str, help='saved model dir') 248 | parser.add_argument('--feature_name', default='tfidf', type=str, help='feature name') 249 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'), 250 | type=str, help='sample data file path') 251 | args = parser.parse_args() 252 | print(args) 253 | # create model 254 | m = ClassicClassifier(output_dir=args.output_dir, model_name_or_model=args.model_name, 255 | feature_name_or_feature=args.feature_name) 256 | # train model 257 | m.train(args.data_path) 258 | # load best trained model and predict 259 | m.load_model() 260 | X, y, _ = load_data(args.data_path) 261 | X = X[:5] 262 | y = y[:5] 263 | predict_labels, predict_probs = m.predict(X) 264 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y): 265 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth) 266 | -------------------------------------------------------------------------------- /pytextclassifier/data_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import json 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | from loguru import logger 12 | from tqdm import tqdm 13 | 14 | 15 | def set_seed(seed): 16 | """ 17 | Set seed for random number generators. 18 | """ 19 | logger.info(f"Set seed for random, numpy and torch: {seed}") 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | if torch.cuda.is_available(): 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | 27 | def build_vocab(contents, tokenizer, max_size, min_freq, unk_token, pad_token): 28 | vocab_dic = {} 29 | for line in tqdm(contents): 30 | line = line.strip() 31 | if not line: 32 | continue 33 | content = line.split('\t')[0] 34 | for word in tokenizer(content): 35 | vocab_dic[word] = vocab_dic.get(word, 0) + 1 36 | vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[ 37 | :max_size] 38 | vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} 39 | vocab_dic.update({unk_token: len(vocab_dic), pad_token: len(vocab_dic) + 1}) 40 | return vocab_dic 41 | 42 | 43 | def load_vocab(vocab_path): 44 | """ 45 | :param vocab_path: 46 | :return: 47 | """ 48 | with open(vocab_path, 'r', encoding='utf-8') as fr: 49 | vocab = json.load(fr) 50 | return vocab 51 | -------------------------------------------------------------------------------- /pytextclassifier/fasttext_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: Fasttext Classifier 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import sys 10 | import time 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from loguru import logger 17 | from sklearn import metrics 18 | from sklearn.model_selection import train_test_split 19 | 20 | sys.path.append('..') 21 | from pytextclassifier.base_classifier import ClassifierABC, load_data 22 | from pytextclassifier.data_helper import set_seed, build_vocab, load_vocab 23 | from pytextclassifier.time_util import get_time_spend 24 | 25 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 26 | device = 'cuda' if torch.cuda.is_available() else ( 27 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu') 28 | 29 | 30 | def build_dataset( 31 | tokenizer, X, y, word_vocab_path, label_vocab_path, max_vocab_size=10000, 32 | max_seq_length=128, unk_token='[UNK]', pad_token='[PAD]' 33 | ): 34 | if os.path.exists(word_vocab_path): 35 | word_id_map = load_vocab(word_vocab_path) 36 | else: 37 | word_id_map = build_vocab(X, tokenizer=tokenizer, max_size=max_vocab_size, min_freq=1, 38 | unk_token=unk_token, pad_token=pad_token) 39 | json.dump(word_id_map, open(word_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 40 | logger.debug('save vocab_path: {}'.format(word_vocab_path)) 41 | logger.debug(f"word vocab size: {len(word_id_map)}, word_vocab_path: {word_vocab_path}") 42 | 43 | if os.path.exists(label_vocab_path): 44 | label_id_map = load_vocab(label_vocab_path) 45 | else: 46 | id_label_map = {id: v for id, v in enumerate(set(y.tolist()))} 47 | label_id_map = {v: k for k, v in id_label_map.items()} 48 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 49 | logger.debug('save label_vocab_path: {}'.format(label_vocab_path)) 50 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}") 51 | 52 | def load_dataset(X, y, max_seq_length=128): 53 | contents = [] 54 | for content, label in zip(X, y): 55 | words_line = [] 56 | token = tokenizer(content) 57 | seq_len = len(token) 58 | if max_seq_length: 59 | if len(token) < max_seq_length: 60 | token.extend([pad_token] * (max_seq_length - len(token))) 61 | else: 62 | token = token[:max_seq_length] 63 | seq_len = max_seq_length 64 | # word to id 65 | for word in token: 66 | words_line.append(word_id_map.get(word, word_id_map.get(unk_token))) 67 | label_id = label_id_map.get(label) 68 | contents.append((words_line, label_id, seq_len, None, None)) 69 | return contents 70 | 71 | dataset = load_dataset(X, y, max_seq_length) 72 | return dataset, word_id_map, label_id_map 73 | 74 | 75 | class DatasetIterater: 76 | def __init__(self, dataset, device, batch_size=32, enable_ngram=True, n_gram_vocab=250499, max_seq_length=128): 77 | self.batch_size = batch_size 78 | self.dataset = dataset 79 | self.n_batches = len(dataset) // batch_size if len(dataset) > batch_size else 1 80 | self.residue = False # 记录batch数量是否为整数 81 | if len(dataset) % self.n_batches != 0: 82 | self.residue = True 83 | self.index = 0 84 | self.device = device 85 | self.enable_ngram = enable_ngram 86 | self.n_gram_vocab = n_gram_vocab 87 | self.max_seq_length = max_seq_length 88 | 89 | def _to_tensor(self, datas): 90 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 91 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 92 | 93 | def biGramHash(sequence, t, buckets): 94 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 95 | return (t1 * 14918087) % buckets 96 | 97 | def triGramHash(sequence, t, buckets): 98 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 99 | t2 = sequence[t - 2] if t - 2 >= 0 else 0 100 | return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets 101 | 102 | # calculate bigram and trigram here fore memory efficiency 103 | bigram = [] 104 | trigram = [] 105 | for _ in datas: 106 | words_line = _[0] 107 | bi = [] 108 | tri = [] 109 | if self.enable_ngram: 110 | buckets = self.n_gram_vocab 111 | # ------ngram------ 112 | for i in range(self.max_seq_length): 113 | bi.append(biGramHash(words_line, i, buckets)) 114 | tri.append(triGramHash(words_line, i, buckets)) 115 | # ----------------- 116 | else: 117 | bi = [0] * self.max_seq_length 118 | tri = [0] * self.max_seq_length 119 | bigram.append(bi) 120 | trigram.append(tri) 121 | bigram = torch.LongTensor(bigram).to(self.device) 122 | trigram = torch.LongTensor(trigram).to(self.device) 123 | 124 | # pad_token前的长度(超过max_seq_length的设为max_seq_length) 125 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 126 | return (x, seq_len, bigram, trigram), y 127 | 128 | def __next__(self): 129 | if self.residue and self.index == self.n_batches: 130 | batches = self.dataset[self.index * self.batch_size: len(self.dataset)] 131 | self.index += 1 132 | batches = self._to_tensor(batches) 133 | return batches 134 | 135 | elif self.index >= self.n_batches: 136 | self.index = 0 137 | raise StopIteration 138 | else: 139 | batches = self.dataset[self.index * self.batch_size: (self.index + 1) * self.batch_size] 140 | self.index += 1 141 | batches = self._to_tensor(batches) 142 | return batches 143 | 144 | def __iter__(self): 145 | return self 146 | 147 | def __len__(self): 148 | if self.residue: 149 | return self.n_batches + 1 150 | else: 151 | return self.n_batches 152 | 153 | 154 | def build_iterator( 155 | dataset, 156 | device, 157 | batch_size=32, 158 | enable_ngram=True, 159 | n_gram_vocab=250499, 160 | max_seq_length=128 161 | ): 162 | return DatasetIterater(dataset, device, batch_size, enable_ngram, n_gram_vocab, max_seq_length) 163 | 164 | 165 | class FastTextModel(nn.Module): 166 | """Bag of Tricks for Efficient Text Classification""" 167 | 168 | def __init__( 169 | self, 170 | vocab_size, 171 | num_classes, 172 | embed_size=200, 173 | n_gram_vocab=250499, 174 | hidden_size=256, 175 | dropout_rate=0.5 176 | ): 177 | super().__init__() 178 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size - 1) 179 | self.embedding_ngram2 = nn.Embedding(n_gram_vocab, embed_size) 180 | self.embedding_ngram3 = nn.Embedding(n_gram_vocab, embed_size) 181 | self.dropout = nn.Dropout(dropout_rate) 182 | self.fc1 = nn.Linear(embed_size * 3, hidden_size) 183 | self.fc2 = nn.Linear(hidden_size, num_classes) 184 | 185 | def forward(self, x): 186 | out_word = self.embedding(x[0]) 187 | out_bigram = self.embedding_ngram2(x[2]) 188 | out_trigram = self.embedding_ngram3(x[3]) 189 | out = torch.cat((out_word, out_bigram, out_trigram), -1) 190 | 191 | out = out.mean(dim=1) 192 | out = self.dropout(out) 193 | out = self.fc1(out) 194 | out = F.relu(out) 195 | out = self.fc2(out) 196 | return out 197 | 198 | 199 | class FastTextClassifier(ClassifierABC): 200 | def __init__( 201 | self, 202 | output_dir="outputs", 203 | dropout_rate=0.5, batch_size=64, max_seq_length=128, 204 | embed_size=200, hidden_size=256, n_gram_vocab=250499, 205 | max_vocab_size=10000, unk_token='[UNK]', pad_token='[PAD]', 206 | tokenizer=None, 207 | enable_ngram=True, 208 | ): 209 | """ 210 | 初始化 211 | @param output_dir: 模型的保存路径 212 | @param dropout_rate: 随机失活 213 | @param batch_size: mini-batch大小 214 | @param max_seq_length: 每句话处理成的长度(短填长切) 215 | @param embed_size: 字向量维度 216 | @param hidden_size: 隐藏层大小 217 | @param n_gram_vocab: ngram 词表大小 218 | @param max_vocab_size: 词表长度限制 219 | @param unk_token: 未知字 220 | @param pad_token: padding符号 221 | @param tokenizer: 切词器 222 | @param enable_ngram: 是否使用ngram 223 | """ 224 | self.output_dir = output_dir 225 | self.is_trained = False 226 | self.model = None 227 | logger.debug(f'Device: {device}') 228 | self.dropout_rate = dropout_rate 229 | self.batch_size = batch_size 230 | self.max_seq_length = max_seq_length 231 | self.embed_size = embed_size 232 | self.hidden_size = hidden_size 233 | self.n_gram_vocab = n_gram_vocab 234 | self.max_vocab_size = max_vocab_size 235 | self.unk_token = unk_token 236 | self.pad_token = pad_token 237 | self.tokenizer = tokenizer if tokenizer else lambda x: [y for y in x] # char-level 238 | self.enable_ngram = enable_ngram 239 | 240 | def __str__(self): 241 | return f'FasttextClassifier instance ({self.model})' 242 | 243 | def train( 244 | self, 245 | data_list_or_path, 246 | header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1, 247 | num_epochs=20, learning_rate=1e-3, 248 | require_improvement=1000, evaluate_during_training_steps=100 249 | ): 250 | """ 251 | Train model with data_list_or_path and save model to output_dir 252 | @param data_list_or_path: 253 | @param header: 254 | @param names: 255 | @param delimiter: 256 | @param test_size: 257 | @param num_epochs: epoch数 258 | @param learning_rate: 学习率 259 | @param require_improvement: 若超过1000batch效果还没提升,则提前结束训练 260 | @param evaluate_during_training_steps: 每隔多少step评估一次模型 261 | @return: 262 | """ 263 | logger.debug('train model...') 264 | SEED = 1 265 | set_seed(SEED) 266 | # load data 267 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True) 268 | del data_df 269 | output_dir = self.output_dir 270 | if output_dir: 271 | os.makedirs(output_dir, exist_ok=True) 272 | word_vocab_path = os.path.join(output_dir, 'word_vocab.json') 273 | label_vocab_path = os.path.join(output_dir, 'label_vocab.json') 274 | save_model_path = os.path.join(output_dir, 'model.pth') 275 | 276 | dataset, self.word_id_map, self.label_id_map = build_dataset( 277 | self.tokenizer, X, y, word_vocab_path, 278 | label_vocab_path, 279 | max_vocab_size=self.max_vocab_size, 280 | max_seq_length=self.max_seq_length, 281 | unk_token=self.unk_token, 282 | pad_token=self.pad_token 283 | ) 284 | train_data, dev_data = train_test_split(dataset, test_size=test_size, random_state=SEED) 285 | logger.debug(f"train_data size: {len(train_data)}, dev_data size: {len(dev_data)}") 286 | logger.debug(f'train_data sample:\n{train_data[:3]}\ndev_data sample:\n{dev_data[:3]}') 287 | train_iter = build_iterator(train_data, device, self.batch_size, self.enable_ngram, 288 | self.n_gram_vocab, self.max_seq_length) 289 | dev_iter = build_iterator(dev_data, device, self.batch_size, self.enable_ngram, 290 | self.n_gram_vocab, self.max_seq_length) 291 | # create model 292 | vocab_size = len(self.word_id_map) 293 | num_classes = len(self.label_id_map) 294 | logger.debug(f'vocab_size:{vocab_size}', 'num_classes:', num_classes) 295 | self.model = FastTextModel( 296 | vocab_size, num_classes, self.embed_size, self.n_gram_vocab, self.hidden_size, 297 | self.dropout_rate 298 | ) 299 | self.model.to(device) 300 | # init_network(self.model) 301 | logger.info(self.model.parameters) 302 | # train model 303 | history = self.train_model_from_data_iterator( 304 | save_model_path, train_iter, dev_iter, num_epochs, learning_rate, 305 | require_improvement, evaluate_during_training_steps 306 | ) 307 | self.is_trained = True 308 | logger.debug('train model done') 309 | return history 310 | 311 | def train_model_from_data_iterator( 312 | self, save_model_path, train_iter, dev_iter, 313 | num_epochs=10, learning_rate=1e-3, 314 | require_improvement=1000, evaluate_during_training_steps=100 315 | ): 316 | history = [] 317 | # train 318 | start_time = time.time() 319 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 320 | 321 | total_batch = 0 # 记录进行到多少batch 322 | dev_best_loss = 1e10 323 | last_improve = 0 # 记录上次验证集loss下降的batch数 324 | flag = False # 记录是否很久没有效果提升 325 | for epoch in range(num_epochs): 326 | logger.debug('Epoch [{}/{}]'.format(epoch + 1, num_epochs)) 327 | for i, (trains, labels) in enumerate(train_iter): 328 | self.model.train() 329 | outputs = self.model(trains) 330 | loss = F.cross_entropy(outputs, labels) 331 | # compute gradient and do SGD step 332 | optimizer.zero_grad() 333 | loss.backward() 334 | optimizer.step() 335 | if total_batch % evaluate_during_training_steps == 0: 336 | # 输出在训练集和验证集上的效果 337 | y_true = labels.cpu() 338 | y_pred = torch.max(outputs, 1)[1].cpu() 339 | train_acc = metrics.accuracy_score(y_true, y_pred) 340 | if dev_iter is not None: 341 | dev_acc, dev_loss = self.evaluate(dev_iter) 342 | if dev_loss < dev_best_loss: 343 | dev_best_loss = dev_loss 344 | torch.save(self.model.state_dict(), save_model_path) 345 | logger.debug(f'Saved model: {save_model_path}') 346 | improve = '*' 347 | last_improve = total_batch 348 | else: 349 | improve = '' 350 | time_dif = get_time_spend(start_time) 351 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},' \ 352 | 'Val Loss:{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'.format( 353 | total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve) 354 | else: 355 | time_dif = get_time_spend(start_time) 356 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Time:{3}'.format( 357 | total_batch, loss.item(), train_acc, time_dif) 358 | logger.debug(msg) 359 | history.append(msg) 360 | self.model.train() 361 | total_batch += 1 362 | if total_batch - last_improve > require_improvement: 363 | # 验证集loss超过1000batch没下降,结束训练 364 | logger.debug("No optimization for a long time, auto-stopping...") 365 | flag = True 366 | break 367 | if flag: 368 | break 369 | return history 370 | 371 | def predict(self, sentences: list): 372 | """ 373 | Predict labels and label probability for sentences. 374 | @param sentences: list, input text list, eg: [text1, text2, ...] 375 | @return: predict_label, predict_prob 376 | """ 377 | if not self.is_trained: 378 | raise ValueError('model not trained.') 379 | self.model.eval() 380 | 381 | def biGramHash(sequence, t, buckets): 382 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 383 | return (t1 * 14918087) % buckets 384 | 385 | def triGramHash(sequence, t, buckets): 386 | t1 = sequence[t - 1] if t - 1 >= 0 else 0 387 | t2 = sequence[t - 2] if t - 2 >= 0 else 0 388 | return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets 389 | 390 | def load_dataset(X, max_seq_length=128): 391 | contents = [] 392 | for content in X: 393 | words_line = [] 394 | token = self.tokenizer(content) 395 | seq_len = len(token) 396 | if max_seq_length: 397 | if len(token) < max_seq_length: 398 | token.extend([self.pad_token] * (max_seq_length - len(token))) 399 | else: 400 | token = token[:max_seq_length] 401 | seq_len = max_seq_length 402 | # word to id 403 | for word in token: 404 | words_line.append(self.word_id_map.get(word, self.word_id_map.get(self.unk_token))) 405 | # fasttext ngram 406 | bigram = [] 407 | trigram = [] 408 | if self.enable_ngram: 409 | buckets = self.n_gram_vocab 410 | # ------ngram------ 411 | for i in range(max_seq_length): 412 | bigram.append(biGramHash(words_line, i, buckets)) 413 | trigram.append(triGramHash(words_line, i, buckets)) 414 | # ----------------- 415 | else: 416 | bigram = [0] * max_seq_length 417 | trigram = [0] * max_seq_length 418 | contents.append((words_line, 0, seq_len, bigram, trigram)) 419 | return contents 420 | 421 | data = load_dataset(sentences, self.max_seq_length) 422 | data_iter = build_iterator(data, device, self.batch_size) 423 | # predict probs 424 | predict_all = np.array([], dtype=int) 425 | proba_all = np.array([], dtype=float) 426 | with torch.no_grad(): 427 | for texts, _ in data_iter: 428 | outputs = self.model(texts) 429 | logit = F.softmax(outputs, dim=1).detach().cpu().numpy() 430 | pred = np.argmax(logit, axis=1) 431 | proba = np.max(logit, axis=1) 432 | 433 | predict_all = np.append(predict_all, pred) 434 | proba_all = np.append(proba_all, proba) 435 | id_label_map = {v: k for k, v in self.label_id_map.items()} 436 | predict_labels = [id_label_map.get(i) for i in predict_all] 437 | predict_probs = proba_all.tolist() 438 | return predict_labels, predict_probs 439 | 440 | def evaluate_model(self, data_list_or_path, header=None, 441 | names=('labels', 'text'), delimiter='\t'): 442 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter) 443 | self.load_model() 444 | data, word_id_map, label_id_map = build_dataset( 445 | self.tokenizer, X_test, y_test, 446 | self.word_vocab_path, 447 | self.label_vocab_path, 448 | max_vocab_size=self.max_vocab_size, 449 | max_seq_length=self.max_seq_length, 450 | unk_token=self.unk_token, 451 | pad_token=self.pad_token, 452 | ) 453 | data_iter = build_iterator(data, device, self.batch_size) 454 | return self.evaluate(data_iter)[0] 455 | 456 | def evaluate(self, data_iter): 457 | """ 458 | Evaluate model. 459 | @param data_iter: 460 | @return: accuracy score, loss 461 | """ 462 | if not self.model: 463 | raise ValueError('model not trained.') 464 | self.model.eval() 465 | loss_total = 0.0 466 | predict_all = np.array([], dtype=int) 467 | labels_all = np.array([], dtype=int) 468 | with torch.no_grad(): 469 | for texts, labels in data_iter: 470 | outputs = self.model(texts) 471 | loss = F.cross_entropy(outputs, labels) 472 | loss_total += loss 473 | labels = labels.cpu().numpy() 474 | predic = torch.max(outputs, 1)[1].cpu().numpy() 475 | labels_all = np.append(labels_all, labels) 476 | predict_all = np.append(predict_all, predic) 477 | logger.debug(f"evaluate, last batch, y_true: {labels}, y_pred: {predic}") 478 | acc = metrics.accuracy_score(labels_all, predict_all) 479 | return acc, loss_total / len(data_iter) 480 | 481 | def load_model(self): 482 | """ 483 | Load model from output_dir 484 | @return: 485 | """ 486 | model_path = os.path.join(self.output_dir, 'model.pth') 487 | if os.path.exists(model_path): 488 | self.word_vocab_path = os.path.join(self.output_dir, 'word_vocab.json') 489 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json') 490 | self.word_id_map = load_vocab(self.word_vocab_path) 491 | self.label_id_map = load_vocab(self.label_vocab_path) 492 | vocab_size = len(self.word_id_map) 493 | num_classes = len(self.label_id_map) 494 | self.model = FastTextModel( 495 | vocab_size, num_classes, self.embed_size, self.n_gram_vocab, self.hidden_size, 496 | self.dropout_rate 497 | ) 498 | self.model.load_state_dict(torch.load(model_path, map_location=device)) 499 | self.model.to(device) 500 | self.is_trained = True 501 | else: 502 | logger.error(f'{model_path} not exists.') 503 | self.is_trained = False 504 | return self.is_trained 505 | 506 | 507 | if __name__ == '__main__': 508 | parser = argparse.ArgumentParser(description='Text Classification') 509 | parser.add_argument('--output_dir', default='models/fasttext', type=str, help='save model dir') 510 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'), 511 | type=str, help='sample data file path') 512 | args = parser.parse_args() 513 | print(args) 514 | # create model 515 | m = FastTextClassifier(args.output_dir) 516 | # train model 517 | m.train(data_list_or_path=args.data_path, num_epochs=3) 518 | # load trained model and predict 519 | m.load_model() 520 | print('best model loaded from file, and predict') 521 | X, y, _ = load_data(args.data_path) 522 | X = X[:5] 523 | y = y[:5] 524 | predict_labels, predict_probs = m.predict(X) 525 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y): 526 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth) 527 | -------------------------------------------------------------------------------- /pytextclassifier/textcluster.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import os 7 | import sys 8 | from codecs import open 9 | import pickle 10 | from sklearn.cluster import MiniBatchKMeans 11 | from sklearn.feature_extraction.text import TfidfVectorizer 12 | from loguru import logger 13 | 14 | sys.path.append('..') 15 | from pytextclassifier.tokenizer import Tokenizer 16 | 17 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 18 | default_stopwords_path = os.path.join(pwd_path, 'stopwords.txt') 19 | 20 | 21 | class TextCluster(object): 22 | def __init__( 23 | self, 24 | output_dir="outputs", 25 | model=None, tokenizer=None, feature=None, 26 | stopwords_path=default_stopwords_path, 27 | n_clusters=3, n_init=10, ngram_range=(1, 2), **kwargs 28 | ): 29 | self.output_dir = output_dir 30 | self.model = model if model else MiniBatchKMeans(n_clusters=n_clusters, n_init=n_init) 31 | self.tokenizer = tokenizer if tokenizer else Tokenizer() 32 | self.feature = feature if feature else TfidfVectorizer(ngram_range=ngram_range, **kwargs) 33 | self.stopwords = set(self.load_list(stopwords_path)) if stopwords_path and os.path.exists( 34 | stopwords_path) else set() 35 | self.is_trained = False 36 | 37 | def __str__(self): 38 | return 'TextCluster instance ({}, {}, {})'.format(self.model, self.tokenizer, self.feature) 39 | 40 | @staticmethod 41 | def load_file_data(file_path, sep='\t', use_col=1): 42 | """ 43 | Load text file, format(txt): text 44 | :param file_path: str 45 | :param sep: \t 46 | :param use_col: int or None 47 | :return: list, text list 48 | """ 49 | contents = [] 50 | if not os.path.exists(file_path): 51 | raise ValueError('file not found. path: {}'.format(file_path)) 52 | with open(file_path, "r", encoding='utf-8') as f: 53 | for line in f: 54 | line = line.strip() 55 | if use_col: 56 | contents.append(line.split(sep)[use_col]) 57 | else: 58 | contents.append(line) 59 | logger.info('load file done. path: {}, size: {}'.format(file_path, len(contents))) 60 | return contents 61 | 62 | @staticmethod 63 | def load_list(path): 64 | """ 65 | 加载停用词 66 | :param path: 67 | :return: list 68 | """ 69 | return [word for word in open(path, 'r', encoding='utf-8').read().split()] 70 | 71 | @staticmethod 72 | def load_pkl(pkl_path): 73 | """ 74 | 加载词典文件 75 | :param pkl_path: 76 | :return: 77 | """ 78 | with open(pkl_path, 'rb') as f: 79 | result = pickle.load(f) 80 | return result 81 | 82 | @staticmethod 83 | def save_pkl(vocab, pkl_path, overwrite=True): 84 | """ 85 | 存储文件 86 | :param pkl_path: 87 | :param overwrite: 88 | :return: 89 | """ 90 | if pkl_path and os.path.exists(pkl_path) and not overwrite: 91 | return 92 | if pkl_path: 93 | with open(pkl_path, 'wb') as f: 94 | pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL) 95 | # pickle.dump(vocab, f, protocol=2) # 兼容python2和python3 96 | print("save %s ok." % pkl_path) 97 | 98 | @staticmethod 99 | def show_plt(feature_matrix, labels, image_file='cluster.png'): 100 | """ 101 | Show cluster plt 102 | :param feature_matrix: 103 | :param labels: 104 | :param image_file: 105 | :return: 106 | """ 107 | from sklearn.decomposition import TruncatedSVD 108 | import matplotlib.pyplot as plt 109 | svd = TruncatedSVD() 110 | plot_columns = svd.fit_transform(feature_matrix) 111 | plt.scatter(x=plot_columns[:, 0], y=plot_columns[:, 1], c=labels) 112 | if image_file: 113 | plt.savefig(image_file) 114 | plt.show() 115 | 116 | def tokenize_sentences(self, sentences): 117 | """ 118 | Encoding input text 119 | :param sentences: list of text, eg: [text1, text2, ...] 120 | :return: X_tokens 121 | """ 122 | X_tokens = [' '.join([w for w in self.tokenizer.tokenize(line) if w not in self.stopwords]) for line in 123 | sentences] 124 | return X_tokens 125 | 126 | def train(self, sentences): 127 | """ 128 | Train model and save model 129 | :param sentences: list of text, eg: [text1, text2, ...] 130 | :return: model 131 | """ 132 | logger.debug('train model') 133 | X_tokens = self.tokenize_sentences(sentences) 134 | logger.debug('data tokens top 1: {}'.format(X_tokens[:1])) 135 | feature = self.feature.fit_transform(X_tokens) 136 | # fit cluster 137 | self.model.fit(feature) 138 | labels = self.model.labels_ 139 | logger.debug('cluster labels:{}'.format(labels)) 140 | output_dir = self.output_dir 141 | if output_dir: 142 | os.makedirs(output_dir, exist_ok=True) 143 | feature_path = os.path.join(output_dir, 'cluster_feature.pkl') 144 | self.save_pkl(self.feature, feature_path) 145 | model_path = os.path.join(output_dir, 'cluster_model.pkl') 146 | self.save_pkl(self.model, model_path) 147 | logger.info('save done. feature path: {}, model path: {}'.format(feature_path, model_path)) 148 | 149 | self.is_trained = True 150 | return feature, labels 151 | 152 | def predict(self, X): 153 | """ 154 | Predict label 155 | :param X: list, input text list, eg: [text1, text2, ...] 156 | :return: list, label name 157 | """ 158 | if not self.is_trained: 159 | raise ValueError('model is None, run train first.') 160 | # tokenize text 161 | X_tokens = self.tokenize_sentences(X) 162 | # transform 163 | feat = self.feature.transform(X_tokens) 164 | return self.model.predict(feat) 165 | 166 | def show_clusters(self, feature_matrix, labels, image_file='cluster.png'): 167 | """ 168 | Show cluster plt image 169 | :param feature_matrix: 170 | :param labels: 171 | :param image_file: 172 | :return: 173 | """ 174 | if not self.is_trained: 175 | raise ValueError('model is None, run train first.') 176 | self.show_plt(feature_matrix, labels, image_file) 177 | 178 | def load_model(self): 179 | """ 180 | Load model from output_dir 181 | :param output_dir: path 182 | :return: None 183 | """ 184 | model_path = os.path.join(self.output_dir, 'cluster_model.pkl') 185 | if not os.path.exists(model_path): 186 | raise ValueError("model is not found. please train and save model first.") 187 | self.model = self.load_pkl(model_path) 188 | feature_path = os.path.join(self.output_dir, 'cluster_feature.pkl') 189 | self.feature = self.load_pkl(feature_path) 190 | self.is_trained = True 191 | logger.info('model loaded {}'.format(self.output_dir)) 192 | return self.is_trained 193 | 194 | 195 | if __name__ == '__main__': 196 | m = TextCluster(output_dir='models/cluster', n_clusters=2) 197 | print(m) 198 | data = [ 199 | 'Student debt to cost Britain billions within decades', 200 | 'Chinese education for TV experiment', 201 | 'Abbott government spends $8 million on higher education', 202 | 'Middle East and Asia boost investment in top level sports', 203 | 'Summit Series look launches HBO Canada sports doc series: Mudhar' 204 | ] 205 | feat, labels = m.train(data) 206 | m.show_clusters(feat, labels, image_file='models/cluster/cluster.png') 207 | m.load_model() 208 | r = m.predict(['Abbott government spends $8 million on higher education media blitz', 209 | 'Middle East and Asia boost investment in top level sports']) 210 | print(r) 211 | -------------------------------------------------------------------------------- /pytextclassifier/textcnn_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: TextCNN Classifier 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import sys 10 | import time 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from loguru import logger 17 | from sklearn import metrics 18 | from sklearn.model_selection import train_test_split 19 | 20 | sys.path.append('..') 21 | from pytextclassifier.base_classifier import ClassifierABC, load_data 22 | from pytextclassifier.data_helper import set_seed, build_vocab, load_vocab 23 | from pytextclassifier.time_util import get_time_spend 24 | 25 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 26 | device = 'cuda' if torch.cuda.is_available() else ( 27 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu') 28 | 29 | 30 | def build_dataset(tokenizer, X, y, word_vocab_path, label_vocab_path, max_seq_length=128, 31 | unk_token='[UNK]', pad_token='[PAD]', max_vocab_size=10000): 32 | if os.path.exists(word_vocab_path): 33 | word_id_map = json.load(open(word_vocab_path, 'r', encoding='utf-8')) 34 | else: 35 | word_id_map = build_vocab(X, tokenizer=tokenizer, max_size=max_vocab_size, min_freq=1, 36 | unk_token=unk_token, pad_token=pad_token) 37 | json.dump(word_id_map, open(word_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 38 | logger.debug(f"word vocab size: {len(word_id_map)}, word_vocab_path: {word_vocab_path}") 39 | 40 | if os.path.exists(label_vocab_path): 41 | label_id_map = json.load(open(label_vocab_path, 'r', encoding='utf-8')) 42 | else: 43 | id_label_map = {id: v for id, v in enumerate(set(y.tolist()))} 44 | label_id_map = {v: k for k, v in id_label_map.items()} 45 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 46 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}") 47 | 48 | def load_dataset(X, y, max_seq_length=128): 49 | contents = [] 50 | for content, label in zip(X, y): 51 | words_line = [] 52 | token = tokenizer(content) 53 | seq_len = len(token) 54 | if max_seq_length: 55 | if len(token) < max_seq_length: 56 | token.extend([pad_token] * (max_seq_length - len(token))) 57 | else: 58 | token = token[:max_seq_length] 59 | seq_len = max_seq_length 60 | # word to id 61 | for word in token: 62 | words_line.append(word_id_map.get(word, word_id_map.get(unk_token))) 63 | label_id = label_id_map.get(label) 64 | contents.append((words_line, label_id, seq_len)) 65 | return contents 66 | 67 | dataset = load_dataset(X, y, max_seq_length) 68 | return dataset, word_id_map, label_id_map 69 | 70 | 71 | class DatasetIterater: 72 | def __init__(self, dataset, device, batch_size=32): 73 | self.batch_size = batch_size 74 | self.dataset = dataset 75 | self.n_batches = len(dataset) // batch_size if len(dataset) > batch_size else 1 76 | self.residue = False # 记录batch数量是否为整数 77 | if len(dataset) % self.n_batches != 0: 78 | self.residue = True 79 | self.index = 0 80 | self.device = device 81 | 82 | def _to_tensor(self, datas): 83 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 84 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 85 | 86 | # pad前的长度(超过max_seq_length的设为max_seq_length) 87 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 88 | return (x, seq_len), y 89 | 90 | def __next__(self): 91 | if self.residue and self.index == self.n_batches: 92 | batches = self.dataset[self.index * self.batch_size: len(self.dataset)] 93 | self.index += 1 94 | batches = self._to_tensor(batches) 95 | return batches 96 | 97 | elif self.index >= self.n_batches: 98 | self.index = 0 99 | raise StopIteration 100 | else: 101 | batches = self.dataset[self.index * self.batch_size: (self.index + 1) * self.batch_size] 102 | self.index += 1 103 | batches = self._to_tensor(batches) 104 | return batches 105 | 106 | def __iter__(self): 107 | return self 108 | 109 | def __len__(self): 110 | if self.residue: 111 | return self.n_batches + 1 112 | else: 113 | return self.n_batches 114 | 115 | 116 | def build_iterator(dataset, device, batch_size=32): 117 | return DatasetIterater(dataset, device, batch_size) 118 | 119 | 120 | class TextCNNModel(nn.Module): 121 | """Convolutional Neural Networks for Sentence Classification""" 122 | 123 | def __init__( 124 | self, 125 | vocab_size, 126 | num_classes, 127 | embed_size=200, filter_sizes=(2, 3, 4), num_filters=256, dropout_rate=0.5 128 | ): 129 | """ 130 | Init the TextCNNModel 131 | @param vocab_size: 132 | @param num_classes: 133 | @param embed_size: 134 | @param filter_sizes: 卷积核尺寸 135 | @param num_filters: 卷积核数量(channels数) 136 | @param dropout_rate: 137 | """ 138 | super().__init__() 139 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size - 1) 140 | self.convs = nn.ModuleList( 141 | [nn.Conv2d(1, num_filters, (k, embed_size)) for k in filter_sizes]) 142 | self.dropout = nn.Dropout(dropout_rate) 143 | self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes) 144 | 145 | def conv_and_pool(self, x, conv): 146 | x = F.relu(conv(x)).squeeze(3) 147 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 148 | return x 149 | 150 | def forward(self, x): 151 | out = self.embedding(x[0]) 152 | out = out.unsqueeze(1) 153 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) 154 | out = self.dropout(out) 155 | out = self.fc(out) 156 | return out 157 | 158 | 159 | class TextCNNClassifier(ClassifierABC): 160 | def __init__( 161 | self, 162 | output_dir="outputs", 163 | filter_sizes=(2, 3, 4), num_filters=256, 164 | dropout_rate=0.5, batch_size=64, max_seq_length=128, 165 | embed_size=200, max_vocab_size=10000, 166 | unk_token='[UNK]', pad_token='[PAD]', tokenizer=None, 167 | ): 168 | """ 169 | Init the TextCNNClassifier 170 | @param output_dir: 模型保存路径 171 | @param filter_sizes: 卷积核尺寸 172 | @param num_filters: 卷积核数量(channels数) 173 | @param dropout_rate: 174 | @param batch_size: 175 | @param max_seq_length: 176 | @param embed_size: 177 | @param max_vocab_size: 178 | @param unk_token: 179 | @param pad_token: 180 | @param tokenizer: 切词器,默认为字粒度切分 181 | """ 182 | self.output_dir = output_dir 183 | self.is_trained = False 184 | self.model = None 185 | logger.debug(f'device: {device}') 186 | self.filter_sizes = filter_sizes 187 | self.num_filters = num_filters 188 | self.dropout_rate = dropout_rate 189 | self.batch_size = batch_size 190 | self.max_seq_length = max_seq_length 191 | self.embed_size = embed_size 192 | self.max_vocab_size = max_vocab_size 193 | self.unk_token = unk_token 194 | self.pad_token = pad_token 195 | self.tokenizer = tokenizer if tokenizer else lambda x: [y for y in x] # char-level 196 | 197 | def __str__(self): 198 | return f'TextCNNClassifier instance ({self.model})' 199 | 200 | def train( 201 | self, 202 | data_list_or_path, 203 | header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1, 204 | num_epochs=20, learning_rate=1e-3, 205 | require_improvement=1000, evaluate_during_training_steps=100 206 | ): 207 | """ 208 | Train model with data_list_or_path and save model to output_dir 209 | @param data_list_or_path: 210 | @param output_dir: 211 | @param header: 212 | @param names: 213 | @param delimiter: 214 | @param test_size: 215 | @param num_epochs: epoch数 216 | @param learning_rate: 学习率 217 | @param require_improvement: 若超过1000batch效果还没提升,则提前结束训练 218 | @param evaluate_during_training_steps: 每隔多少step评估一次模型 219 | @return: 220 | """ 221 | logger.debug('train model...') 222 | SEED = 1 223 | set_seed(SEED) 224 | # load data 225 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True) 226 | output_dir = self.output_dir 227 | if output_dir: 228 | os.makedirs(output_dir, exist_ok=True) 229 | word_vocab_path = os.path.join(output_dir, 'word_vocab.json') 230 | label_vocab_path = os.path.join(output_dir, 'label_vocab.json') 231 | save_model_path = os.path.join(output_dir, 'model.pth') 232 | 233 | dataset, self.word_id_map, self.label_id_map = build_dataset( 234 | self.tokenizer, X, y, 235 | word_vocab_path, 236 | label_vocab_path, 237 | max_vocab_size=self.max_vocab_size, 238 | max_seq_length=self.max_seq_length, 239 | unk_token=self.unk_token, pad_token=self.pad_token 240 | ) 241 | train_data, dev_data = train_test_split(dataset, test_size=test_size, random_state=SEED) 242 | logger.debug(f"train_data size: {len(train_data)}, dev_data size: {len(dev_data)}") 243 | logger.debug(f'train_data sample:\n{train_data[:3]}\ndev_data sample:\n{dev_data[:3]}') 244 | train_iter = build_iterator(train_data, device, self.batch_size) 245 | dev_iter = build_iterator(dev_data, device, self.batch_size) 246 | # create model 247 | vocab_size = len(self.word_id_map) 248 | num_classes = len(self.label_id_map) 249 | logger.debug(f'vocab_size:{vocab_size}', 'num_classes:', num_classes) 250 | self.model = TextCNNModel( 251 | vocab_size, num_classes, 252 | embed_size=self.embed_size, 253 | filter_sizes=self.filter_sizes, 254 | num_filters=self.num_filters, 255 | dropout_rate=self.dropout_rate 256 | ) 257 | self.model.to(device) 258 | # init_network(self.model) 259 | logger.info(self.model.parameters) 260 | # train model 261 | history = self.train_model_from_data_iterator(save_model_path, train_iter, dev_iter, num_epochs, learning_rate, 262 | require_improvement, evaluate_during_training_steps) 263 | self.is_trained = True 264 | logger.debug('train model done') 265 | return history 266 | 267 | def train_model_from_data_iterator(self, save_model_path, train_iter, dev_iter, 268 | num_epochs=10, learning_rate=1e-3, 269 | require_improvement=1000, evaluate_during_training_steps=100): 270 | history = [] 271 | # train 272 | start_time = time.time() 273 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 274 | 275 | total_batch = 0 # 记录进行到多少batch 276 | dev_best_loss = 1e10 277 | last_improve = 0 # 记录上次验证集loss下降的batch数 278 | flag = False # 记录是否很久没有效果提升 279 | for epoch in range(num_epochs): 280 | logger.debug('Epoch [{}/{}]'.format(epoch + 1, num_epochs)) 281 | for i, (trains, labels) in enumerate(train_iter): 282 | self.model.train() 283 | outputs = self.model(trains) 284 | loss = F.cross_entropy(outputs, labels) 285 | # compute gradient and do SGD step 286 | optimizer.zero_grad() 287 | loss.backward() 288 | optimizer.step() 289 | if total_batch % evaluate_during_training_steps == 0: 290 | # 输出在训练集和验证集上的效果 291 | y_true = labels.cpu() 292 | y_pred = torch.max(outputs, 1)[1].cpu() 293 | train_acc = metrics.accuracy_score(y_true, y_pred) 294 | if dev_iter is not None: 295 | dev_acc, dev_loss = self.evaluate(dev_iter) 296 | if dev_loss < dev_best_loss: 297 | dev_best_loss = dev_loss 298 | torch.save(self.model.state_dict(), save_model_path) 299 | logger.debug(f'Saved model: {save_model_path}') 300 | improve = '*' 301 | last_improve = total_batch 302 | else: 303 | improve = '' 304 | time_dif = get_time_spend(start_time) 305 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Val Loss:{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'.format( 306 | total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve) 307 | else: 308 | time_dif = get_time_spend(start_time) 309 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Time:{3}'.format( 310 | total_batch, loss.item(), train_acc, time_dif) 311 | logger.debug(msg) 312 | history.append(msg) 313 | self.model.train() 314 | total_batch += 1 315 | if total_batch - last_improve > require_improvement: 316 | # 验证集loss超过1000batch没下降,结束训练 317 | logger.debug("No optimization for a long time, auto-stopping...") 318 | flag = True 319 | break 320 | if flag: 321 | break 322 | return history 323 | 324 | def predict(self, sentences: list): 325 | """ 326 | Predict labels and label probability for sentences. 327 | @param sentences: list, input text list, eg: [text1, text2, ...] 328 | @return: predict_label, predict_prob 329 | """ 330 | if not self.is_trained: 331 | raise ValueError('model not trained.') 332 | self.model.eval() 333 | 334 | def load_dataset(X, max_seq_length=128): 335 | contents = [] 336 | for content in X: 337 | words_line = [] 338 | token = self.tokenizer(content) 339 | seq_len = len(token) 340 | if max_seq_length: 341 | if len(token) < max_seq_length: 342 | token.extend([self.pad_token] * (max_seq_length - len(token))) 343 | else: 344 | token = token[:max_seq_length] 345 | seq_len = max_seq_length 346 | # word to id 347 | for word in token: 348 | words_line.append(self.word_id_map.get(word, self.word_id_map.get(self.unk_token))) 349 | contents.append((words_line, 0, seq_len)) 350 | return contents 351 | 352 | data = load_dataset(sentences, self.max_seq_length) 353 | data_iter = build_iterator(data, device, self.batch_size) 354 | # predict prob 355 | predict_all = np.array([], dtype=int) 356 | proba_all = np.array([], dtype=float) 357 | with torch.no_grad(): 358 | for texts, _ in data_iter: 359 | outputs = self.model(texts) 360 | logit = F.softmax(outputs, dim=1).detach().cpu().numpy() 361 | pred = np.argmax(logit, axis=1) 362 | proba = np.max(logit, axis=1) 363 | 364 | predict_all = np.append(predict_all, pred) 365 | proba_all = np.append(proba_all, proba) 366 | id_label_map = {v: k for k, v in self.label_id_map.items()} 367 | predict_labels = [id_label_map.get(i) for i in predict_all] 368 | predict_probs = proba_all.tolist() 369 | return predict_labels, predict_probs 370 | 371 | def evaluate_model(self, data_list_or_path, header=None, 372 | names=('labels', 'text'), delimiter='\t'): 373 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter) 374 | self.load_model() 375 | data, word_id_map, label_id_map = build_dataset( 376 | self.tokenizer, X_test, y_test, 377 | self.word_vocab_path, 378 | self.label_vocab_path, 379 | max_vocab_size=self.max_vocab_size, 380 | max_seq_length=self.max_seq_length, 381 | unk_token=self.unk_token, 382 | pad_token=self.pad_token, 383 | ) 384 | data_iter = build_iterator(data, device, self.batch_size) 385 | return self.evaluate(data_iter)[0] 386 | 387 | def evaluate(self, data_iter): 388 | """ 389 | Evaluate model. 390 | @param data_iter: 391 | @return: accuracy score, loss 392 | """ 393 | if not self.model: 394 | raise ValueError('model not trained.') 395 | self.model.eval() 396 | loss_total = 0.0 397 | predict_all = np.array([], dtype=int) 398 | labels_all = np.array([], dtype=int) 399 | with torch.no_grad(): 400 | for texts, labels in data_iter: 401 | outputs = self.model(texts) 402 | loss = F.cross_entropy(outputs, labels) 403 | loss_total += loss 404 | labels = labels.cpu().numpy() 405 | predic = torch.max(outputs, 1)[1].cpu().numpy() 406 | labels_all = np.append(labels_all, labels) 407 | predict_all = np.append(predict_all, predic) 408 | logger.debug(f"evaluate, last batch, y_true: {labels}, y_pred: {predic}") 409 | acc = metrics.accuracy_score(labels_all, predict_all) 410 | return acc, loss_total / len(data_iter) 411 | 412 | def load_model(self): 413 | """ 414 | Load model from output_dir 415 | @return: 416 | """ 417 | model_path = os.path.join(self.output_dir, 'model.pth') 418 | if os.path.exists(model_path): 419 | self.word_vocab_path = os.path.join(self.output_dir, 'word_vocab.json') 420 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json') 421 | self.word_id_map = load_vocab(self.word_vocab_path) 422 | self.label_id_map = load_vocab(self.label_vocab_path) 423 | vocab_size = len(self.word_id_map) 424 | num_classes = len(self.label_id_map) 425 | self.model = TextCNNModel( 426 | vocab_size, num_classes, 427 | embed_size=self.embed_size, 428 | filter_sizes=self.filter_sizes, 429 | num_filters=self.num_filters, 430 | dropout_rate=self.dropout_rate 431 | ) 432 | self.model.load_state_dict(torch.load(model_path, map_location=device)) 433 | self.model.to(device) 434 | self.is_trained = True 435 | else: 436 | logger.error(f'{model_path} not exists.') 437 | self.is_trained = False 438 | return self.is_trained 439 | 440 | 441 | if __name__ == '__main__': 442 | parser = argparse.ArgumentParser(description='Text Classification') 443 | parser.add_argument('--output_dir', default='models/textcnn', type=str, help='save model dir') 444 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'), 445 | type=str, help='sample data file path') 446 | args = parser.parse_args() 447 | print(args) 448 | # create model 449 | m = TextCNNClassifier(output_dir=args.output_dir) 450 | # train model 451 | m.train(data_list_or_path=args.data_path, num_epochs=3) 452 | # load trained model and predict 453 | m.load_model() 454 | print('best model loaded from file, and predict') 455 | X, y, _ = load_data(args.data_path) 456 | X = X[:5] 457 | y = y[:5] 458 | predict_labels, predict_probs = m.predict(X) 459 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y): 460 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth) 461 | -------------------------------------------------------------------------------- /pytextclassifier/textrnn_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: TextRNN Attention Classifier 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import sys 10 | import time 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from loguru import logger 17 | from sklearn import metrics 18 | from sklearn.model_selection import train_test_split 19 | 20 | sys.path.append('..') 21 | from pytextclassifier.base_classifier import ClassifierABC, load_data 22 | from pytextclassifier.data_helper import set_seed, build_vocab, load_vocab 23 | from pytextclassifier.time_util import get_time_spend 24 | 25 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 26 | device = 'cuda' if torch.cuda.is_available() else ( 27 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu') 28 | 29 | 30 | def build_dataset(tokenizer, X, y, word_vocab_path, label_vocab_path, max_seq_length=128, 31 | unk_token='[UNK]', pad_token='[PAD]', max_vocab_size=10000): 32 | if os.path.exists(word_vocab_path): 33 | word_id_map = json.load(open(word_vocab_path, 'r', encoding='utf-8')) 34 | else: 35 | word_id_map = build_vocab(X, tokenizer=tokenizer, max_size=max_vocab_size, min_freq=1, 36 | unk_token=unk_token, pad_token=pad_token) 37 | json.dump(word_id_map, open(word_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 38 | logger.debug(f"word vocab size: {len(word_id_map)}, word_vocab_path: {word_vocab_path}") 39 | 40 | if os.path.exists(label_vocab_path): 41 | label_id_map = json.load(open(label_vocab_path, 'r', encoding='utf-8')) 42 | else: 43 | id_label_map = {id: v for id, v in enumerate(set(y.tolist()))} 44 | label_id_map = {v: k for k, v in id_label_map.items()} 45 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4) 46 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}") 47 | 48 | def load_dataset(X, y, max_seq_length=128): 49 | contents = [] 50 | for content, label in zip(X, y): 51 | words_line = [] 52 | token = tokenizer(content) 53 | seq_len = len(token) 54 | if max_seq_length: 55 | if len(token) < max_seq_length: 56 | token.extend([pad_token] * (max_seq_length - len(token))) 57 | else: 58 | token = token[:max_seq_length] 59 | seq_len = max_seq_length 60 | # word to id 61 | for word in token: 62 | words_line.append(word_id_map.get(word, word_id_map.get(unk_token))) 63 | label_id = label_id_map.get(label) 64 | contents.append((words_line, label_id, seq_len)) 65 | return contents 66 | 67 | dataset = load_dataset(X, y, max_seq_length) 68 | return dataset, word_id_map, label_id_map 69 | 70 | 71 | class DatasetIterater: 72 | def __init__(self, dataset, device, batch_size=32): 73 | self.batch_size = batch_size 74 | self.dataset = dataset 75 | self.n_batches = len(dataset) // batch_size if len(dataset) > batch_size else 1 76 | self.residue = False # 记录batch数量是否为整数 77 | if len(dataset) % self.n_batches != 0: 78 | self.residue = True 79 | self.index = 0 80 | self.device = device 81 | 82 | def _to_tensor(self, datas): 83 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 84 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 85 | 86 | # pad前的长度(超过max_seq_length的设为max_seq_length) 87 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 88 | return (x, seq_len), y 89 | 90 | def __next__(self): 91 | if self.residue and self.index == self.n_batches: 92 | batches = self.dataset[self.index * self.batch_size: len(self.dataset)] 93 | self.index += 1 94 | batches = self._to_tensor(batches) 95 | return batches 96 | 97 | elif self.index >= self.n_batches: 98 | self.index = 0 99 | raise StopIteration 100 | else: 101 | batches = self.dataset[self.index * self.batch_size: (self.index + 1) * self.batch_size] 102 | self.index += 1 103 | batches = self._to_tensor(batches) 104 | return batches 105 | 106 | def __iter__(self): 107 | return self 108 | 109 | def __len__(self): 110 | if self.residue: 111 | return self.n_batches + 1 112 | else: 113 | return self.n_batches 114 | 115 | 116 | def build_iterator(dataset, device, batch_size=32): 117 | return DatasetIterater(dataset, device, batch_size) 118 | 119 | 120 | class TextRNNAttModel(nn.Module): 121 | """Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification""" 122 | 123 | def __init__( 124 | self, 125 | vocab_size, 126 | num_classes, 127 | embed_size=200, 128 | hidden_size=128, 129 | num_layers=2, 130 | dropout_rate=0.5, 131 | ): 132 | """ 133 | TextRNN-Att model 134 | @param vocab_size: 135 | @param num_classes: 136 | @param embed_size: 字向量维度 137 | @param hidden_size: lstm隐藏层 138 | @param num_layers: lstm层数 139 | @param dropout_rate: 140 | """ 141 | super().__init__() 142 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size - 1) 143 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, 144 | bidirectional=True, batch_first=True, dropout=dropout_rate) 145 | self.tanh1 = nn.Tanh() 146 | # self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2)) 147 | self.w = nn.Parameter(torch.zeros(hidden_size * 2)) 148 | self.tanh2 = nn.Tanh() 149 | self.fc1 = nn.Linear(hidden_size * 2, int(hidden_size / 2)) 150 | self.fc2 = nn.Linear(int(hidden_size / 2), num_classes) 151 | 152 | def forward(self, x): 153 | emb = self.embedding(x[0]) # [batch_size, seq_len, embeding]=[128, 32, 300] 154 | H, _ = self.lstm(emb) # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256] 155 | 156 | M = self.tanh1(H) # [128, 32, 256] 157 | # M = torch.tanh(torch.matmul(H, self.u)) 158 | alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1) # [128, 32, 1] 159 | out = H * alpha # [128, 32, 256] 160 | out = torch.sum(out, 1) # [128, 256] 161 | out = F.relu(out) 162 | out = self.fc1(out) 163 | out = self.fc2(out) # [128, 64] 164 | return out 165 | 166 | 167 | class TextRNNClassifier(ClassifierABC): 168 | def __init__( 169 | self, 170 | output_dir="outputs", 171 | hidden_size=128, 172 | num_layers=2, 173 | dropout_rate=0.5, batch_size=64, max_seq_length=128, 174 | embed_size=200, max_vocab_size=10000, unk_token='[UNK]', 175 | pad_token='[PAD]', tokenizer=None, 176 | ): 177 | """ 178 | Init the TextRNNClassifier 179 | @param output_dir: 模型保存路径 180 | @param hidden_size: lstm隐藏层 181 | @param num_layers: lstm层数 182 | @param dropout_rate: 183 | @param batch_size: 184 | @param max_seq_length: 185 | @param embed_size: 186 | @param max_vocab_size: 187 | @param unk_token: 188 | @param pad_token: 189 | @param tokenizer: 切词器,默认为字粒度切分 190 | """ 191 | self.output_dir = output_dir 192 | self.is_trained = False 193 | self.model = None 194 | logger.debug(f'device: {device}') 195 | self.hidden_size = hidden_size 196 | self.num_layers = num_layers 197 | self.dropout_rate = dropout_rate 198 | self.batch_size = batch_size 199 | self.max_seq_length = max_seq_length 200 | self.embed_size = embed_size 201 | self.max_vocab_size = max_vocab_size 202 | self.unk_token = unk_token 203 | self.pad_token = pad_token 204 | self.tokenizer = tokenizer if tokenizer else lambda x: [y for y in x] # char-level 205 | 206 | def __str__(self): 207 | return f'TextRNNClassifier instance ({self.model})' 208 | 209 | def train( 210 | self, 211 | data_list_or_path, 212 | header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1, 213 | num_epochs=20, learning_rate=1e-3, 214 | require_improvement=1000, evaluate_during_training_steps=100 215 | ): 216 | """ 217 | Train model with data_list_or_path and save model to output_dir 218 | @param data_list_or_path: 219 | @param header: 220 | @param names: 221 | @param delimiter: 222 | @param test_size: 223 | @param num_epochs: epoch数 224 | @param learning_rate: 学习率 225 | @param require_improvement: 若超过1000batch效果还没提升,则提前结束训练 226 | @param evaluate_during_training_steps: 每隔多少step评估一次模型 227 | @return: 228 | """ 229 | logger.debug('train model...') 230 | SEED = 1 231 | set_seed(SEED) 232 | # load data 233 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True) 234 | output_dir = self.output_dir 235 | if output_dir: 236 | os.makedirs(output_dir, exist_ok=True) 237 | word_vocab_path = os.path.join(output_dir, 'word_vocab.json') 238 | label_vocab_path = os.path.join(output_dir, 'label_vocab.json') 239 | save_model_path = os.path.join(output_dir, 'model.pth') 240 | 241 | dataset, self.word_id_map, self.label_id_map = build_dataset( 242 | self.tokenizer, X, y, 243 | word_vocab_path, 244 | label_vocab_path, 245 | max_vocab_size=self.max_vocab_size, 246 | max_seq_length=self.max_seq_length, 247 | unk_token=self.unk_token, pad_token=self.pad_token 248 | ) 249 | train_data, dev_data = train_test_split(dataset, test_size=test_size, random_state=SEED) 250 | logger.debug(f"train_data size: {len(train_data)}, dev_data size: {len(dev_data)}") 251 | logger.debug(f'train_data sample:\n{train_data[:3]}\ndev_data sample:\n{dev_data[:3]}') 252 | train_iter = build_iterator(train_data, device, self.batch_size) 253 | dev_iter = build_iterator(dev_data, device, self.batch_size) 254 | # create model 255 | vocab_size = len(self.word_id_map) 256 | num_classes = len(self.label_id_map) 257 | logger.debug(f'vocab_size:{vocab_size}', 'num_classes:', num_classes) 258 | self.model = TextRNNAttModel( 259 | vocab_size, num_classes, 260 | embed_size=self.embed_size, 261 | hidden_size=self.hidden_size, 262 | num_layers=self.num_layers, 263 | dropout_rate=self.dropout_rate 264 | ) 265 | self.model.to(device) 266 | # init_network(self.model) 267 | logger.info(self.model.parameters) 268 | # train model 269 | history = self.train_model_from_data_iterator(save_model_path, train_iter, dev_iter, num_epochs, learning_rate, 270 | require_improvement, evaluate_during_training_steps) 271 | self.is_trained = True 272 | logger.debug('train model done') 273 | return history 274 | 275 | def train_model_from_data_iterator(self, save_model_path, train_iter, dev_iter, 276 | num_epochs=10, learning_rate=1e-3, 277 | require_improvement=1000, evaluate_during_training_steps=100): 278 | history = [] 279 | # train 280 | start_time = time.time() 281 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 282 | 283 | total_batch = 0 # 记录进行到多少batch 284 | dev_best_loss = 1e10 285 | last_improve = 0 # 记录上次验证集loss下降的batch数 286 | flag = False # 记录是否很久没有效果提升 287 | for epoch in range(num_epochs): 288 | logger.debug('Epoch [{}/{}]'.format(epoch + 1, num_epochs)) 289 | for i, (trains, labels) in enumerate(train_iter): 290 | self.model.train() 291 | outputs = self.model(trains) 292 | loss = F.cross_entropy(outputs, labels) 293 | # compute gradient and do SGD step 294 | optimizer.zero_grad() 295 | loss.backward() 296 | optimizer.step() 297 | if total_batch % evaluate_during_training_steps == 0: 298 | # 输出在训练集和验证集上的效果 299 | y_true = labels.cpu() 300 | y_pred = torch.max(outputs, 1)[1].cpu() 301 | train_acc = metrics.accuracy_score(y_true, y_pred) 302 | if dev_iter is not None: 303 | dev_acc, dev_loss = self.evaluate(dev_iter) 304 | if dev_loss < dev_best_loss: 305 | dev_best_loss = dev_loss 306 | torch.save(self.model.state_dict(), save_model_path) 307 | logger.debug(f'Saved model: {save_model_path}') 308 | improve = '*' 309 | last_improve = total_batch 310 | else: 311 | improve = '' 312 | time_dif = get_time_spend(start_time) 313 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Val Loss:{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'.format( 314 | total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve) 315 | else: 316 | time_dif = get_time_spend(start_time) 317 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Time:{3}'.format( 318 | total_batch, loss.item(), train_acc, time_dif) 319 | logger.debug(msg) 320 | history.append(msg) 321 | self.model.train() 322 | total_batch += 1 323 | if total_batch - last_improve > require_improvement: 324 | # 验证集loss超过1000batch没下降,结束训练 325 | logger.debug("No optimization for a long time, auto-stopping...") 326 | flag = True 327 | break 328 | if flag: 329 | break 330 | return history 331 | 332 | def predict(self, sentences: list): 333 | """ 334 | Predict labels and label probability for sentences. 335 | @param sentences: list, input text list, eg: [text1, text2, ...] 336 | @return: predict_label, predict_prob 337 | """ 338 | if not self.is_trained: 339 | raise ValueError('model not trained.') 340 | self.model.eval() 341 | 342 | def load_dataset(X, max_seq_length=128): 343 | contents = [] 344 | for content in X: 345 | words_line = [] 346 | token = self.tokenizer(content) 347 | seq_len = len(token) 348 | if max_seq_length: 349 | if len(token) < max_seq_length: 350 | token.extend([self.pad_token] * (max_seq_length - len(token))) 351 | else: 352 | token = token[:max_seq_length] 353 | seq_len = max_seq_length 354 | # word to id 355 | for word in token: 356 | words_line.append(self.word_id_map.get(word, self.word_id_map.get(self.unk_token))) 357 | contents.append((words_line, 0, seq_len)) 358 | return contents 359 | 360 | data = load_dataset(sentences, self.max_seq_length) 361 | data_iter = build_iterator(data, device, self.batch_size) 362 | # predict prob 363 | predict_all = np.array([], dtype=int) 364 | proba_all = np.array([], dtype=float) 365 | with torch.no_grad(): 366 | for texts, _ in data_iter: 367 | outputs = self.model(texts) 368 | logit = F.softmax(outputs, dim=1).detach().cpu().numpy() 369 | pred = np.argmax(logit, axis=1) 370 | proba = np.max(logit, axis=1) 371 | 372 | predict_all = np.append(predict_all, pred) 373 | proba_all = np.append(proba_all, proba) 374 | id_label_map = {v: k for k, v in self.label_id_map.items()} 375 | predict_labels = [id_label_map.get(i) for i in predict_all] 376 | predict_probs = proba_all.tolist() 377 | return predict_labels, predict_probs 378 | 379 | def evaluate_model(self, data_list_or_path, header=None, 380 | names=('labels', 'text'), delimiter='\t'): 381 | """ 382 | Evaluate model. 383 | @param data_list_or_path: 384 | @param header: 385 | @param names: 386 | @param delimiter: 387 | @return: 388 | """ 389 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter) 390 | self.load_model() 391 | data, word_id_map, label_id_map = build_dataset( 392 | self.tokenizer, X_test, y_test, 393 | self.word_vocab_path, 394 | self.label_vocab_path, 395 | max_vocab_size=self.max_vocab_size, 396 | max_seq_length=self.max_seq_length, 397 | unk_token=self.unk_token, 398 | pad_token=self.pad_token, 399 | ) 400 | data_iter = build_iterator(data, device, self.batch_size) 401 | return self.evaluate(data_iter)[0] 402 | 403 | def evaluate(self, data_iter): 404 | """ 405 | Evaluate model. 406 | @param data_iter: 407 | @return: accuracy score, loss 408 | """ 409 | if not self.model: 410 | raise ValueError('model not trained.') 411 | self.model.eval() 412 | loss_total = 0.0 413 | predict_all = np.array([], dtype=int) 414 | labels_all = np.array([], dtype=int) 415 | with torch.no_grad(): 416 | for texts, labels in data_iter: 417 | outputs = self.model(texts) 418 | loss = F.cross_entropy(outputs, labels) 419 | loss_total += loss 420 | labels = labels.cpu().numpy() 421 | predic = torch.max(outputs, 1)[1].cpu().numpy() 422 | labels_all = np.append(labels_all, labels) 423 | predict_all = np.append(predict_all, predic) 424 | logger.debug(f"evaluate, last batch, y_true: {labels}, y_pred: {predic}") 425 | acc = metrics.accuracy_score(labels_all, predict_all) 426 | return acc, loss_total / len(data_iter) 427 | 428 | def load_model(self): 429 | """ 430 | Load model from output_dir 431 | @return: 432 | """ 433 | model_path = os.path.join(self.output_dir, 'model.pth') 434 | if os.path.exists(model_path): 435 | self.word_vocab_path = os.path.join(self.output_dir, 'word_vocab.json') 436 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json') 437 | self.word_id_map = load_vocab(self.word_vocab_path) 438 | self.label_id_map = load_vocab(self.label_vocab_path) 439 | vocab_size = len(self.word_id_map) 440 | num_classes = len(self.label_id_map) 441 | self.model = TextRNNAttModel( 442 | vocab_size, num_classes, 443 | embed_size=self.embed_size, 444 | hidden_size=self.hidden_size, 445 | num_layers=self.num_layers, 446 | dropout_rate=self.dropout_rate 447 | ) 448 | self.model.load_state_dict(torch.load(model_path, map_location=device)) 449 | self.model.to(device) 450 | self.is_trained = True 451 | else: 452 | logger.error(f'{model_path} not exists.') 453 | self.is_trained = False 454 | return self.is_trained 455 | 456 | 457 | if __name__ == '__main__': 458 | parser = argparse.ArgumentParser(description='Text Classification') 459 | parser.add_argument('--output_dir', default='models/textrnn', type=str, help='save model dir') 460 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'), 461 | type=str, help='sample data file path') 462 | args = parser.parse_args() 463 | print(args) 464 | # create model 465 | m = TextRNNClassifier(args.output_dir) 466 | # train model 467 | m.train(data_list_or_path=args.data_path, num_epochs=3) 468 | # load trained the best model and predict 469 | m.load_model() 470 | print('best model loaded from file, and predict') 471 | X, y, _ = load_data(args.data_path) 472 | X = X[:5] 473 | y = y[:5] 474 | predict_labels, predict_probs = m.predict(X) 475 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y): 476 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth) 477 | -------------------------------------------------------------------------------- /pytextclassifier/time_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import time 7 | from datetime import timedelta 8 | 9 | 10 | def get_time_spend(start_time): 11 | """获取已使用时间""" 12 | end_time = time.time() 13 | time_dif = end_time - start_time 14 | return timedelta(seconds=int(round(time_dif))) 15 | 16 | 17 | def init_network(model, method='xavier', exclude='embedding'): 18 | """权重初始化,默认xavier""" 19 | import torch.nn as nn 20 | for name, w in model.named_parameters(): 21 | if exclude not in name: 22 | if 'weight' in name: 23 | if method == 'xavier': 24 | nn.init.xavier_normal_(w) 25 | elif method == 'kaiming': 26 | nn.init.kaiming_normal_(w) 27 | else: 28 | nn.init.normal_(w) 29 | if 'bias' in name: 30 | nn.init.constant_(w, 0) 31 | -------------------------------------------------------------------------------- /pytextclassifier/tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: Tokenization 5 | """ 6 | import re 7 | 8 | import jieba 9 | 10 | 11 | def tokenize_words(text): 12 | """Word segmentation""" 13 | output = [] 14 | sentences = split_2_short_text(text, include_symbol=True) 15 | for sentence, idx in sentences: 16 | if is_any_chinese_string(sentence): 17 | output.extend(jieba.lcut(sentence)) 18 | else: 19 | output.extend(whitespace_tokenize(sentence)) 20 | return output 21 | 22 | 23 | class Tokenizer(object): 24 | """Given Full tokenization.""" 25 | 26 | def __init__(self, lower=True): 27 | self.lower = lower 28 | 29 | def tokenize(self, text): 30 | """Tokenizes a piece of text.""" 31 | res = [] 32 | if len(text) == 0: 33 | return res 34 | 35 | if self.lower: 36 | text = text.lower() 37 | # for the multilingual and Chinese 38 | res = tokenize_words(text) 39 | return res 40 | 41 | 42 | def split_2_short_text(text, include_symbol=True): 43 | """ 44 | 长句切分为短句 45 | :param text: str 46 | :param include_symbol: bool 47 | :return: (sentence, idx) 48 | """ 49 | re_han = re.compile("([\u4E00-\u9Fa5a-zA-Z0-9+#&]+)", re.U) 50 | result = [] 51 | blocks = re_han.split(text) 52 | start_idx = 0 53 | for blk in blocks: 54 | if not blk: 55 | continue 56 | if include_symbol: 57 | result.append((blk, start_idx)) 58 | else: 59 | if re_han.match(blk): 60 | result.append((blk, start_idx)) 61 | start_idx += len(blk) 62 | return result 63 | 64 | 65 | def is_chinese(uchar): 66 | """判断一个unicode是否是汉字""" 67 | return '\u4e00' <= uchar <= '\u9fa5' 68 | 69 | 70 | def is_all_chinese_string(string): 71 | """判断是否全为汉字""" 72 | return all(is_chinese(c) for c in string) 73 | 74 | 75 | def is_any_chinese_string(string): 76 | """判断是否有中文汉字""" 77 | return any(is_chinese(c) for c in string) 78 | 79 | 80 | def whitespace_tokenize(text): 81 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 82 | text = text.strip() 83 | if not text: 84 | return [] 85 | tokens = text.split() 86 | return tokens 87 | 88 | 89 | if __name__ == '__main__': 90 | a = '晚上一个人好孤单,想:找附近的人陪陪我.' 91 | b = "unlabeled example, aug_copy_num is the index of the generated augmented. you don't know." 92 | 93 | t = Tokenizer() 94 | word_list_a = t.tokenize(a + b) 95 | print('VocabTokenizer-word:', word_list_a) 96 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | jieba 3 | scikit-learn 4 | pandas 5 | numpy 6 | transformers -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | python_functions=test_ 3 | 4 | codestyle_max_line_length = 119 5 | 6 | log_cli = true 7 | log_cli_level = WARNING 8 | 9 | [metadata] 10 | description-file = README.md 11 | license_file = LICENSE 12 | 13 | [pycodestyle] 14 | max-line-length = 119 15 | 16 | [flake8] 17 | max-line-length = 119 18 | ignore = E203 , W503, F401 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | from setuptools import setup, find_packages 7 | 8 | __version__ = '1.4.0' 9 | 10 | with open('README.md', 'r', encoding='utf-8') as f: 11 | readme = f.read() 12 | 13 | setup( 14 | name='pytextclassifier', 15 | version=__version__, 16 | description='Text Classifier, Text Classification', 17 | long_description=readme, 18 | long_description_content_type='text/markdown', 19 | author='XuMing', 20 | author_email='xuming624@qq.com', 21 | url='https://github.com/shibing624/pytextclassifier', 22 | license='Apache 2.0', 23 | classifiers=[ 24 | 'Intended Audience :: Science/Research', 25 | 'Operating System :: OS Independent', 26 | 'License :: OSI Approved :: Apache Software License', 27 | 'Programming Language :: Python', 28 | 'Programming Language :: Python :: 2.7', 29 | 'Programming Language :: Python :: 3', 30 | 'Topic :: Text Processing :: Linguistic', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | ], 33 | keywords='pytextclassifier,textclassifier,classifier,textclassification', 34 | install_requires=[ 35 | "loguru", 36 | "jieba", 37 | "scikit-learn", 38 | "pandas", 39 | "numpy", 40 | "transformers", 41 | ], 42 | packages=find_packages(exclude=['tests']), 43 | package_dir={'pytextclassifier': 'pytextclassifier'}, 44 | package_data={ 45 | 'pytextclassifier': ['*.*', '*.txt', '../examples/thucnews_train_1w.txt'], 46 | }, 47 | ) 48 | -------------------------------------------------------------------------------- /tests/test_bert_onnx_bs_qps.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | 7 | import os 8 | import shutil 9 | import sys 10 | import time 11 | import unittest 12 | 13 | import numpy as np 14 | import torch 15 | from loguru import logger 16 | 17 | sys.path.append('..') 18 | from pytextclassifier import BertClassifier 19 | 20 | 21 | class ModelSpeedTestCase(unittest.TestCase): 22 | 23 | def test_classifier_diff_batch_size(self): 24 | # Helper function to calculate QPS and 95th percentile latency 25 | def calculate_metrics(times): 26 | completion_times = np.array(times) 27 | total_requests = 300 28 | 29 | # 平均每秒请求数(QPS),计算公式为总请求数除以总耗时 30 | average_qps = total_requests / np.sum(completion_times) 31 | latency = np.sum(completion_times) / total_requests 32 | 33 | # 返回所有计算结果 34 | return { 35 | 'total_requests': total_requests, # 总请求数 36 | 'latency': latency, # 每条请求的平均完成时间 37 | 'average_qps': average_qps, # 平均每秒请求数 38 | } 39 | 40 | # Train the model once 41 | def train_model(output_dir='models/bert-chinese-v1'): 42 | m = BertClassifier(output_dir=output_dir, num_classes=2, 43 | model_type='bert', model_name='bert-base-chinese', num_epochs=1) 44 | data = [ 45 | ('education', '名师指导托福语法技巧:名词的复数形式'), 46 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 47 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 48 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'), 49 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'), 50 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'), 51 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'), 52 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 53 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 54 | ('sports', '米兰客场8战不败国米10年连胜1'), 55 | ('sports', '米兰客场8战不败国米10年连胜2'), 56 | ('sports', '米兰客场8战不败国米10年连胜3'), 57 | ('sports', '米兰客场8战不败国米10年连胜4'), 58 | ('sports', '米兰客场8战不败国米10年连胜5'), 59 | ] 60 | m.train(data * 10) 61 | m.load_model() 62 | return m 63 | 64 | # Evaluate performance for a given batch size 65 | def evaluate_performance(m, batch_size): 66 | samples = ['名师指导托福语法技巧', 67 | '米兰客场8战不败', 68 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100 69 | 70 | batch_times = [] 71 | 72 | for i in range(0, len(samples), batch_size): 73 | batch_samples = samples[i:i + batch_size] 74 | start_time = time.time() 75 | m.predict(batch_samples) 76 | end_time = time.time() 77 | 78 | batch_times.append(end_time - start_time) 79 | 80 | metrics = calculate_metrics(batch_times) 81 | return metrics 82 | 83 | # Convert the model to ONNX format 84 | def convert_model_to_onnx(m, model_dir='models/bert-chinese-v1'): 85 | save_onnx_dir = os.path.join(model_dir, 'onnx') 86 | if os.path.exists(save_onnx_dir): 87 | shutil.rmtree(save_onnx_dir) 88 | m.model.convert_to_onnx(save_onnx_dir) 89 | shutil.copy(m.label_vocab_path, save_onnx_dir) 90 | return save_onnx_dir 91 | 92 | # Main function 93 | batch_sizes = [1, 8, 16, 32, 64, 128] # Modify these values as appropriate 94 | model_dir = 'models/bert-chinese-v1' 95 | # Train the model once 96 | model = train_model(model_dir) 97 | # Convert to ONNX 98 | onnx_model_path = convert_model_to_onnx(model) 99 | del model 100 | torch.cuda.empty_cache() 101 | 102 | # Evaluate Standard BERT model performance 103 | for batch_size in batch_sizes: 104 | model = BertClassifier(output_dir=model_dir, num_classes=2, model_type='bert', 105 | model_name=model_dir, 106 | args={"eval_batch_size": batch_size, "onnx": False}) 107 | model.load_model() 108 | metrics = evaluate_performance(model, batch_size) 109 | logger.info( 110 | f'Standard BERT model - Batch size: {batch_size}, total_requests: {metrics["total_requests"]}, ' 111 | f'Average QPS: {metrics["average_qps"]:.2f}, Average Latency: {metrics["latency"]:.4f}') 112 | del model 113 | torch.cuda.empty_cache() 114 | 115 | # Load and evaluate ONNX model performance 116 | for batch_size in batch_sizes: 117 | onnx_model = BertClassifier(output_dir=onnx_model_path, num_classes=2, model_type='bert', 118 | model_name=onnx_model_path, 119 | args={"eval_batch_size": batch_size, "onnx": True}) 120 | onnx_model.load_model() 121 | metrics = evaluate_performance(onnx_model, batch_size) 122 | logger.info( 123 | f'ONNX model - Batch size: {batch_size}, total_requests: {metrics["total_requests"]}, ' 124 | f'Average QPS: {metrics["average_qps"]:.2f}, Average Latency: {metrics["latency"]:.4f}') 125 | del onnx_model 126 | torch.cuda.empty_cache() 127 | 128 | # Clean up 129 | shutil.rmtree('models') 130 | 131 | 132 | if __name__ == '__main__': 133 | unittest.main() 134 | -------------------------------------------------------------------------------- /tests/test_bert_onnx_speed.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import os.path 7 | import shutil 8 | import sys 9 | import time 10 | import unittest 11 | 12 | import torch 13 | 14 | sys.path.append('..') 15 | from pytextclassifier import BertClassifier 16 | 17 | 18 | class ModelSpeedTestCase(unittest.TestCase): 19 | def test_classifier(self): 20 | m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2, 21 | model_type='bert', model_name='bert-base-chinese', num_epochs=1) 22 | data = [ 23 | ('education', '名师指导托福语法技巧:名词的复数形式'), 24 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'), 25 | ('education', '公务员考虑越来越吃香,这是怎么回事?'), 26 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'), 27 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'), 28 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'), 29 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'), 30 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), 31 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), 32 | ('sports', '米兰客场8战不败国米10年连胜1'), 33 | ('sports', '米兰客场8战不败国米10年连胜2'), 34 | ('sports', '米兰客场8战不败国米10年连胜3'), 35 | ('sports', '米兰客场8战不败国米10年连胜4'), 36 | ('sports', '米兰客场8战不败国米10年连胜5'), 37 | ] 38 | m.train(data * 10) 39 | m.load_model() 40 | 41 | samples = ['名师指导托福语法技巧', 42 | '米兰客场8战不败', 43 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100 44 | 45 | start_time = time.time() 46 | predict_label_bert, predict_proba_bert = m.predict(samples) 47 | print(f'predict_label_bert size: {len(predict_label_bert)}') 48 | self.assertEqual(len(predict_label_bert), 300) 49 | end_time = time.time() 50 | elapsed_time_bert = end_time - start_time 51 | print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds') 52 | 53 | # convert to onnx, and load onnx model to predict, speed up 10x 54 | save_onnx_dir = 'models/bert-chinese-v1/onnx' 55 | m.model.convert_to_onnx(save_onnx_dir) 56 | # copy label_vocab.json to save_onnx_dir 57 | if os.path.exists(m.label_vocab_path): 58 | shutil.copy(m.label_vocab_path, save_onnx_dir) 59 | 60 | # Manually delete the model and clear CUDA cache 61 | del m 62 | torch.cuda.empty_cache() 63 | 64 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir, 65 | args={"onnx": True}) 66 | m.load_model() 67 | start_time = time.time() 68 | predict_label_bert, predict_proba_bert = m.predict(samples) 69 | print(f'predict_label_bert size: {len(predict_label_bert)}') 70 | end_time = time.time() 71 | elapsed_time_onnx = end_time - start_time 72 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds') 73 | 74 | self.assertEqual(len(predict_label_bert), 300) 75 | shutil.rmtree('models') 76 | 77 | 78 | if __name__ == '__main__': 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /tests/test_fasttext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import os 7 | import unittest 8 | 9 | import sys 10 | 11 | sys.path.append('..') 12 | from pytextclassifier import FastTextClassifier 13 | import torch 14 | 15 | 16 | class SaveModelTestCase(unittest.TestCase): 17 | def test_classifier(self): 18 | m = FastTextClassifier(output_dir='models/fasttext') 19 | data = [ 20 | ('education', 'Student debt to cost Britain billions within decades'), 21 | ('education', 'Chinese education for TV experiment'), 22 | ('sports', 'Middle East and Asia boost investment in top level sports'), 23 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 24 | ] 25 | m.train(data) 26 | m.load_model() 27 | samples = ['Abbott government spends $8 million on higher education media blitz', 28 | 'Middle East and Asia boost investment in top level sports'] 29 | r, p = m.predict(samples) 30 | print(r, p) 31 | print('-' * 20) 32 | torch.save(m.model, 'models/model.pkl') 33 | model = torch.load('models/model.pkl') 34 | m.model = model 35 | r1, p1 = m.predict(samples) 36 | print(r1, p1) 37 | self.assertEqual(r, r1) 38 | self.assertEqual(p, p1) 39 | 40 | os.remove('models/model.pkl') 41 | import shutil 42 | shutil.rmtree('models') 43 | 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_lr_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | 7 | import unittest 8 | 9 | import sys 10 | 11 | sys.path.append('..') 12 | from pytextclassifier import ClassicClassifier 13 | 14 | 15 | class BaseTestCase(unittest.TestCase): 16 | def test_classifier(self): 17 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr') 18 | data = [ 19 | ('education', 'Student debt to cost Britain billions within decades'), 20 | ('education', 'Chinese education for TV experiment'), 21 | ('sports', 'Middle East and Asia boost investment in top level sports'), 22 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 23 | ] 24 | m.train(data) 25 | m.load_model() 26 | r, _ = m.predict(['Abbott government spends $8 million on higher education media blitz', 27 | 'Middle East and Asia boost investment in top level sports']) 28 | print(r) 29 | self.assertEqual(r[0], 'education') 30 | self.assertEqual(r[1], 'sports') 31 | test_data = [ 32 | ('education', 'Abbott government spends $8 million on higher education media blitz'), 33 | ('sports', 'Middle East and Asia boost investment in top level sports'), 34 | ] 35 | acc_score = m.evaluate_model(test_data) 36 | print(acc_score) # 1.0 37 | self.assertEqual(acc_score, 1.0) 38 | import shutil 39 | shutil.rmtree('models') 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /tests/test_lr_vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | 7 | import unittest 8 | 9 | import sys 10 | 11 | sys.path.append('..') 12 | from pytextclassifier import ClassicClassifier 13 | 14 | 15 | class VecTestCase(unittest.TestCase): 16 | def setUp(self): 17 | m = ClassicClassifier('models/lr') 18 | data = [ 19 | ('education', 'Student debt to cost Britain billions within decades'), 20 | ('education', 'Chinese education for TV experiment'), 21 | ('sports', 'Middle East and Asia boost investment in top level sports'), 22 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar') 23 | ] 24 | m.train(data) 25 | print('model trained:', m) 26 | 27 | def test_classifier(self): 28 | new_m = ClassicClassifier('models/lr') 29 | new_m.load_model() 30 | r, _ = new_m.predict(['Abbott government spends $8 million on higher education media blitz', 31 | 'Middle East and Asia boost investment in top level sports']) 32 | print(r) 33 | self.assertTrue(r[0] == 'education') 34 | 35 | def test_vec(self): 36 | new_m = ClassicClassifier('models/lr') 37 | new_m.load_model() 38 | print(new_m.feature.get_feature_names()) 39 | print('feature name size:', len(new_m.feature.get_feature_names())) 40 | self.assertTrue(len(new_m.feature.get_feature_names()) > 0) 41 | 42 | def test_stopwords(self): 43 | new_m = ClassicClassifier('models/lr') 44 | new_m.load_model() 45 | stopwords = new_m.stopwords 46 | print(len(stopwords)) 47 | self.assertTrue(len(stopwords) > 0) 48 | 49 | @classmethod 50 | def tearDownClass(cls): 51 | import shutil 52 | 53 | shutil.rmtree('models') 54 | print('remove dir: models') 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() 59 | --------------------------------------------------------------------------------