├── .gitignore
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── _config.yml
├── docs
├── Text classification.png
├── cluster_train_seg_samples.png
├── img.png
├── logo.png
├── logo.svg
└── wechat.jpeg
├── examples
├── albert_classification_zh_demo.py
├── baidu_extract_2020_train.csv
├── bert_classification_en_demo.py
├── bert_classification_tnews_demo.py
├── bert_classification_zh_demo.py
├── bert_hierarchical_classification_zh_demo.py
├── bert_multilabel_classification_en_demo.py
├── bert_multilabel_classification_zh_demo.py
├── cluster_demo.py
├── fasttext_classification_demo.py
├── lr_classification_demo.py
├── lr_en_classification_demo.py
├── multilabel_jd_comments.csv
├── my_vectorizer_demo.py
├── onnx_predict_demo.py
├── onnx_xlnet_predict_demo.py
├── random_forest_classification_demo.py
├── textcnn_classification_demo.py
├── textrnn_classification_demo.py
├── thucnews_train_10w.txt
├── thucnews_train_1w.txt
└── visual_feature_importance.ipynb
├── pytextclassifier
├── __init__.py
├── base_classifier.py
├── bert_classfication_utils.py
├── bert_classification_model.py
├── bert_classifier.py
├── bert_multi_label_classification_model.py
├── classic_classifier.py
├── data_helper.py
├── fasttext_classifier.py
├── stopwords.txt
├── textcluster.py
├── textcnn_classifier.py
├── textrnn_classifier.py
├── time_util.py
└── tokenizer.py
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── test_bert_onnx_bs_qps.py
├── test_bert_onnx_speed.py
├── test_fasttext.py
├── test_lr_classification.py
└── test_lr_vec.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 | *.pyc
91 | .idea
92 | .idea/
93 | *.iml
94 | .vscode
95 | *.bin
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Xu"
5 | given-names: "Ming"
6 | title: "Pytextclassifier: Text classifier toolkit for NLP"
7 | url: "https://github.com/shibing624/pytextclassifier"
8 | data-released: 2021-10-29
9 | version: 1.1.4
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | We are happy to accept your contributions to make `pytextclassifier` better and more awesome! To avoid unnecessary work on either
4 | side, please stick to the following process:
5 |
6 | 1. Check if there is already [an issue](https://github.com/shibing624/pytextclassifier/issues) for your concern.
7 | 2. If there is not, open a new one to start a discussion. We hate to close finished PRs!
8 | 3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the
9 | commit guidelines below.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
6 |
7 | -----------------
8 |
9 | # PyTextClassifier: Python Text Classifier
10 | [](https://badge.fury.io/py/pytextclassifier)
11 | [](https://pepy.tech/project/pytextclassifier)
12 | [](CONTRIBUTING.md)
13 | [](LICENSE)
14 | [](requirements.txt)
15 | [](https://github.com/shibing624/pytextclassifier/issues)
16 | [](#Contact)
17 |
18 |
19 | ## Introduction
20 | PyTextClassifier: Python Text Classifier. It can be applied to the fields of sentiment polarity analysis, text risk classification and so on,
21 | and it supports multiple classification algorithms and clustering algorithms.
22 |
23 | **pytextclassifier** is a python Open Source Toolkit for text classification. The goal is to implement
24 | text analysis algorithm, so to achieve the use in the production environment.
25 |
26 | 文本分类器,提供多种文本分类和聚类算法,支持句子和文档级的文本分类任务,支持二分类、多分类、多标签分类、多层级分类和Kmeans聚类,开箱即用。python3开发。
27 |
28 | **Guide**
29 |
30 | - [Feature](#Feature)
31 | - [Install](#install)
32 | - [Usage](#usage)
33 | - [Dataset](#Dataset)
34 | - [Contact](#Contact)
35 | - [Citation](#Citation)
36 | - [Reference](#reference)
37 |
38 | ## Feature
39 |
40 | **pytextclassifier** has the characteristics
41 | of clear algorithm, high performance and customizable corpus.
42 |
43 | Functions:
44 | ### Classifier
45 | - [x] LogisticRegression
46 | - [x] Random Forest
47 | - [x] Decision Tree
48 | - [x] K-Nearest Neighbours
49 | - [x] Naive bayes
50 | - [x] Xgboost
51 | - [x] Support Vector Machine(SVM)
52 | - [x] TextCNN
53 | - [x] TextRNN
54 | - [x] Fasttext
55 | - [x] BERT
56 |
57 | ### Cluster
58 | - [x] MiniBatchKmeans
59 |
60 | While providing rich functions, **pytextclassifier** internal modules adhere to low coupling, model adherence to inert loading, dictionary publication, and easy to use.
61 |
62 | ## Install
63 |
64 | - Requirements and Installation
65 |
66 | ```
67 | pip3 install torch # conda install pytorch
68 | pip3 install pytextclassifier
69 | ```
70 |
71 | or
72 |
73 | ```
74 | git clone https://github.com/shibing624/pytextclassifier.git
75 | cd pytextclassifier
76 | python3 setup.py install
77 | ```
78 |
79 |
80 | ## Usage
81 | ### Text Classifier
82 |
83 | ### English Text Classifier
84 |
85 | Including model training, saving, predict, evaluate, for example [examples/lr_en_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_en_classification_demo.py):
86 |
87 | ```python
88 | import sys
89 |
90 | sys.path.append('..')
91 | from pytextclassifier import ClassicClassifier
92 |
93 | if __name__ == '__main__':
94 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')
95 | # ClassicClassifier support model_name:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
96 | print(m)
97 | data = [
98 | ('education', 'Student debt to cost Britain billions within decades'),
99 | ('education', 'Chinese education for TV experiment'),
100 | ('sports', 'Middle East and Asia boost investment in top level sports'),
101 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
102 | ]
103 | # train and save best model
104 | m.train(data)
105 | # load best model from model_dir
106 | m.load_model()
107 | predict_label, predict_proba = m.predict([
108 | 'Abbott government spends $8 million on higher education media blitz'])
109 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
110 |
111 | test_data = [
112 | ('education', 'Abbott government spends $8 million on higher education media blitz'),
113 | ('sports', 'Middle East and Asia boost investment in top level sports'),
114 | ]
115 | acc_score = m.evaluate_model(test_data)
116 | print(f'acc_score: {acc_score}')
117 | ```
118 |
119 | output:
120 |
121 | ```
122 | ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)
123 | predict_label: ['education'], predict_proba: [0.5378236358492112]
124 | acc_score: 1.0
125 | ```
126 |
127 | ### Chinese Text Classifier(中文文本分类)
128 |
129 | Text classification compatible with Chinese and English corpora.
130 |
131 | example [examples/lr_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_classification_demo.py)
132 |
133 | ```python
134 | import sys
135 |
136 | sys.path.append('..')
137 | from pytextclassifier import ClassicClassifier
138 |
139 | if __name__ == '__main__':
140 | m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')
141 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
142 | data = [
143 | ('education', '名师指导托福语法技巧:名词的复数形式'),
144 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
145 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
146 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
147 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
148 | ('sports', '米兰客场8战不败国米10年连胜'),
149 | ]
150 | m.train(data)
151 | print(m)
152 | # load best model from model_dir
153 | m.load_model()
154 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
155 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
156 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
157 |
158 | test_data = [
159 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
160 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
161 | ]
162 | acc_score = m.evaluate_model(test_data)
163 | print(f'acc_score: {acc_score}') # 1.0
164 |
165 | #### train model with 1w data
166 | print('-' * 42)
167 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')
168 | data_file = 'thucnews_train_1w.txt'
169 | m.train(data_file)
170 | m.load_model()
171 | predict_label, predict_proba = m.predict(
172 | ['顺义北京苏活88平米起精装房在售',
173 | '美EB-5项目“15日快速移民”将推迟'])
174 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
175 | ```
176 |
177 | output:
178 |
179 | ```
180 | ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)
181 | predict_label: ['education' 'sports'], predict_proba: [0.5, 0.598941806741534]
182 | acc_score: 1.0
183 | ------------------------------------------
184 | predict_label: ['realty' 'education'], predict_proba: [0.7302956923617372, 0.2565005445322923]
185 | ```
186 |
187 | ### Visual Feature Importance
188 |
189 | Show feature weights of model, and prediction word weight, for example [examples/visual_feature_importance.ipynb](https://github.com/shibing624/pytextclassifier/blob/master/examples/visual_feature_importance.ipynb)
190 |
191 | ```python
192 | import sys
193 |
194 | sys.path.append('..')
195 | from pytextclassifier import ClassicClassifier
196 | import jieba
197 |
198 | tc = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')
199 | data = [
200 | ('education', '名师指导托福语法技巧:名词的复数形式'),
201 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
202 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
203 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
204 | ('sports', '米兰客场8战不败国米10年连胜')
205 | ]
206 | tc.train(data)
207 | import eli5
208 |
209 | infer_data = ['高考指导托福语法技巧国际认可',
210 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']
211 | eli5.show_weights(tc.model, vec=tc.feature)
212 | seg_infer_data = [' '.join(jieba.lcut(i)) for i in infer_data]
213 | eli5.show_prediction(tc.model, seg_infer_data[0], vec=tc.feature,
214 | target_names=['education', 'sports'])
215 | ```
216 |
217 | output:
218 |
219 | 
220 |
221 | ### Deep Classification model
222 |
223 | 本项目支持以下深度分类模型:FastText、TextCNN、TextRNN、Bert模型,`import`模型对应的方法来调用:
224 | ```python
225 | from pytextclassifier import FastTextClassifier, TextCNNClassifier, TextRNNClassifier, BertClassifier
226 | ```
227 |
228 | 下面以FastText模型为示例,其他模型的使用方法类似。
229 |
230 | ### FastText 模型
231 |
232 | 训练和预测`FastText`模型示例[examples/fasttext_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/fasttext_classification_demo.py)
233 |
234 | ```python
235 | import sys
236 |
237 | sys.path.append('..')
238 | from pytextclassifier import FastTextClassifier, load_data
239 |
240 | if __name__ == '__main__':
241 | m = FastTextClassifier(output_dir='models/fasttext-toy')
242 | data = [
243 | ('education', '名师指导托福语法技巧:名词的复数形式'),
244 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
245 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
246 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
247 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
248 | ('sports', '米兰客场8战不败保持连胜'),
249 | ]
250 | m.train(data, num_epochs=3)
251 | print(m)
252 | # load trained best model
253 | m.load_model()
254 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
255 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
256 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
257 | test_data = [
258 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
259 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
260 | ]
261 | acc_score = m.evaluate_model(test_data)
262 | print(f'acc_score: {acc_score}') # 1.0
263 |
264 | #### train model with 1w data
265 | print('-' * 42)
266 | data_file = 'thucnews_train_1w.txt'
267 | m = FastTextClassifier(output_dir='models/fasttext')
268 | m.train(data_file, names=('labels', 'text'), num_epochs=3)
269 | # load best trained model from model_dir
270 | m.load_model()
271 | predict_label, predict_proba = m.predict(
272 | ['顺义北京苏活88平米起精装房在售',
273 | '美EB-5项目“15日快速移民”将推迟']
274 | )
275 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
276 | x, y, df = load_data(data_file)
277 | test_data = df[:100]
278 | acc_score = m.evaluate_model(test_data)
279 | print(f'acc_score: {acc_score}')
280 | ```
281 |
282 | ### BERT 类模型
283 |
284 | #### 多分类模型
285 | 训练和预测`BERT`多分类模型,示例[examples/bert_classification_zh_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_classification_zh_demo.py)
286 |
287 | ```python
288 | import sys
289 |
290 | sys.path.append('..')
291 | from pytextclassifier import BertClassifier
292 |
293 | if __name__ == '__main__':
294 | m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2,
295 | model_type='bert', model_name='bert-base-chinese', num_epochs=2)
296 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
297 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
298 | data = [
299 | ('education', '名师指导托福语法技巧:名词的复数形式'),
300 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
301 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
302 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
303 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
304 | ('sports', '米兰客场8战不败国米10年连胜'),
305 | ]
306 | m.train(data)
307 | print(m)
308 | # load trained best model from model_dir
309 | m.load_model()
310 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
311 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
312 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
313 |
314 | test_data = [
315 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
316 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
317 | ]
318 | acc_score = m.evaluate_model(test_data)
319 | print(f'acc_score: {acc_score}')
320 |
321 | # train model with 1w data file and 10 classes
322 | print('-' * 42)
323 | m = BertClassifier(output_dir='models/bert-chinese', num_classes=10,
324 | model_type='bert', model_name='bert-base-chinese', num_epochs=2,
325 | args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, })
326 | data_file = 'thucnews_train_1w.txt'
327 | # 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用
328 | m.train(data_file, test_size=0, names=('labels', 'text'))
329 | m.load_model()
330 | predict_label, predict_proba = m.predict(
331 | ['顺义北京苏活88平米起精装房在售',
332 | '美EB-5项目“15日快速移民”将推迟',
333 | '恒生AH溢指收平 A股对H股折价1.95%'])
334 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
335 | ```
336 | PS:如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用
337 |
338 | #### 多标签分类模型
339 | 分类可以分为多分类和多标签分类。多分类的标签是排他的,而多标签分类的所有标签是不排他的。
340 |
341 | 多标签分类比较直观的理解是,一个样本可以同时拥有几个类别标签,
342 | 比如一首歌的标签可以是流行、轻快,一部电影的标签可以是动作、喜剧、搞笑等,这都是多标签分类的情况。
343 |
344 | 训练和预测`BERT`多标签分类模型,示例[examples/bert_multilabel_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_multilabel_classification_zh_demo.py)
345 |
346 | ```python
347 | import sys
348 | import pandas as pd
349 |
350 | sys.path.append('..')
351 | from pytextclassifier import BertClassifier
352 |
353 |
354 | def load_jd_data(file_path):
355 | """
356 | Load jd data from file.
357 | @param file_path:
358 | format: content,其他,互联互通,产品功耗,滑轮提手,声音,APP操控性,呼吸灯,外观,底座,制热范围,遥控器电池,味道,制热效果,衣物烘干,体积大小
359 | @return:
360 | """
361 | data = []
362 | with open(file_path, 'r', encoding='utf-8') as f:
363 | for line in f:
364 | line = line.strip()
365 | if line.startswith('#'):
366 | continue
367 | if not line:
368 | continue
369 | terms = line.split(',')
370 | if len(terms) != 16:
371 | continue
372 | val = [int(i) for i in terms[1:]]
373 | data.append([terms[0], val])
374 | return data
375 |
376 |
377 | if __name__ == '__main__':
378 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
379 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
380 | m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15,
381 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True)
382 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
383 | train_data = [
384 | ["一个小时房间仍然没暖和", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
385 | ["耗电情况:这个没有注意", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
386 | ]
387 | data = load_jd_data('multilabel_jd_comments.csv')
388 | train_data.extend(data)
389 | print(train_data[:5])
390 | train_df = pd.DataFrame(train_data, columns=["text", "labels"])
391 |
392 | print(train_df.head())
393 | m.train(train_df)
394 | print(m)
395 | # Evaluate the model
396 | acc_score = m.evaluate_model(train_df[:20])
397 | print(f'acc_score: {acc_score}')
398 |
399 | # load trained best model from model_dir
400 | m.load_model()
401 | predict_label, predict_proba = m.predict(['一个小时房间仍然没暖和', '耗电情况:这个没有注意'])
402 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
403 | ```
404 |
405 | #### 多层级分类模型
406 | **多层级标签分类任务**,如行业分类(一级行业下分二级子行业,再分三级)、产品分类,可以使用多标签分类模型,将多层级标签转换为多标签形式,
407 | 示例[examples/bert_hierarchical_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_hierarchical_classification_zh_demo.py)
408 |
409 |
410 | #### ONNX推理加速
411 |
412 | 支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用。
413 |
414 | - GPU环境下导出ONNX模型,用ONNX模型推理,可以获得10倍以上的推理加速,需要安装`onnxruntime-gpu`库:`pip install onnxruntime-gpu`
415 | - CPU环境下导出ONNX模型,用ONNX模型推理,可以获得6倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime`
416 |
417 | 示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py)
418 |
419 | ```python
420 | import os
421 | import shutil
422 | import sys
423 | import time
424 |
425 | import torch
426 |
427 | sys.path.append('..')
428 | from pytextclassifier import BertClassifier
429 |
430 | m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2,
431 | model_type='bert', model_name='bert-base-chinese', num_epochs=1)
432 | data = [
433 | ('education', '名师指导托福语法技巧:名词的复数形式'),
434 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
435 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
436 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'),
437 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'),
438 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'),
439 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'),
440 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
441 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
442 | ('sports', '米兰客场8战不败国米10年连胜1'),
443 | ('sports', '米兰客场8战不败国米10年连胜2'),
444 | ('sports', '米兰客场8战不败国米10年连胜3'),
445 | ('sports', '米兰客场8战不败国米10年连胜4'),
446 | ('sports', '米兰客场8战不败国米10年连胜5'),
447 | ]
448 | m.train(data * 10)
449 | m.load_model()
450 |
451 | samples = ['名师指导托福语法技巧',
452 | '米兰客场8战不败',
453 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100
454 |
455 | start_time = time.time()
456 | predict_label_bert, predict_proba_bert = m.predict(samples)
457 | print(f'predict_label_bert size: {len(predict_label_bert)}')
458 | end_time = time.time()
459 | elapsed_time_bert = end_time - start_time
460 | print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds')
461 |
462 | # convert to onnx, and load onnx model to predict, speed up 10x
463 | save_onnx_dir = 'models/bert-chinese-v1/onnx'
464 | m.model.convert_to_onnx(save_onnx_dir)
465 | # copy label_vocab.json to save_onnx_dir
466 | if os.path.exists(m.label_vocab_path):
467 | shutil.copy(m.label_vocab_path, save_onnx_dir)
468 |
469 | # Manually delete the model and clear CUDA cache
470 | del m
471 | torch.cuda.empty_cache()
472 |
473 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir,
474 | args={"onnx": True})
475 | m.load_model()
476 | start_time = time.time()
477 | predict_label_bert, predict_proba_bert = m.predict(samples)
478 | print(f'predict_label_bert size: {len(predict_label_bert)}')
479 | end_time = time.time()
480 | elapsed_time_onnx = end_time - start_time
481 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
482 | ```
483 |
484 | #### 推理耗时评测
485 | 评测脚本:[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/tests/test_bert_onnx_bs_qps.py)
486 |
487 |
488 | ##### GPU (Tesla T4)
489 |
490 | | Model Type | Batch Size | Average QPS | Average Latency (s) |
491 | |--------------------|-------------|-------------|---------------------|
492 | | **Standard BERT** | 1 | 9.67 | 0.1034 |
493 | | | 8 | 34.85 | 0.0287 |
494 | | | 16 | 42.23 | 0.0237 |
495 | | | 32 | 46.79 | 0.0214 |
496 | | | 64 | 48.79 | 0.0205 |
497 | | | 128 | **50.15** | 0.0199 |
498 | | **ONNX Model** | 1 | 121.89 | 0.0082 |
499 | | | 8 | 123.38 | 0.0081 |
500 | | | 16 | 132.26 | 0.0076 |
501 | | | 32 | 128.33 | 0.0078 |
502 | | | 64 | **134.59** | 0.0074 |
503 | | | 128 | 128.94 | 0.0078 |
504 |
505 | ##### CPU (10核 Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz)
506 |
507 | | Model Type | Batch Size | Average QPS | Average Latency (s) |
508 | |--------------------|-------------|-------------|---------------------|
509 | | **Standard BERT** | 1 | 4.87 | 0.2053 |
510 | | | 8 | **9.21** | 0.1086 |
511 | | | 16 | 7.59 | 0.1318 |
512 | | | 32 | 7.48 | 0.1337 |
513 | | | 64 | 7.01 | 0.1426 |
514 | | | 128 | 6.34 | 0.1576 |
515 | | **ONNX Model** | 1 | **65.25** | 0.0153 |
516 | | | 8 | 52.93 | 0.0189 |
517 | | | 16 | 56.99 | 0.0175 |
518 | | | 32 | 55.03 | 0.0182 |
519 | | | 64 | 56.23 | 0.0178 |
520 | | | 128 | 46.22 | 0.0216 |
521 |
522 |
523 | ## Evaluation
524 |
525 | ### Dataset
526 |
527 | 1. THUCNews中文文本数据集(1.56GB):官方[下载地址](http://thuctc.thunlp.org/),抽样了10万条THUCNews中文文本10分类数据集(6MB),地址:[examples/thucnews_train_10w.txt](https://github.com/shibing624/pytextclassifier/blob/master/examples/thucnews_train_10w.txt)。
528 | 2. TNEWS今日头条中文新闻(短文本)分类 Short Text Classificaiton for News,该数据集(5.1MB)来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等,地址:[tnews_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip)
529 |
530 | ### Evaluation Result
531 | 在THUCNews中文文本10分类数据集(6MB)上评估,模型在测试集(test)评测效果如下:
532 |
533 | | 模型 | acc | 说明 |
534 | |-------------|------------|------------------------------|
535 | | LR | 0.8803 | 逻辑回归Logistics Regression |
536 | | TextCNN | 0.8809 | Kim 2014 经典的CNN文本分类 |
537 | | TextRNN_Att | 0.9022 | BiLSTM+Attention |
538 | | FastText | 0.9177 | bow+bigram+trigram, 效果出奇的好 |
539 | | DPCNN | 0.9125 | 深层金字塔CNN |
540 | | Transformer | 0.8991 | 效果较差 |
541 | | BERT-base | **0.9483** | bert + fc |
542 | | ERNIE | 0.9461 | 比bert略差 |
543 |
544 | 在中文新闻短文本分类数据集TNEWS上评估,模型在开发集(dev)评测效果如下:
545 |
546 | 模型|acc|说明
547 | --|--|--
548 | BERT-base|**0.5660**|本项目实现
549 | BERT-base|0.5609|CLUE Benchmark Leaderboard结果 [CLUEbenchmark](https://github.com/CLUEbenchmark/CLUE)
550 |
551 | - 以上结果均为分类的准确率(accuracy)结果
552 | - THUCNews数据集评测结果可以基于`examples/thucnews_train_10w.txt`数据用`examples`下的各模型demo复现
553 | - TNEWS数据集评测结果可以下载TNEWS数据集,运行`examples/bert_classification_tnews_demo.py`复现
554 |
555 | ### 命令行调用
556 |
557 | 提供分类模型命令行调用脚本,文件树:
558 | ```bash
559 | pytextclassifier
560 | ├── bert_classifier.py
561 | ├── fasttext_classifier.py
562 | ├── classic_classifier.py
563 | ├── textcnn_classifier.py
564 | └── textrnn_classifier.py
565 | ```
566 |
567 | 每个文件对应一个模型方法,各模型完全独立,可以直接运行,也方便修改,支持通过`argparse` 修改`--data_path`等参数。
568 |
569 | 直接在终端调用fasttext模型训练:
570 | ```bash
571 | python -m pytextclassifier.fasttext_classifier -h
572 | ```
573 |
574 | ## Text Cluster
575 |
576 |
577 | Text clustering, for example [examples/cluster_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/cluster_demo.py)
578 |
579 | ```python
580 | import sys
581 |
582 | sys.path.append('..')
583 | from pytextclassifier.textcluster import TextCluster
584 |
585 | if __name__ == '__main__':
586 | m = TextCluster(output_dir='models/cluster-toy', n_clusters=2)
587 | print(m)
588 | data = [
589 | 'Student debt to cost Britain billions within decades',
590 | 'Chinese education for TV experiment',
591 | 'Abbott government spends $8 million on higher education',
592 | 'Middle East and Asia boost investment in top level sports',
593 | 'Summit Series look launches HBO Canada sports doc series: Mudhar'
594 | ]
595 | m.train(data)
596 | m.load_model()
597 | r = m.predict(['Abbott government spends $8 million on higher education media blitz',
598 | 'Middle East and Asia boost investment in top level sports'])
599 | print(r)
600 |
601 | ########### load chinese train data from 1w data file
602 | from sklearn.feature_extraction.text import TfidfVectorizer
603 |
604 | tcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10)
605 | data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\t', use_col=1)
606 | feature, labels = tcluster.train(data[:5000])
607 | tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png')
608 | r = tcluster.predict(data[:30])
609 | print(r)
610 | ```
611 |
612 | output:
613 |
614 | ```
615 | TextCluster instance (MiniBatchKMeans(n_clusters=2, n_init=10), , TfidfVectorizer(ngram_range=(1, 2)))
616 | [1 1 1 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1 1 9 1 1 8 1 1 9 1]
617 | ```
618 | clustering plot image:
619 |
620 | 
621 |
622 |
623 | ## Contact
624 |
625 | - Issue(建议):[](https://github.com/shibing624/pytextclassifier/issues)
626 | - 邮件我:xuming: xuming624@qq.com
627 | - 微信我:加我*微信号:xuming624*, 进Python-NLP交流群,备注:*姓名-公司名-NLP*
628 |
629 |
630 |
631 | ## Citation
632 |
633 | 如果你在研究中使用了pytextclassifier,请按如下格式引用:
634 |
635 | APA:
636 | ```latex
637 | Xu, M. Pytextclassifier: Text classifier toolkit for NLP (Version 1.2.0) [Computer software]. https://github.com/shibing624/pytextclassifier
638 | ```
639 |
640 | BibTeX:
641 | ```latex
642 | @misc{Pytextclassifier,
643 | title={Pytextclassifier: Text classifier toolkit for NLP},
644 | author={Xu Ming},
645 | year={2022},
646 | howpublished={\url{https://github.com/shibing624/pytextclassifier}},
647 | }
648 | ```
649 |
650 |
651 | ## License
652 |
653 |
654 | 授权协议为 [The Apache License 2.0](LICENSE),可免费用做商业用途。请在产品说明中附加**pytextclassifier**的链接和授权协议。
655 |
656 |
657 | ## Contribute
658 | 项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
659 |
660 | - 在`tests`添加相应的单元测试
661 | - 使用`python setup.py test`来运行所有单元测试,确保所有单测都是通过的
662 |
663 | 之后即可提交PR。
664 |
665 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
--------------------------------------------------------------------------------
/docs/Text classification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/Text classification.png
--------------------------------------------------------------------------------
/docs/cluster_train_seg_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/cluster_train_seg_samples.png
--------------------------------------------------------------------------------
/docs/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/img.png
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/logo.png
--------------------------------------------------------------------------------
/docs/logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/wechat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shibing624/pytextclassifier/381042e53a8af816e0b5d34f3e1888a80c513401/docs/wechat.jpeg
--------------------------------------------------------------------------------
/examples/albert_classification_zh_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import BertClassifier
10 |
11 | if __name__ == '__main__':
12 | m = BertClassifier(output_dir='models/albert-chinese-toy', num_classes=2,
13 | model_type='albert', model_name='uer/albert-base-chinese-cluecorpussmall', num_epochs=2)
14 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
15 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
16 | data = [
17 | ('education', '名师指导托福语法技巧:名词的复数形式'),
18 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
19 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
20 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
21 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
22 | ('sports', '米兰客场8战不败国米10年连胜'),
23 | ]
24 | m.train(data)
25 | print(m)
26 | # load trained best model from output_dir
27 | m.load_model()
28 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
29 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
30 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
31 |
32 | test_data = [
33 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
34 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
35 | ]
36 | acc_score = m.evaluate_model(test_data)
37 | print(f'acc_score: {acc_score}') # 1.0
38 |
39 | # train model with 1w data file and 10 classes
40 | print('-' * 42)
41 | m = BertClassifier(output_dir='models/albert-chinese', num_classes=10,
42 | model_type='albert', model_name='uer/albert-base-chinese-cluecorpussmall', num_epochs=2,
43 | args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, })
44 | data_file = 'thucnews_train_1w.txt'
45 | # 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用
46 | m.train(data_file, test_size=0, names=('labels', 'text'))
47 | m.load_model()
48 | predict_label, predict_proba = m.predict(
49 | ['顺义北京苏活88平米起精装房在售',
50 | '美EB-5项目“15日快速移民”将推迟',
51 | '恒生AH溢指收平 A股对H股折价1.95%'])
52 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
53 |
--------------------------------------------------------------------------------
/examples/bert_classification_en_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import BertClassifier
10 |
11 | if __name__ == '__main__':
12 | m = BertClassifier(output_dir='models/bert-english-toy', num_classes=2,
13 | model_type='bert', model_name='bert-base-uncased', num_epochs=2)
14 | data = [
15 | ('education', 'Student debt to cost Britain billions within decades'),
16 | ('education', 'Chinese education for TV experiment'),
17 | ('sports', 'Middle East and Asia boost investment in top level sports'),
18 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
19 | ]
20 | m.train(data)
21 | print(m)
22 | # load trained best model
23 | m.load_model()
24 | predict_label, predict_proba = m.predict(['Abbott government spends $8 million on higher education media blitz',
25 | 'Middle East and Asia boost investment in top level sports'])
26 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
27 |
28 | test_data = [
29 | ('education', 'Abbott government spends $8 million on higher education media blitz'),
30 | ('sports', 'Middle East and Asia boost investment in top level sports'),
31 | ]
32 | acc_score = m.evaluate_model(test_data)
33 | print(f'acc_score: {acc_score}')
34 |
--------------------------------------------------------------------------------
/examples/bert_classification_tnews_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: 字节TNEWS新闻分类标准数据集评估模型
5 | """
6 | import sys
7 | import json
8 |
9 | sys.path.append('..')
10 | from pytextclassifier import BertClassifier
11 |
12 | def convert_json_to_csv(path):
13 | lst = []
14 | with open(path, 'r', encoding='utf-8') as f:
15 | for line in f:
16 | d = json.loads(line.strip('\n'))
17 | lst.append((d['label_desc'], d['sentence']))
18 | return lst
19 |
20 |
21 | if __name__ == '__main__':
22 | # train model with TNEWS data file
23 | train_file = './TNEWS/train.json'
24 | dev_file = './TNEWS/dev.json'
25 | train_data = convert_json_to_csv(train_file)[:None]
26 | dev_data = convert_json_to_csv(dev_file)[:None]
27 | print('train_data head:', train_data[:10])
28 |
29 | m = BertClassifier(output_dir='models/bert-tnews', num_classes=15,
30 | model_type='bert', model_name='bert-base-chinese',
31 | batch_size=32, num_epochs=5)
32 | m.train(train_data, test_size=0)
33 | m.load_model()
34 | # {"label": "102", "label_desc": "news_entertainment", "sentence": "江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物", "keywords": "江疏影,美少女,经纪人,甜甜圈"}
35 | # {"label": "110", "label_desc": "news_military", "sentence": "以色列大规模空袭开始!伊朗多个军事目标遭遇打击,誓言对等反击", "keywords": "伊朗,圣城军,叙利亚,以色列国防军,以色列"}
36 | predict_label, predict_proba = m.predict(
37 | ['江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物',
38 | '以色列大规模空袭开始!伊朗多个军事目标遭遇打击,誓言对等反击',
39 | ])
40 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
41 |
42 | score = m.evaluate_model(dev_data)
43 | print(f'score: {score}') # acc: 0.5643
44 |
45 |
--------------------------------------------------------------------------------
/examples/bert_classification_zh_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import BertClassifier
10 |
11 | if __name__ == '__main__':
12 | m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2,
13 | model_type='bert', model_name='bert-base-chinese', num_epochs=2)
14 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
15 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
16 | data = [
17 | ('education', '名师指导托福语法技巧:名词的复数形式'),
18 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
19 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
20 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'),
21 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'),
22 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'),
23 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'),
24 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
25 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
26 | ('sports', '米兰客场8战不败国米10年连胜1'),
27 | ('sports', '米兰客场8战不败国米10年连胜2'),
28 | ('sports', '米兰客场8战不败国米10年连胜3'),
29 | ('sports', '米兰客场8战不败国米10年连胜4'),
30 | ('sports', '米兰客场8战不败国米10年连胜5'),
31 | ]
32 | m.train(data)
33 | print(m)
34 | # load trained best model from output_dir
35 | m.load_model()
36 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
37 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
38 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
39 |
40 | test_data = [
41 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
42 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
43 | ]
44 | acc_score = m.evaluate_model(test_data)
45 | print(f'acc_score: {acc_score}') # 1.0
46 |
47 | # train model with 1w data file and 10 classes
48 | print('-' * 42)
49 | m = BertClassifier(output_dir='models/bert-chinese', num_classes=10,
50 | model_type='bert', model_name='bert-base-chinese', num_epochs=2,
51 | args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, })
52 | data_file = 'thucnews_train_1w.txt'
53 | # 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用
54 | m.train(data_file, test_size=0, names=('labels', 'text'))
55 | m.load_model()
56 | predict_label, predict_proba = m.predict(
57 | ['顺义北京苏活88平米起精装房在售',
58 | '美EB-5项目“15日快速移民”将推迟',
59 | '恒生AH溢指收平 A股对H股折价1.95%'])
60 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
61 |
--------------------------------------------------------------------------------
/examples/bert_hierarchical_classification_zh_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 | import pandas as pd
8 |
9 | sys.path.append('..')
10 | from pytextclassifier import BertClassifier
11 |
12 |
13 | def load_baidu_data(file_path):
14 | """
15 | Load baidu data from file.
16 | @param file_path:
17 | format: content,labels
18 | @return:
19 | """
20 | data = []
21 | with open(file_path, 'r', encoding='utf-8') as f:
22 | for line in f:
23 | line = line.strip()
24 | if line.startswith('#'):
25 | continue
26 | if not line:
27 | continue
28 | terms = line.split('\t')
29 | if len(terms) != 2:
30 | continue
31 | data.append([terms[0], terms[1]])
32 | return data
33 |
34 |
35 | if __name__ == '__main__':
36 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
37 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
38 | m = BertClassifier(output_dir='models/hierarchical-bert-zh-model', num_classes=34,
39 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True)
40 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
41 | train_data = [
42 | ["马国明承认与黄心颖分手,女方被娱乐圈封杀,现如今朋友关系", "人生,人生##分手"],
43 | ["RedmiBook14集显版明天首发:酷睿i5+8G内存3799元", "产品行为,产品行为##发布"],
44 | ]
45 | data = load_baidu_data('baidu_extract_2020_train.csv')
46 | train_data.extend(data)
47 | print(train_data[:5])
48 | train_df = pd.DataFrame(train_data, columns=["text", "labels"])
49 |
50 | print(train_df.head())
51 | m.train(train_df)
52 | print(m)
53 | # Evaluate the model
54 | acc_score = m.evaluate_model(train_df[:20])
55 | print(f'acc_score: {acc_score}')
56 |
57 | # load trained best model from output_dir
58 | m.load_model()
59 | predict_label, predict_proba = m.predict([
60 | '马国明承认与黄心颖分手,女方被娱乐圈封杀', 'RedmiBook14集显版明天首发'])
61 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
62 |
--------------------------------------------------------------------------------
/examples/bert_multilabel_classification_en_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 | import pandas as pd
8 |
9 | sys.path.append('..')
10 | from pytextclassifier import BertClassifier
11 |
12 | if __name__ == '__main__':
13 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
14 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
15 | m = BertClassifier(output_dir='models/multilabel-bert-en-toy', num_classes=6,
16 | model_type='bert', model_name='bert-base-uncased', num_epochs=2, multi_label=True)
17 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
18 | train_data = [
19 | ["Example sentence 1 for multilabel classification.", [1, 1, 1, 1, 0, 1]],
20 | ["This is another example sentence. ", [0, 1, 1, 0, 0, 0]],
21 | ["This is the third example sentence. ", [1, 1, 1, 0, 0, 0]],
22 | ]
23 | train_df = pd.DataFrame(train_data, columns=["text", "labels"])
24 |
25 | eval_data = [
26 | ["Example eval sentence for multilabel classification.", [1, 1, 1, 1, 0, 1]],
27 | ["Example eval senntence belonging to class 2", [0, 1, 1, 0, 0, 0]],
28 | ]
29 | eval_df = pd.DataFrame(eval_data, columns=["text", "labels"])
30 | print(train_df.head())
31 | # m.train(train_df)
32 | print(m)
33 | # Evaluate the model
34 | acc_score = m.evaluate_model(eval_df)
35 | print(f'acc_score: {acc_score}')
36 |
37 | # load trained best model from output_dir
38 | m.load_model()
39 | predict_label, predict_proba = m.predict(['some new sentence'])
40 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
41 |
--------------------------------------------------------------------------------
/examples/bert_multilabel_classification_zh_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 | import pandas as pd
8 |
9 | sys.path.append('..')
10 | from pytextclassifier import BertClassifier
11 |
12 |
13 | def load_jd_data(file_path):
14 | """
15 | Load jd data from file.
16 | @param file_path:
17 | format: content,其他,互联互通,产品功耗,滑轮提手,声音,APP操控性,呼吸灯,外观,底座,制热范围,遥控器电池,味道,制热效果,衣物烘干,体积大小
18 | @return:
19 | """
20 | data = []
21 | with open(file_path, 'r', encoding='utf-8') as f:
22 | for line in f:
23 | line = line.strip()
24 | if line.startswith('#'):
25 | continue
26 | if not line:
27 | continue
28 | terms = line.split(',')
29 | if len(terms) != 16:
30 | continue
31 | val = [int(i) for i in terms[1:]]
32 | data.append([terms[0], val])
33 | return data
34 |
35 |
36 | if __name__ == '__main__':
37 | # model_type: support 'bert', 'albert', 'roberta', 'xlnet'
38 | # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
39 | m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15,
40 | model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True)
41 | # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
42 | train_data = [
43 | ["一个小时房间仍然没暖和", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
44 | ["耗电情况:这个没有注意", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
45 | ]
46 | data = load_jd_data('multilabel_jd_comments.csv')
47 | train_data.extend(data)
48 | print(train_data[:5])
49 | train_df = pd.DataFrame(train_data, columns=["text", "labels"])
50 |
51 | print(train_df.head())
52 | m.train(train_df)
53 | print(m)
54 | # Evaluate the model
55 | acc_score = m.evaluate_model(train_df[:20])
56 | print(f'acc_score: {acc_score}')
57 |
58 | # load trained best model from output_dir
59 | m.load_model()
60 | predict_label, predict_proba = m.predict(['一个小时房间仍然没暖和', '耗电情况:这个没有注意'])
61 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
62 |
--------------------------------------------------------------------------------
/examples/cluster_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier.textcluster import TextCluster
10 |
11 | if __name__ == '__main__':
12 | m = TextCluster(output_dir='models/cluster-toy', n_clusters=2)
13 | print(m)
14 | data = [
15 | 'Student debt to cost Britain billions within decades',
16 | 'Chinese education for TV experiment',
17 | 'Abbott government spends $8 million on higher education',
18 | 'Middle East and Asia boost investment in top level sports',
19 | 'Summit Series look launches HBO Canada sports doc series: Mudhar'
20 | ]
21 | m.train(data)
22 | m.load_model()
23 | r = m.predict(['Abbott government spends $8 million on higher education media blitz',
24 | 'Middle East and Asia boost investment in top level sports'])
25 | print(r)
26 |
27 | ########### load chinese train data from 1w data file
28 | from sklearn.feature_extraction.text import TfidfVectorizer
29 |
30 | tcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10)
31 | data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\t', use_col=1)
32 | feature, labels = tcluster.train(data[:5000])
33 | tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png')
34 | r = tcluster.predict(data[:30])
35 | print(r)
36 |
--------------------------------------------------------------------------------
/examples/fasttext_classification_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import FastTextClassifier, load_data
10 |
11 | if __name__ == '__main__':
12 | m = FastTextClassifier(output_dir='models/fasttext-toy', enable_ngram=False)
13 | data = [
14 | ('education', '名师指导托福语法技巧:名词的复数形式'),
15 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
16 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
17 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
18 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
19 | ('sports', '米兰客场8战不败保持连胜'),
20 | ]
21 | m.train(data, num_epochs=3)
22 | print(m)
23 | # load trained best model
24 | m.load_model()
25 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
26 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
27 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
28 | test_data = [
29 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
30 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
31 | ]
32 | acc_score = m.evaluate_model(test_data)
33 | print(f'acc_score: {acc_score}') # 1.0
34 |
35 | #### train model with 1w data
36 | print('-' * 42)
37 | data_file = 'thucnews_train_1w.txt'
38 | m = FastTextClassifier(output_dir='models/fasttext')
39 | m.train(data_file, names=('labels', 'text'), num_epochs=3)
40 | # load best trained model from output_dir
41 | m.load_model()
42 | predict_label, predict_proba = m.predict(
43 | ['顺义北京苏活88平米起精装房在售',
44 | '美EB-5项目“15日快速移民”将推迟']
45 | )
46 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
47 | x, y, df = load_data(data_file)
48 | test_data = df[:100]
49 | acc_score = m.evaluate_model(test_data)
50 | print(f'acc_score: {acc_score}')
51 |
--------------------------------------------------------------------------------
/examples/lr_classification_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import ClassicClassifier
10 |
11 | if __name__ == '__main__':
12 | m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')
13 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
14 | data = [
15 | ('education', '名师指导托福语法技巧:名词的复数形式'),
16 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
17 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
18 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
19 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
20 | ('sports', '米兰客场8战不败国米10年连胜'),
21 | ]
22 | m.train(data)
23 | print(m)
24 | # load best model from output_dir
25 | m.load_model()
26 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
27 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
28 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
29 |
30 | test_data = [
31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
33 | ]
34 | acc_score = m.evaluate_model(test_data)
35 | print(f'acc_score: {acc_score}') # 1.0
36 |
37 | # train model with 1w data
38 | print('-' * 42)
39 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')
40 | data_file = 'thucnews_train_1w.txt'
41 | m.train(data_file)
42 | m.load_model()
43 | predict_label, predict_proba = m.predict(
44 | ['顺义北京苏活88平米起精装房在售',
45 | '美EB-5项目“15日快速移民”将推迟'])
46 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
47 |
--------------------------------------------------------------------------------
/examples/lr_en_classification_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import ClassicClassifier
10 | from loguru import logger
11 |
12 | logger.remove() # Remove default log handler
13 | logger.add(sys.stderr, level="INFO") # 设置log级别
14 |
15 | if __name__ == '__main__':
16 | m = ClassicClassifier(output_dir='models/lr-english-toy', model_name_or_model='lr')
17 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
18 | print(m)
19 | data = [
20 | ('education', 'Student debt to cost Britain billions within decades'),
21 | ('education', 'Chinese education for TV experiment'),
22 | ('sports', 'Middle East and Asia boost investment in top level sports'),
23 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
24 | ]
25 | # train and save best model
26 | m.train(data)
27 | # load best model from output_dir
28 | m.load_model()
29 | predict_label, predict_proba = m.predict([
30 | 'Abbott government spends $8 million on higher education media blitz'])
31 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
32 |
33 | test_data = [
34 | ('education', 'Abbott government spends $8 million on higher education media blitz'),
35 | ('sports', 'Middle East and Asia boost investment in top level sports'),
36 | ]
37 | acc_score = m.evaluate_model(test_data)
38 | print(f'acc_score: {acc_score}')
39 |
--------------------------------------------------------------------------------
/examples/my_vectorizer_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import ClassicClassifier
10 | from sklearn.feature_extraction.text import CountVectorizer
11 |
12 | if __name__ == '__main__':
13 | vec = CountVectorizer(ngram_range=(1, 3))
14 | m = ClassicClassifier(output_dir='models/lr-vec', model_name_or_model='lr', feature_name_or_feature=vec)
15 |
16 | data = [
17 | ('education', '名师指导托福语法技巧:名词的复数形式'),
18 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
19 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
20 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
21 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
22 | ('sports', '米兰客场8战不败国米10年连胜'),
23 | ]
24 | m.train(data)
25 | m.load_model()
26 | predict_label, predict_label_prob = m.predict(['福建春季公务员考试报名18日截止 2月6日考试'])
27 | print(predict_label, predict_label_prob)
28 | print('classes_: ', m.model.classes_) # the classes ordered as prob
29 |
30 | test_data = [
31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
33 | ]
34 | acc_score = m.evaluate_model(test_data)
35 | print(acc_score) # 1.0
36 |
--------------------------------------------------------------------------------
/examples/onnx_predict_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 |
7 | import os
8 | import shutil
9 | import sys
10 | import time
11 |
12 | import torch
13 |
14 | sys.path.append('..')
15 | from pytextclassifier import BertClassifier
16 |
17 | if __name__ == '__main__':
18 | m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2,
19 | model_type='bert', model_name='bert-base-chinese', num_epochs=1)
20 | data = [
21 | ('education', '名师指导托福语法技巧:名词的复数形式'),
22 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
23 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
24 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'),
25 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'),
26 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'),
27 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'),
28 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
29 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
30 | ('sports', '米兰客场8战不败国米10年连胜1'),
31 | ('sports', '米兰客场8战不败国米10年连胜2'),
32 | ('sports', '米兰客场8战不败国米10年连胜3'),
33 | ('sports', '米兰客场8战不败国米10年连胜4'),
34 | ('sports', '米兰客场8战不败国米10年连胜5'),
35 | ]
36 | # m.train(data * 10)
37 | m.load_model()
38 |
39 | samples = ['名师指导托福语法技巧',
40 | '米兰客场8战不败',
41 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100
42 |
43 | start_time = time.time()
44 | predict_label_bert, predict_proba_bert = m.predict(samples)
45 | print(f'predict_label_bert size: {len(predict_label_bert)}')
46 | end_time = time.time()
47 | elapsed_time_bert = end_time - start_time
48 | print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds')
49 |
50 | # convert to onnx, and load onnx model to predict, speed up 10x
51 | save_onnx_dir = 'models/bert-chinese-v1/onnx'
52 | m.model.convert_to_onnx(save_onnx_dir)
53 | # copy label_vocab.json to save_onnx_dir
54 | if os.path.exists(m.label_vocab_path):
55 | shutil.copy(m.label_vocab_path, save_onnx_dir)
56 |
57 | # Manually delete the model and clear CUDA cache
58 | del m
59 | torch.cuda.empty_cache()
60 |
61 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir,
62 | args={"onnx": True})
63 | m.load_model()
64 | start_time = time.time()
65 | predict_label_bert, predict_proba_bert = m.predict(samples)
66 | print(f'predict_label_bert size: {len(predict_label_bert)}')
67 | end_time = time.time()
68 | elapsed_time_onnx = end_time - start_time
69 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
70 |
--------------------------------------------------------------------------------
/examples/onnx_xlnet_predict_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 |
7 | import os
8 | import shutil
9 | import sys
10 | import time
11 |
12 | import torch
13 |
14 | sys.path.append('..')
15 | from pytextclassifier import BertClassifier
16 |
17 | if __name__ == '__main__':
18 | m = BertClassifier(output_dir='models/xlnet-chinese-v1', num_classes=2,
19 | model_type='xlnet', model_name='hfl/chinese-xlnet-base', num_epochs=1)
20 | data = [
21 | ('education', '名师指导托福语法技巧:名词的复数形式'),
22 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
23 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
24 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'),
25 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'),
26 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'),
27 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'),
28 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
29 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
30 | ('sports', '米兰客场8战不败国米10年连胜1'),
31 | ('sports', '米兰客场8战不败国米10年连胜2'),
32 | ('sports', '米兰客场8战不败国米10年连胜3'),
33 | ('sports', '米兰客场8战不败国米10年连胜4'),
34 | ('sports', '米兰客场8战不败国米10年连胜5'),
35 | ]
36 | m.train(data * 1)
37 | m.load_model()
38 |
39 | samples = ['名师指导托福语法技巧',
40 | '米兰客场8战不败',
41 | '恒生AH溢指收平 A股对H股折价1.95%'] * 10
42 |
43 | start_time = time.time()
44 | predict_label_bert, predict_proba_bert = m.predict(samples)
45 | print(f'predict_label_bert size: {len(predict_label_bert)}')
46 | end_time = time.time()
47 | elapsed_time_bert = end_time - start_time
48 | print(f'Standard xlnet model prediction time: {elapsed_time_bert} seconds')
49 |
50 | # convert to onnx, and load onnx model to predict, speed up 10x
51 | save_onnx_dir = 'models/xlnet-chinese-v1/onnx'
52 | m.model.convert_to_onnx(save_onnx_dir)
53 | # copy label_vocab.json to save_onnx_dir
54 | if os.path.exists(m.label_vocab_path):
55 | shutil.copy(m.label_vocab_path, save_onnx_dir)
56 |
57 | # Manually delete the model and clear CUDA cache
58 | del m
59 | torch.cuda.empty_cache()
60 |
61 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='xlnet', model_name=save_onnx_dir,
62 | args={"onnx": True})
63 | m.load_model()
64 | start_time = time.time()
65 | predict_label_bert, predict_proba_bert = m.predict(samples)
66 | print(f'predict_label_bert size: {len(predict_label_bert)}')
67 | end_time = time.time()
68 | elapsed_time_onnx = end_time - start_time
69 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
70 |
--------------------------------------------------------------------------------
/examples/random_forest_classification_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import ClassicClassifier
10 |
11 | if __name__ == '__main__':
12 | m = ClassicClassifier(output_dir='models/random_forest-toy', model_name_or_model='random_forest')
13 | # 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
14 | print(m)
15 | data = [
16 | ('education', 'Student debt to cost Britain billions within decades'),
17 | ('education', 'Chinese education for TV experiment'),
18 | ('sports', 'Middle East and Asia boost investment in top level sports'),
19 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
20 | ]
21 | m.train(data)
22 | m.load_model()
23 | predict_label, predict_proba = m.predict([
24 | 'Abbott government spends $8 million on higher education media blitz'])
25 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
26 |
27 | test_data = [
28 | ('education', 'Abbott government spends $8 million on higher education media blitz'),
29 | ('sports', 'Middle East and Asia boost investment in top level sports'),
30 | ]
31 | acc_score = m.evaluate_model(test_data)
32 | print(f'acc_score: {acc_score}') # 1.0
33 |
--------------------------------------------------------------------------------
/examples/textcnn_classification_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import TextCNNClassifier
10 |
11 | if __name__ == '__main__':
12 | m = TextCNNClassifier(output_dir='models/textcnn-toy')
13 | data = [
14 | ('education', '名师指导托福语法技巧:名词的复数形式'),
15 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
16 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
17 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
18 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
19 | ('sports', '米兰客场8战不败国米10年连胜')
20 | ]
21 | # train and save best model
22 | m.train(data, num_epochs=3, evaluate_during_training_steps=1)
23 | print(m)
24 | # load best model from output_dir
25 | m.load_model()
26 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
27 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
28 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
29 |
30 | test_data = [
31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
33 | ]
34 | acc_score = m.evaluate_model(test_data)
35 | print(f'acc_score: {acc_score}') # 1.0
36 |
--------------------------------------------------------------------------------
/examples/textrnn_classification_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import sys
7 |
8 | sys.path.append('..')
9 | from pytextclassifier import TextRNNClassifier
10 |
11 | if __name__ == '__main__':
12 | m = TextRNNClassifier(output_dir='models/textrnn-toy')
13 | data = [
14 | ('education', '名师指导托福语法技巧:名词的复数形式'),
15 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
16 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
17 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
18 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
19 | ('sports', '米兰客场8战不败国米10年连胜')
20 | ]
21 | # train and save best model
22 | m.train(data, num_epochs=3, evaluate_during_training_steps=1)
23 | print(m)
24 | # load best model from output_dir
25 | m.load_model()
26 | predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
27 | '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
28 | print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
29 |
30 | test_data = [
31 | ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
32 | ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
33 | ]
34 | acc_score = m.evaluate_model(test_data)
35 | print(f'acc_score: {acc_score}')
36 |
--------------------------------------------------------------------------------
/pytextclassifier/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | __version__ = '1.4.0'
7 |
8 | from pytextclassifier.classic_classifier import ClassicClassifier
9 | from pytextclassifier.fasttext_classifier import FastTextClassifier
10 | from pytextclassifier.textcnn_classifier import TextCNNClassifier
11 | from pytextclassifier.textrnn_classifier import TextRNNClassifier
12 | from pytextclassifier.bert_classifier import BertClassifier
13 | from pytextclassifier.base_classifier import load_data
14 | from pytextclassifier.textcluster import TextCluster
15 | from pytextclassifier.bert_classification_model import BertClassificationModel, BertClassificationArgs
16 |
--------------------------------------------------------------------------------
/pytextclassifier/base_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import os
7 |
8 | import pandas as pd
9 | from loguru import logger
10 |
11 |
12 | def load_data(data_list_or_path, header=None, names=('labels', 'text'), delimiter='\t',
13 | labels_sep=',', is_train=False):
14 | """
15 | Encoding data_list text
16 | @param data_list_or_path: list of (label, text), eg: [(label, text), (label, text) ...]
17 | @param header: read_csv header
18 | @param names: read_csv names
19 | @param delimiter: read_csv sep
20 | @param labels_sep: multi label split
21 | @param is_train: is train data
22 | @return: X, y, data_df
23 | """
24 | if isinstance(data_list_or_path, list):
25 | data_df = pd.DataFrame(data_list_or_path, columns=names)
26 | elif isinstance(data_list_or_path, str) and os.path.exists(data_list_or_path):
27 | data_df = pd.read_csv(data_list_or_path, header=header, delimiter=delimiter, names=names)
28 | elif isinstance(data_list_or_path, pd.DataFrame):
29 | data_df = data_list_or_path
30 | else:
31 | raise TypeError('should be list or file path, eg: [(label, text), ... ]')
32 | X, y = data_df['text'], data_df['labels']
33 | labels = set()
34 | if y.size:
35 | for label in y.tolist():
36 | if isinstance(label, str):
37 | labels.update(label.split(labels_sep))
38 | elif isinstance(label, list):
39 | labels.update(range(len(label)))
40 | else:
41 | labels.add(label)
42 | num_classes = len(labels)
43 | labels = sorted(list(labels))
44 | logger.debug(f'loaded data list, X size: {len(X)}, y size: {len(y)}')
45 | if is_train:
46 | logger.debug('num_classes: %d, labels: %s' % (num_classes, labels))
47 | assert len(X) == len(y)
48 |
49 | return X, y, data_df
50 |
51 |
52 | class ClassifierABC:
53 | """
54 | Abstract class for classifier
55 | """
56 |
57 | def train(self, data_list_or_path, model_dir: str, **kwargs):
58 | raise NotImplementedError('train method not implemented.')
59 |
60 | def predict(self, sentences: list):
61 | raise NotImplementedError('predict method not implemented.')
62 |
63 | def evaluate_model(self, **kwargs):
64 | raise NotImplementedError('evaluate_model method not implemented.')
65 |
66 | def evaluate(self, **kwargs):
67 | raise NotImplementedError('evaluate method not implemented.')
68 |
69 | def load_model(self):
70 | raise NotImplementedError('load method not implemented.')
71 |
72 | def save_model(self):
73 | raise NotImplementedError('save method not implemented.')
74 |
--------------------------------------------------------------------------------
/pytextclassifier/bert_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: BERT Classifier, support 'bert', 'albert', 'roberta', 'xlnet' model
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import sys
10 |
11 | import numpy as np
12 | import torch
13 | from loguru import logger
14 | from sklearn.model_selection import train_test_split
15 |
16 | sys.path.append('..')
17 | from pytextclassifier.base_classifier import ClassifierABC, load_data
18 | from pytextclassifier.data_helper import set_seed
19 | from pytextclassifier.bert_classification_model import BertClassificationModel, BertClassificationArgs
20 |
21 | pwd_path = os.path.abspath(os.path.dirname(__file__))
22 | device = 'cuda' if torch.cuda.is_available() else (
23 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')
24 | default_use_cuda = torch.cuda.is_available()
25 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
26 |
27 |
28 | class BertClassifier(ClassifierABC):
29 | def __init__(
30 | self,
31 | num_classes,
32 | output_dir="outputs",
33 | model_type='bert',
34 | model_name='bert-base-chinese',
35 | num_epochs=3,
36 | batch_size=8,
37 | max_seq_length=256,
38 | multi_label=False,
39 | labels_sep=',',
40 | use_cuda=None,
41 | args=None,
42 | ):
43 |
44 | """
45 | Init classification model
46 | @param output_dir: output model dir
47 | @param model_type: support 'bert', 'albert', 'roberta', 'xlnet'
48 | @param model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
49 | @param num_classes: number of label classes
50 | @param num_epochs: train epochs
51 | @param batch_size: train batch size
52 | @param max_seq_length: max seq length, trim longer sentence.
53 | @param multi_label: bool, multi label or single label
54 | @param labels_sep: label separator, default is ','
55 | @param use_cuda: bool, use cuda or not
56 | @param args: dict, train args
57 | """
58 | default_args = {
59 | "output_dir": output_dir,
60 | "max_seq_length": max_seq_length,
61 | "num_train_epochs": num_epochs,
62 | "train_batch_size": batch_size,
63 | "best_model_dir": os.path.join(output_dir, 'best_model'),
64 | "labels_sep": labels_sep,
65 | }
66 | train_args = BertClassificationArgs()
67 | if args and isinstance(args, dict):
68 | train_args.update_from_dict(args)
69 | train_args.update_from_dict(default_args)
70 | if use_cuda is None:
71 | use_cuda = default_use_cuda
72 |
73 | self.model = BertClassificationModel(
74 | model_type=model_type,
75 | model_name=model_name,
76 | num_labels=num_classes,
77 | multi_label=multi_label,
78 | args=train_args,
79 | use_cuda=use_cuda,
80 | )
81 | self.output_dir = output_dir
82 | self.model_type = model_type
83 | self.num_classes = num_classes
84 | self.batch_size = batch_size
85 | self.max_seq_length = max_seq_length
86 | self.num_epochs = num_epochs
87 | self.train_args = train_args
88 | self.use_cuda = use_cuda
89 | self.multi_label = multi_label
90 | self.labels_sep = labels_sep
91 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json')
92 | self.is_trained = False
93 |
94 | def __str__(self):
95 | return f'BertClassifier instance ({self.model})'
96 |
97 | def train(
98 | self,
99 | data_list_or_path,
100 | dev_data_list_or_path=None,
101 | header=None,
102 | names=('labels', 'text'),
103 | delimiter='\t',
104 | test_size=0.1,
105 | ):
106 | """
107 | Train model with data_list_or_path and save model to output_dir
108 | @param data_list_or_path:
109 | @param dev_data_list_or_path:
110 | @param header:
111 | @param names:
112 | @param delimiter:
113 | @param test_size:
114 | @return:
115 | """
116 | logger.debug('train model ...')
117 | SEED = 1
118 | set_seed(SEED)
119 | # load data
120 | X, y, data_df = load_data(
121 | data_list_or_path,
122 | header=header,
123 | names=names,
124 | delimiter=delimiter,
125 | labels_sep=self.labels_sep,
126 | is_train=True
127 | )
128 | if self.output_dir:
129 | os.makedirs(self.output_dir, exist_ok=True)
130 | labels_map = self.build_labels_map(y, self.label_vocab_path, self.multi_label, self.labels_sep)
131 | labels_list = sorted(list(labels_map.keys()))
132 | if dev_data_list_or_path is not None:
133 | dev_X, dev_y, dev_df = load_data(
134 | dev_data_list_or_path,
135 | header=header, names=names,
136 | delimiter=delimiter,
137 | labels_sep=self.labels_sep,
138 | is_train=False
139 | )
140 | train_data = data_df
141 | dev_data = dev_df
142 | else:
143 | if test_size > 0:
144 | train_data, dev_data = train_test_split(data_df, test_size=test_size, random_state=SEED)
145 | else:
146 | train_data = data_df
147 | dev_data = None
148 | logger.debug(f"train_data size: {len(train_data)}")
149 | logger.debug(f'train_data sample:\n{train_data[:3]}')
150 | if dev_data is not None and dev_data.size:
151 | logger.debug(f"dev_data size: {len(dev_data)}")
152 | logger.debug(f'dev_data sample:\n{dev_data[:3]}')
153 | # train model
154 | if self.train_args.lazy_loading:
155 | train_data = data_list_or_path
156 | dev_data = dev_data_list_or_path
157 | if dev_data is not None:
158 | self.model.train_model(
159 | train_data, eval_df=dev_data,
160 | args={'labels_map': labels_map, 'labels_list': labels_list}
161 | )
162 | else:
163 | self.model.train_model(
164 | train_data,
165 | args={'labels_map': labels_map, 'labels_list': labels_list}
166 | )
167 | self.is_trained = True
168 | logger.debug('train model done')
169 |
170 | def predict(self, sentences: list):
171 | """
172 | Predict labels and label probability for sentences.
173 | @param sentences: list, input text list, eg: [text1, text2, ...]
174 | @return: predict_labels, predict_probs
175 | """
176 | if not self.is_trained:
177 | raise ValueError('model not trained.')
178 | # predict
179 | predictions, raw_outputs = self.model.predict(sentences)
180 | if self.multi_label:
181 | return predictions, raw_outputs
182 | else:
183 | # predict probability
184 | predict_probs = [1 - np.exp(-np.max(raw_output)) for raw_output, prediction in
185 | zip(raw_outputs, predictions)]
186 | return predictions, predict_probs
187 |
188 | def evaluate_model(
189 | self,
190 | data_list_or_path,
191 | header=None,
192 | names=('labels', 'text'),
193 | delimiter='\t',
194 | **kwargs
195 | ):
196 | """
197 | Evaluate model with data_list_or_path
198 | @param data_list_or_path:
199 | @param header:
200 | @param names:
201 | @param delimiter:
202 | @param kwargs:
203 | @return:
204 | """
205 | if self.train_args.lazy_loading:
206 | eval_df = data_list_or_path
207 | else:
208 | X_test, y_test, eval_df = load_data(
209 | data_list_or_path,
210 | header=header,
211 | names=names,
212 | delimiter=delimiter,
213 | labels_sep=self.labels_sep
214 | )
215 | if not self.is_trained:
216 | self.load_model()
217 | result, model_outputs, wrong_predictions = self.model.eval_model(
218 | eval_df,
219 | output_dir=self.output_dir,
220 | **kwargs,
221 | )
222 | return result
223 |
224 | def load_model(self):
225 | """
226 | Load model from output_dir
227 | @return:
228 | """
229 | model_config_file = os.path.join(self.output_dir, 'config.json')
230 | if os.path.exists(model_config_file):
231 | labels_map = json.load(open(self.label_vocab_path, 'r', encoding='utf-8'))
232 | labels_list = sorted(list(labels_map.keys()))
233 | num_classes = len(labels_map)
234 | assert num_classes == self.num_classes, f'num_classes not match, {num_classes} != {self.num_classes}'
235 | self.train_args.update_from_dict({'labels_map': labels_map, 'labels_list': labels_list})
236 | self.model = BertClassificationModel(
237 | model_type=self.model_type,
238 | model_name=self.output_dir,
239 | num_labels=self.num_classes,
240 | multi_label=self.multi_label,
241 | args=self.train_args,
242 | use_cuda=self.use_cuda,
243 | )
244 | self.is_trained = True
245 | else:
246 | logger.error(f'{model_config_file} not exists.')
247 | self.is_trained = False
248 | return self.is_trained
249 |
250 | @staticmethod
251 | def build_labels_map(y, label_vocab_path, multi_label=False, labels_sep=','):
252 | """
253 | Build labels map
254 | @param y:
255 | @param label_vocab_path:
256 | @param multi_label:
257 | @param labels_sep:
258 | @return:
259 | """
260 | if multi_label:
261 | labels = set()
262 | for label in y.tolist():
263 | if isinstance(label, str):
264 | labels.update(label.split(labels_sep))
265 | elif isinstance(label, list):
266 | labels.update(range(len(label)))
267 | else:
268 | labels.add(label)
269 | else:
270 | labels = set(y.tolist())
271 | labels = sorted(list(labels))
272 | id_label_map = {id: v for id, v in enumerate(labels)}
273 | label_id_map = {v: k for k, v in id_label_map.items()}
274 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
275 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}")
276 | return label_id_map
277 |
278 |
279 | if __name__ == '__main__':
280 | parser = argparse.ArgumentParser(description='Bert Text Classification')
281 | parser.add_argument('--pretrain_model_type', default='bert', type=str,
282 | help='pretrained huggingface model type')
283 | parser.add_argument('--pretrain_model_name', default='bert-base-chinese', type=str,
284 | help='pretrained huggingface model name')
285 | parser.add_argument('--output_dir', default='models/bert', type=str, help='save model dir')
286 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'),
287 | type=str, help='sample data file path')
288 | parser.add_argument('--num_classes', default=10, type=int, help='number of label classes')
289 | parser.add_argument('--num_epochs', default=3, type=int, help='train epochs')
290 | parser.add_argument('--batch_size', default=64, type=int, help='train batch size')
291 | parser.add_argument('--max_seq_length', default=128, type=int, help='max seq length, trim longer sentence.')
292 | args = parser.parse_args()
293 | print(args)
294 | # create model
295 | m = BertClassifier(
296 | num_classes=args.num_classes,
297 | output_dir=args.output_dir,
298 | model_type=args.pretrain_model_type,
299 | model_name=args.pretrain_model_name,
300 | num_epochs=args.num_epochs,
301 | batch_size=args.batch_size,
302 | max_seq_length=args.max_seq_length,
303 | multi_label=False,
304 | )
305 | # train model
306 | m.train(data_list_or_path=args.data_path)
307 | # load trained model and predict
308 | m.load_model()
309 | print('best model loaded from file, and predict')
310 | X, y, _ = load_data(args.data_path)
311 | X = X[:5]
312 | y = y[:5]
313 | predict_labels, predict_probs = m.predict(X)
314 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y):
315 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth)
316 |
--------------------------------------------------------------------------------
/pytextclassifier/bert_multi_label_classification_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import torch
7 | from torch import nn
8 | from torch.nn import BCEWithLogitsLoss
9 | from transformers import (
10 | BertModel,
11 | BertPreTrainedModel,
12 | FlaubertModel,
13 | LongformerModel,
14 | RemBertModel,
15 | RemBertPreTrainedModel,
16 | XLMModel,
17 | XLMPreTrainedModel,
18 | XLNetModel,
19 | XLNetPreTrainedModel,
20 | )
21 | from transformers.modeling_utils import SequenceSummary
22 | from transformers.models.albert.modeling_albert import (
23 | AlbertModel,
24 | AlbertPreTrainedModel,
25 | )
26 | from transformers.models.longformer.modeling_longformer import (
27 | LongformerClassificationHead,
28 | LongformerPreTrainedModel,
29 | )
30 |
31 | try:
32 | import wandb
33 |
34 | wandb_available = True
35 | except ImportError:
36 | wandb_available = False
37 |
38 |
39 | class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
40 | """
41 | Bert model adapted for multi-label sequence classification
42 | """
43 |
44 | def __init__(self, config, pos_weight=None):
45 | super(BertForMultiLabelSequenceClassification, self).__init__(config)
46 | self.num_labels = config.num_labels
47 | self.bert = BertModel(config)
48 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
49 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
50 | self.pos_weight = pos_weight
51 |
52 | self.init_weights()
53 |
54 | def forward(
55 | self,
56 | input_ids,
57 | attention_mask=None,
58 | token_type_ids=None,
59 | position_ids=None,
60 | head_mask=None,
61 | labels=None,
62 | ):
63 | outputs = self.bert(
64 | input_ids,
65 | attention_mask=attention_mask,
66 | token_type_ids=token_type_ids,
67 | position_ids=position_ids,
68 | head_mask=head_mask,
69 | )
70 |
71 | pooled_output = outputs[1]
72 |
73 | pooled_output = self.dropout(pooled_output)
74 | logits = self.classifier(pooled_output)
75 |
76 | outputs = (logits,) + outputs[
77 | 2:
78 | ] # add hidden states and attention if they are here
79 |
80 | if labels is not None:
81 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
82 | labels = labels.float()
83 | loss = loss_fct(
84 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
85 | )
86 | outputs = (loss,) + outputs
87 |
88 | return outputs # (loss), logits, (hidden_states), (attentions)
89 |
90 |
91 | class BertForHierarchicalMultiLabelSequenceClassification(BertPreTrainedModel):
92 | def __init__(self, config, pos_weight=None):
93 | super(BertForHierarchicalMultiLabelSequenceClassification, self).__init__(config)
94 | config.update({"mlp_size": 1024})
95 | self.num_labels = config.num_labels
96 | self.bert = BertModel(config)
97 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
98 | self.mlp = nn.Sequential(
99 | nn.Linear(config.hidden_size + config.num_labels, config.mlp_size),
100 | nn.ReLU(),
101 | nn.Linear(config.mlp_size, config.mlp_size),
102 | nn.ReLU(),
103 | )
104 | self.classifier = nn.Linear(config.mlp_size, config.num_labels)
105 | self.pos_weight = pos_weight
106 |
107 | self.init_weights()
108 |
109 | def forward(
110 | self,
111 | input_ids,
112 | attention_mask=None,
113 | token_type_ids=None,
114 | position_ids=None,
115 | head_mask=None,
116 | labels=None,
117 | parent_labels=None,
118 | ):
119 | outputs = self.bert(
120 | input_ids,
121 | attention_mask=attention_mask,
122 | token_type_ids=token_type_ids,
123 | position_ids=position_ids,
124 | head_mask=head_mask,
125 | )
126 | pooled_output = outputs[1]
127 | pooled_output = self.dropout(pooled_output)
128 | concat_output = torch.cat((pooled_output, parent_labels), dim=1)
129 | mlp_output = self.mlp(concat_output)
130 | logits = self.classifier(mlp_output)
131 |
132 | outputs = (logits,) + outputs[2:]
133 |
134 | if labels is not None:
135 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
136 | labels = labels.float()
137 | loss = loss_fct(
138 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
139 | )
140 | outputs = (loss,) + outputs
141 |
142 | return outputs # (loss), logits, (hidden_states), (attentions)
143 |
144 |
145 | class RemBertForMultiLabelSequenceClassification(RemBertPreTrainedModel):
146 | """
147 | Bert model adapted for multi-label sequence classification
148 | """
149 |
150 | def __init__(self, config, pos_weight=None):
151 | super(RemBertForMultiLabelSequenceClassification, self).__init__(config)
152 | self.num_labels = config.num_labels
153 | self.rembert = RemBertModel(config)
154 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
155 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
156 | self.pos_weight = pos_weight
157 |
158 | self.init_weights()
159 |
160 | def forward(
161 | self,
162 | input_ids,
163 | attention_mask=None,
164 | token_type_ids=None,
165 | position_ids=None,
166 | head_mask=None,
167 | labels=None,
168 | ):
169 | outputs = self.rembert(
170 | input_ids,
171 | attention_mask=attention_mask,
172 | token_type_ids=token_type_ids,
173 | position_ids=position_ids,
174 | head_mask=head_mask,
175 | )
176 |
177 | pooled_output = outputs[1]
178 |
179 | pooled_output = self.dropout(pooled_output)
180 | logits = self.classifier(pooled_output)
181 |
182 | outputs = (logits,) + outputs[
183 | 2:
184 | ] # add hidden states and attention if they are here
185 |
186 | if labels is not None:
187 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
188 | labels = labels.float()
189 | loss = loss_fct(
190 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
191 | )
192 | outputs = (loss,) + outputs
193 |
194 | return outputs # (loss), logits, (hidden_states), (attentions)
195 |
196 |
197 | class XLNetForMultiLabelSequenceClassification(XLNetPreTrainedModel):
198 | """
199 | XLNet model adapted for multi-label sequence classification
200 | """
201 |
202 | def __init__(self, config, pos_weight=None):
203 | super(XLNetForMultiLabelSequenceClassification, self).__init__(config)
204 | self.num_labels = config.num_labels
205 | self.pos_weight = pos_weight
206 |
207 | self.transformer = XLNetModel(config)
208 | self.sequence_summary = SequenceSummary(config)
209 | self.logits_proj = nn.Linear(config.d_model, config.num_labels)
210 |
211 | self.init_weights()
212 |
213 | def forward(
214 | self,
215 | input_ids=None,
216 | attention_mask=None,
217 | mems=None,
218 | perm_mask=None,
219 | target_mapping=None,
220 | token_type_ids=None,
221 | input_mask=None,
222 | head_mask=None,
223 | inputs_embeds=None,
224 | labels=None,
225 | ):
226 | transformer_outputs = self.transformer(
227 | input_ids,
228 | attention_mask=attention_mask,
229 | mems=mems,
230 | perm_mask=perm_mask,
231 | target_mapping=target_mapping,
232 | token_type_ids=token_type_ids,
233 | input_mask=input_mask,
234 | head_mask=head_mask,
235 | )
236 | output = transformer_outputs[0]
237 |
238 | output = self.sequence_summary(output)
239 | logits = self.logits_proj(output)
240 |
241 | outputs = (logits,) + transformer_outputs[
242 | 1:
243 | ] # Keep mems, hidden states, attentions if there are in it
244 |
245 | if labels is not None:
246 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
247 | labels = labels.float()
248 | loss = loss_fct(
249 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
250 | )
251 | outputs = (loss,) + outputs
252 |
253 | return outputs
254 |
255 |
256 | class XLMForMultiLabelSequenceClassification(XLMPreTrainedModel):
257 | """
258 | XLM model adapted for multi-label sequence classification
259 | """
260 |
261 | def __init__(self, config, pos_weight=None):
262 | super(XLMForMultiLabelSequenceClassification, self).__init__(config)
263 | self.num_labels = config.num_labels
264 | self.pos_weight = pos_weight
265 |
266 | self.transformer = XLMModel(config)
267 | self.sequence_summary = SequenceSummary(config)
268 |
269 | self.init_weights()
270 |
271 | def forward(
272 | self,
273 | input_ids=None,
274 | attention_mask=None,
275 | langs=None,
276 | token_type_ids=None,
277 | position_ids=None,
278 | lengths=None,
279 | cache=None,
280 | head_mask=None,
281 | inputs_embeds=None,
282 | labels=None,
283 | ):
284 | transformer_outputs = self.transformer(
285 | input_ids,
286 | attention_mask=attention_mask,
287 | langs=langs,
288 | token_type_ids=token_type_ids,
289 | position_ids=position_ids,
290 | lengths=lengths,
291 | cache=cache,
292 | head_mask=head_mask,
293 | )
294 |
295 | output = transformer_outputs[0]
296 | logits = self.sequence_summary(output)
297 |
298 | outputs = (logits,) + transformer_outputs[
299 | 1:
300 | ] # Keep new_mems and attention/hidden states if they are here
301 |
302 | if labels is not None:
303 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
304 | labels = labels.float()
305 | loss = loss_fct(
306 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
307 | )
308 | outputs = (loss,) + outputs
309 |
310 | return outputs
311 |
312 |
313 | class AlbertForMultiLabelSequenceClassification(AlbertPreTrainedModel):
314 | """
315 | Alber model adapted for multi-label sequence classification
316 | """
317 |
318 | def __init__(self, config, pos_weight=None):
319 | super(AlbertForMultiLabelSequenceClassification, self).__init__(config)
320 |
321 | self.num_labels = config.num_labels
322 | self.pos_weight = pos_weight
323 |
324 | self.albert = AlbertModel(config)
325 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
326 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
327 |
328 | self.init_weights()
329 |
330 | def forward(
331 | self,
332 | input_ids=None,
333 | attention_mask=None,
334 | token_type_ids=None,
335 | position_ids=None,
336 | head_mask=None,
337 | inputs_embeds=None,
338 | labels=None,
339 | ):
340 | outputs = self.albert(
341 | input_ids=input_ids,
342 | attention_mask=attention_mask,
343 | token_type_ids=token_type_ids,
344 | position_ids=position_ids,
345 | head_mask=head_mask,
346 | inputs_embeds=inputs_embeds,
347 | )
348 |
349 | pooled_output = outputs[1]
350 |
351 | pooled_output = self.dropout(pooled_output)
352 | logits = self.classifier(pooled_output)
353 |
354 | outputs = (logits,) + outputs[
355 | 2:
356 | ] # add hidden states and attention if they are here
357 |
358 | if labels is not None:
359 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
360 | labels = labels.float()
361 | loss = loss_fct(
362 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
363 | )
364 | outputs = (loss,) + outputs
365 |
366 | return outputs # (loss), logits, (hidden_states), (attentions)
367 |
368 |
369 | class FlaubertForMultiLabelSequenceClassification(FlaubertModel):
370 | """
371 | Flaubert model adapted for multi-label sequence classification
372 | """
373 |
374 | def __init__(self, config, pos_weight=None):
375 | super(FlaubertForMultiLabelSequenceClassification, self).__init__(config)
376 | self.num_labels = config.num_labels
377 | self.pos_weight = pos_weight
378 |
379 | self.transformer = FlaubertModel(config)
380 | self.sequence_summary = SequenceSummary(config)
381 |
382 | self.init_weights()
383 |
384 | def forward(
385 | self,
386 | input_ids=None,
387 | attention_mask=None,
388 | langs=None,
389 | token_type_ids=None,
390 | position_ids=None,
391 | lengths=None,
392 | cache=None,
393 | head_mask=None,
394 | inputs_embeds=None,
395 | labels=None,
396 | ):
397 | transformer_outputs = self.transformer(
398 | input_ids,
399 | attention_mask=attention_mask,
400 | langs=langs,
401 | token_type_ids=token_type_ids,
402 | position_ids=position_ids,
403 | lengths=lengths,
404 | cache=cache,
405 | head_mask=head_mask,
406 | )
407 |
408 | output = transformer_outputs[0]
409 | logits = self.sequence_summary(output)
410 |
411 | outputs = (logits,) + transformer_outputs[
412 | 1:
413 | ] # Keep new_mems and attention/hidden states if they are here
414 |
415 | if labels is not None:
416 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
417 | labels = labels.float()
418 | loss = loss_fct(
419 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
420 | )
421 | outputs = (loss,) + outputs
422 |
423 | return outputs
424 |
425 |
426 | class LongformerForMultiLabelSequenceClassification(LongformerPreTrainedModel):
427 | """
428 | Longformer model adapted for multilabel sequence classification.
429 | """
430 |
431 | def __init__(self, config, pos_weight=None):
432 | super(LongformerForMultiLabelSequenceClassification, self).__init__(config)
433 | self.num_labels = config.num_labels
434 | self.pos_weight = pos_weight
435 |
436 | self.longformer = LongformerModel(config)
437 | self.classifier = LongformerClassificationHead(config)
438 |
439 | self.init_weights()
440 |
441 | def forward(
442 | self,
443 | input_ids=None,
444 | attention_mask=None,
445 | global_attention_mask=None,
446 | token_type_ids=None,
447 | position_ids=None,
448 | inputs_embeds=None,
449 | labels=None,
450 | ):
451 | if global_attention_mask is None:
452 | global_attention_mask = torch.zeros_like(input_ids)
453 | # global attention on cls token
454 | global_attention_mask[:, 0] = 1
455 |
456 | outputs = self.longformer(
457 | input_ids,
458 | attention_mask=attention_mask,
459 | global_attention_mask=global_attention_mask,
460 | token_type_ids=token_type_ids,
461 | position_ids=position_ids,
462 | )
463 | sequence_output = outputs[0]
464 | logits = self.classifier(sequence_output)
465 |
466 | outputs = (logits,) + outputs[2:]
467 | if labels is not None:
468 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
469 | labels = labels.float()
470 | loss = loss_fct(
471 | logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
472 | )
473 | outputs = (loss,) + outputs
474 |
475 | return outputs
476 |
--------------------------------------------------------------------------------
/pytextclassifier/classic_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: Classic Classifier, support Naive Bayes, Logistic Regression, Random Forest, SVM, XGBoost
5 | and so on sklearn classification model
6 | """
7 | import argparse
8 | import os
9 | import sys
10 | import pickle
11 | import numpy as np
12 | from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
13 | from sklearn import metrics
14 | from sklearn.ensemble import RandomForestClassifier
15 | from sklearn.linear_model import LogisticRegression
16 | from sklearn.naive_bayes import MultinomialNB
17 | from sklearn.neighbors import KNeighborsClassifier
18 | from sklearn.svm import SVC
19 | from sklearn.tree import DecisionTreeClassifier
20 | from sklearn.model_selection import train_test_split
21 | from loguru import logger
22 |
23 | sys.path.append('..')
24 | from pytextclassifier.base_classifier import ClassifierABC, load_data
25 | from pytextclassifier.tokenizer import Tokenizer
26 |
27 | pwd_path = os.path.abspath(os.path.dirname(__file__))
28 | default_stopwords_path = os.path.join(pwd_path, 'stopwords.txt')
29 |
30 |
31 | class ClassicClassifier(ClassifierABC):
32 | def __init__(
33 | self,
34 | output_dir="outputs",
35 | model_name_or_model='lr',
36 | feature_name_or_feature='tfidf',
37 | stopwords_path=default_stopwords_path,
38 | tokenizer=None
39 | ):
40 | """
41 | 经典机器学习分类模型,支持lr, random_forest, decision_tree, knn, bayes, svm, xgboost
42 | @param output_dir: 模型保存路径
43 | @param model_name_or_model:
44 | @param feature_name_or_feature:
45 | @param stopwords_path:
46 | @param tokenizer: 切词器,默认为jieba切词
47 | """
48 | self.output_dir = output_dir
49 | if isinstance(model_name_or_model, str):
50 | model_name = model_name_or_model.lower()
51 | if model_name not in ['lr', 'random_forest', 'decision_tree', 'knn', 'bayes', 'xgboost', 'svm']:
52 | raise ValueError('model_name not found.')
53 | logger.debug(f'model_name: {model_name}')
54 | self.model = self.get_model(model_name)
55 | elif hasattr(model_name_or_model, 'fit'):
56 | self.model = model_name_or_model
57 | else:
58 | raise ValueError('model_name_or_model set error.')
59 | if isinstance(feature_name_or_feature, str):
60 | feature_name = feature_name_or_feature.lower()
61 | if feature_name not in ['tfidf', 'count']:
62 | raise ValueError('feature_name not found.')
63 | logger.debug(f'feature_name: {feature_name}')
64 | if feature_name == 'tfidf':
65 | self.feature = TfidfVectorizer(ngram_range=(1, 2))
66 | else:
67 | self.feature = CountVectorizer(ngram_range=(1, 2))
68 | elif hasattr(feature_name_or_feature, 'fit_transform'):
69 | self.feature = feature_name_or_feature
70 | else:
71 | raise ValueError('feature_name_or_feature set error.')
72 | self.is_trained = False
73 | self.stopwords = set(self.load_list(stopwords_path)) if stopwords_path and os.path.exists(
74 | stopwords_path) else set()
75 | self.tokenizer = tokenizer if tokenizer else Tokenizer()
76 |
77 | def __str__(self):
78 | return f'ClassicClassifier instance ({self.model}, stopwords size: {len(self.stopwords)})'
79 |
80 | @staticmethod
81 | def get_model(model_type):
82 | if model_type in ["lr", "logistic_regression"]:
83 | model = LogisticRegression(solver='lbfgs', fit_intercept=False) # 快,准确率一般。val mean acc:0.91
84 | elif model_type == "random_forest":
85 | model = RandomForestClassifier(n_estimators=300) # 速度还行,准确率一般。val mean acc:0.93125
86 | elif model_type == "decision_tree":
87 | model = DecisionTreeClassifier() # 速度快,准确率低。val mean acc:0.62
88 | elif model_type == "knn":
89 | model = KNeighborsClassifier() # 速度一般,准确率低。val mean acc:0.675
90 | elif model_type == "bayes":
91 | model = MultinomialNB(alpha=0.1, fit_prior=False) # 速度快,准确率低。val mean acc:0.62
92 | elif model_type == "xgboost":
93 | try:
94 | from xgboost import XGBClassifier
95 | except ImportError:
96 | raise ImportError('xgboost not installed, please install it with "pip install xgboost"')
97 | model = XGBClassifier() # 速度慢,准确率高。val mean acc:0.95
98 | elif model_type == "svm":
99 | model = SVC(kernel='linear', probability=True) # 速度慢,准确率高,val mean acc:0.945
100 | else:
101 | raise ValueError('model type set error.')
102 | return model
103 |
104 | @staticmethod
105 | def load_list(path):
106 | return [word for word in open(path, 'r', encoding='utf-8').read().split()]
107 |
108 | def tokenize_sentences(self, sentences):
109 | """
110 | Tokenize input text
111 | :param sentences: list of text, eg: [text1, text2, ...]
112 | :return: X_tokens
113 | """
114 | X_tokens = [' '.join([w for w in self.tokenizer.tokenize(line) if w not in self.stopwords]) for line in
115 | sentences]
116 | return X_tokens
117 |
118 | def load_pkl(self, pkl_path):
119 | """
120 | 加载词典文件
121 | :param pkl_path:
122 | :return:
123 | """
124 | with open(pkl_path, 'rb') as f:
125 | result = pickle.load(f)
126 | return result
127 |
128 | def save_pkl(self, vocab, pkl_path, overwrite=True):
129 | """
130 | 存储文件
131 | :param pkl_path:
132 | :param overwrite:
133 | :return:
134 | """
135 | if pkl_path and os.path.exists(pkl_path) and not overwrite:
136 | return
137 | if pkl_path:
138 | with open(pkl_path, 'wb') as f:
139 | pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL) # python3
140 |
141 | def train(self, data_list_or_path, header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1):
142 | """
143 | Train model with data_list_or_path and save model to output_dir
144 | @param data_list_or_path:
145 | @param header:
146 | @param names:
147 | @param delimiter:
148 | @param test_size:
149 | @return:
150 | """
151 | # load data
152 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True)
153 | # split validation set
154 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=1)
155 | # train model
156 | logger.debug(f"X_train size: {len(X_train)}, X_test size: {len(X_test)}")
157 | assert len(X_train) == len(y_train)
158 | logger.debug(f'X_train sample:\n{X_train[:3]}\ny_train sample:\n{y_train[:3]}')
159 | logger.debug(f'num_classes:{len(set(y))}')
160 | # tokenize text
161 | X_train_tokens = self.tokenize_sentences(X_train)
162 | logger.debug(f'X_train_tokens sample:\n{X_train_tokens[:3]}')
163 | X_train_feat = self.feature.fit_transform(X_train_tokens)
164 | # fit
165 | self.model.fit(X_train_feat, y_train)
166 | self.is_trained = True
167 | # evaluate
168 | test_acc = self.evaluate(X_test, y_test)
169 | logger.debug(f'evaluate, X size: {len(X_test)}, y size: {len(y_test)}, acc: {test_acc}')
170 | # save model
171 | self.save_model()
172 | return test_acc
173 |
174 | def predict(self, sentences: list):
175 | """
176 | Predict labels and label probability for sentences.
177 | @param sentences: list, input text list, eg: [text1, text2, ...]
178 | @return: predict_label, predict_prob
179 | """
180 | if not self.is_trained:
181 | raise ValueError('model not trained.')
182 | # tokenize text
183 | X_tokens = self.tokenize_sentences(sentences)
184 | # transform
185 | X_feat = self.feature.transform(X_tokens)
186 | predict_labels = self.model.predict(X_feat)
187 | probs = self.model.predict_proba(X_feat)
188 | predict_probs = [prob[np.where(self.model.classes_ == label)][0] for label, prob in zip(predict_labels, probs)]
189 | return predict_labels, predict_probs
190 |
191 | def evaluate_model(self, data_list_or_path, header=None, names=('labels', 'text'), delimiter='\t'):
192 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter)
193 | return self.evaluate(X_test, y_test)
194 |
195 | def evaluate(self, X_test, y_test):
196 | """
197 | Evaluate model.
198 | @param X_test:
199 | @param y_test:
200 | @return: accuracy score
201 | """
202 | if not self.is_trained:
203 | raise ValueError('model not trained.')
204 | # evaluate the model
205 | y_pred, _ = self.predict(X_test)
206 | acc_score = metrics.accuracy_score(y_test, y_pred)
207 | return acc_score
208 |
209 | def load_model(self):
210 | """
211 | Load model from output_dir
212 | @return:
213 | """
214 | model_path = os.path.join(self.output_dir, 'classifier_model.pkl')
215 | if os.path.exists(model_path):
216 | self.model = self.load_pkl(model_path)
217 | feature_path = os.path.join(self.output_dir, 'classifier_feature.pkl')
218 | self.feature = self.load_pkl(feature_path)
219 | logger.info(f'Loaded model: {model_path}.')
220 | self.is_trained = True
221 | else:
222 | logger.error(f'{model_path} not exists.')
223 | self.is_trained = False
224 | return self.is_trained
225 |
226 | def save_model(self):
227 | """
228 | Save model to output_dir
229 | @return:
230 | """
231 | if self.output_dir:
232 | os.makedirs(self.output_dir, exist_ok=True)
233 | if self.is_trained:
234 | feature_path = os.path.join(self.output_dir, 'classifier_feature.pkl')
235 | self.save_pkl(self.feature, feature_path)
236 | model_path = os.path.join(self.output_dir, 'classifier_model.pkl')
237 | self.save_pkl(self.model, model_path)
238 | logger.info(f'Saved model: {model_path}, feature_path: {feature_path}')
239 | else:
240 | logger.error('model is not trained, please train model first')
241 | return self.model, self.feature
242 |
243 |
244 | if __name__ == '__main__':
245 | parser = argparse.ArgumentParser(description='Text Classification')
246 | parser.add_argument('--model_name', default='lr', type=str, help='model name')
247 | parser.add_argument('--output_dir', default='models/lr', type=str, help='saved model dir')
248 | parser.add_argument('--feature_name', default='tfidf', type=str, help='feature name')
249 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'),
250 | type=str, help='sample data file path')
251 | args = parser.parse_args()
252 | print(args)
253 | # create model
254 | m = ClassicClassifier(output_dir=args.output_dir, model_name_or_model=args.model_name,
255 | feature_name_or_feature=args.feature_name)
256 | # train model
257 | m.train(args.data_path)
258 | # load best trained model and predict
259 | m.load_model()
260 | X, y, _ = load_data(args.data_path)
261 | X = X[:5]
262 | y = y[:5]
263 | predict_labels, predict_probs = m.predict(X)
264 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y):
265 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth)
266 |
--------------------------------------------------------------------------------
/pytextclassifier/data_helper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import json
7 | import random
8 |
9 | import numpy as np
10 | import torch
11 | from loguru import logger
12 | from tqdm import tqdm
13 |
14 |
15 | def set_seed(seed):
16 | """
17 | Set seed for random number generators.
18 | """
19 | logger.info(f"Set seed for random, numpy and torch: {seed}")
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | if torch.cuda.is_available():
24 | torch.cuda.manual_seed_all(seed)
25 |
26 |
27 | def build_vocab(contents, tokenizer, max_size, min_freq, unk_token, pad_token):
28 | vocab_dic = {}
29 | for line in tqdm(contents):
30 | line = line.strip()
31 | if not line:
32 | continue
33 | content = line.split('\t')[0]
34 | for word in tokenizer(content):
35 | vocab_dic[word] = vocab_dic.get(word, 0) + 1
36 | vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[
37 | :max_size]
38 | vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
39 | vocab_dic.update({unk_token: len(vocab_dic), pad_token: len(vocab_dic) + 1})
40 | return vocab_dic
41 |
42 |
43 | def load_vocab(vocab_path):
44 | """
45 | :param vocab_path:
46 | :return:
47 | """
48 | with open(vocab_path, 'r', encoding='utf-8') as fr:
49 | vocab = json.load(fr)
50 | return vocab
51 |
--------------------------------------------------------------------------------
/pytextclassifier/fasttext_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: Fasttext Classifier
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import sys
10 | import time
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from loguru import logger
17 | from sklearn import metrics
18 | from sklearn.model_selection import train_test_split
19 |
20 | sys.path.append('..')
21 | from pytextclassifier.base_classifier import ClassifierABC, load_data
22 | from pytextclassifier.data_helper import set_seed, build_vocab, load_vocab
23 | from pytextclassifier.time_util import get_time_spend
24 |
25 | pwd_path = os.path.abspath(os.path.dirname(__file__))
26 | device = 'cuda' if torch.cuda.is_available() else (
27 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')
28 |
29 |
30 | def build_dataset(
31 | tokenizer, X, y, word_vocab_path, label_vocab_path, max_vocab_size=10000,
32 | max_seq_length=128, unk_token='[UNK]', pad_token='[PAD]'
33 | ):
34 | if os.path.exists(word_vocab_path):
35 | word_id_map = load_vocab(word_vocab_path)
36 | else:
37 | word_id_map = build_vocab(X, tokenizer=tokenizer, max_size=max_vocab_size, min_freq=1,
38 | unk_token=unk_token, pad_token=pad_token)
39 | json.dump(word_id_map, open(word_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
40 | logger.debug('save vocab_path: {}'.format(word_vocab_path))
41 | logger.debug(f"word vocab size: {len(word_id_map)}, word_vocab_path: {word_vocab_path}")
42 |
43 | if os.path.exists(label_vocab_path):
44 | label_id_map = load_vocab(label_vocab_path)
45 | else:
46 | id_label_map = {id: v for id, v in enumerate(set(y.tolist()))}
47 | label_id_map = {v: k for k, v in id_label_map.items()}
48 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
49 | logger.debug('save label_vocab_path: {}'.format(label_vocab_path))
50 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}")
51 |
52 | def load_dataset(X, y, max_seq_length=128):
53 | contents = []
54 | for content, label in zip(X, y):
55 | words_line = []
56 | token = tokenizer(content)
57 | seq_len = len(token)
58 | if max_seq_length:
59 | if len(token) < max_seq_length:
60 | token.extend([pad_token] * (max_seq_length - len(token)))
61 | else:
62 | token = token[:max_seq_length]
63 | seq_len = max_seq_length
64 | # word to id
65 | for word in token:
66 | words_line.append(word_id_map.get(word, word_id_map.get(unk_token)))
67 | label_id = label_id_map.get(label)
68 | contents.append((words_line, label_id, seq_len, None, None))
69 | return contents
70 |
71 | dataset = load_dataset(X, y, max_seq_length)
72 | return dataset, word_id_map, label_id_map
73 |
74 |
75 | class DatasetIterater:
76 | def __init__(self, dataset, device, batch_size=32, enable_ngram=True, n_gram_vocab=250499, max_seq_length=128):
77 | self.batch_size = batch_size
78 | self.dataset = dataset
79 | self.n_batches = len(dataset) // batch_size if len(dataset) > batch_size else 1
80 | self.residue = False # 记录batch数量是否为整数
81 | if len(dataset) % self.n_batches != 0:
82 | self.residue = True
83 | self.index = 0
84 | self.device = device
85 | self.enable_ngram = enable_ngram
86 | self.n_gram_vocab = n_gram_vocab
87 | self.max_seq_length = max_seq_length
88 |
89 | def _to_tensor(self, datas):
90 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
91 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
92 |
93 | def biGramHash(sequence, t, buckets):
94 | t1 = sequence[t - 1] if t - 1 >= 0 else 0
95 | return (t1 * 14918087) % buckets
96 |
97 | def triGramHash(sequence, t, buckets):
98 | t1 = sequence[t - 1] if t - 1 >= 0 else 0
99 | t2 = sequence[t - 2] if t - 2 >= 0 else 0
100 | return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
101 |
102 | # calculate bigram and trigram here fore memory efficiency
103 | bigram = []
104 | trigram = []
105 | for _ in datas:
106 | words_line = _[0]
107 | bi = []
108 | tri = []
109 | if self.enable_ngram:
110 | buckets = self.n_gram_vocab
111 | # ------ngram------
112 | for i in range(self.max_seq_length):
113 | bi.append(biGramHash(words_line, i, buckets))
114 | tri.append(triGramHash(words_line, i, buckets))
115 | # -----------------
116 | else:
117 | bi = [0] * self.max_seq_length
118 | tri = [0] * self.max_seq_length
119 | bigram.append(bi)
120 | trigram.append(tri)
121 | bigram = torch.LongTensor(bigram).to(self.device)
122 | trigram = torch.LongTensor(trigram).to(self.device)
123 |
124 | # pad_token前的长度(超过max_seq_length的设为max_seq_length)
125 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
126 | return (x, seq_len, bigram, trigram), y
127 |
128 | def __next__(self):
129 | if self.residue and self.index == self.n_batches:
130 | batches = self.dataset[self.index * self.batch_size: len(self.dataset)]
131 | self.index += 1
132 | batches = self._to_tensor(batches)
133 | return batches
134 |
135 | elif self.index >= self.n_batches:
136 | self.index = 0
137 | raise StopIteration
138 | else:
139 | batches = self.dataset[self.index * self.batch_size: (self.index + 1) * self.batch_size]
140 | self.index += 1
141 | batches = self._to_tensor(batches)
142 | return batches
143 |
144 | def __iter__(self):
145 | return self
146 |
147 | def __len__(self):
148 | if self.residue:
149 | return self.n_batches + 1
150 | else:
151 | return self.n_batches
152 |
153 |
154 | def build_iterator(
155 | dataset,
156 | device,
157 | batch_size=32,
158 | enable_ngram=True,
159 | n_gram_vocab=250499,
160 | max_seq_length=128
161 | ):
162 | return DatasetIterater(dataset, device, batch_size, enable_ngram, n_gram_vocab, max_seq_length)
163 |
164 |
165 | class FastTextModel(nn.Module):
166 | """Bag of Tricks for Efficient Text Classification"""
167 |
168 | def __init__(
169 | self,
170 | vocab_size,
171 | num_classes,
172 | embed_size=200,
173 | n_gram_vocab=250499,
174 | hidden_size=256,
175 | dropout_rate=0.5
176 | ):
177 | super().__init__()
178 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size - 1)
179 | self.embedding_ngram2 = nn.Embedding(n_gram_vocab, embed_size)
180 | self.embedding_ngram3 = nn.Embedding(n_gram_vocab, embed_size)
181 | self.dropout = nn.Dropout(dropout_rate)
182 | self.fc1 = nn.Linear(embed_size * 3, hidden_size)
183 | self.fc2 = nn.Linear(hidden_size, num_classes)
184 |
185 | def forward(self, x):
186 | out_word = self.embedding(x[0])
187 | out_bigram = self.embedding_ngram2(x[2])
188 | out_trigram = self.embedding_ngram3(x[3])
189 | out = torch.cat((out_word, out_bigram, out_trigram), -1)
190 |
191 | out = out.mean(dim=1)
192 | out = self.dropout(out)
193 | out = self.fc1(out)
194 | out = F.relu(out)
195 | out = self.fc2(out)
196 | return out
197 |
198 |
199 | class FastTextClassifier(ClassifierABC):
200 | def __init__(
201 | self,
202 | output_dir="outputs",
203 | dropout_rate=0.5, batch_size=64, max_seq_length=128,
204 | embed_size=200, hidden_size=256, n_gram_vocab=250499,
205 | max_vocab_size=10000, unk_token='[UNK]', pad_token='[PAD]',
206 | tokenizer=None,
207 | enable_ngram=True,
208 | ):
209 | """
210 | 初始化
211 | @param output_dir: 模型的保存路径
212 | @param dropout_rate: 随机失活
213 | @param batch_size: mini-batch大小
214 | @param max_seq_length: 每句话处理成的长度(短填长切)
215 | @param embed_size: 字向量维度
216 | @param hidden_size: 隐藏层大小
217 | @param n_gram_vocab: ngram 词表大小
218 | @param max_vocab_size: 词表长度限制
219 | @param unk_token: 未知字
220 | @param pad_token: padding符号
221 | @param tokenizer: 切词器
222 | @param enable_ngram: 是否使用ngram
223 | """
224 | self.output_dir = output_dir
225 | self.is_trained = False
226 | self.model = None
227 | logger.debug(f'Device: {device}')
228 | self.dropout_rate = dropout_rate
229 | self.batch_size = batch_size
230 | self.max_seq_length = max_seq_length
231 | self.embed_size = embed_size
232 | self.hidden_size = hidden_size
233 | self.n_gram_vocab = n_gram_vocab
234 | self.max_vocab_size = max_vocab_size
235 | self.unk_token = unk_token
236 | self.pad_token = pad_token
237 | self.tokenizer = tokenizer if tokenizer else lambda x: [y for y in x] # char-level
238 | self.enable_ngram = enable_ngram
239 |
240 | def __str__(self):
241 | return f'FasttextClassifier instance ({self.model})'
242 |
243 | def train(
244 | self,
245 | data_list_or_path,
246 | header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1,
247 | num_epochs=20, learning_rate=1e-3,
248 | require_improvement=1000, evaluate_during_training_steps=100
249 | ):
250 | """
251 | Train model with data_list_or_path and save model to output_dir
252 | @param data_list_or_path:
253 | @param header:
254 | @param names:
255 | @param delimiter:
256 | @param test_size:
257 | @param num_epochs: epoch数
258 | @param learning_rate: 学习率
259 | @param require_improvement: 若超过1000batch效果还没提升,则提前结束训练
260 | @param evaluate_during_training_steps: 每隔多少step评估一次模型
261 | @return:
262 | """
263 | logger.debug('train model...')
264 | SEED = 1
265 | set_seed(SEED)
266 | # load data
267 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True)
268 | del data_df
269 | output_dir = self.output_dir
270 | if output_dir:
271 | os.makedirs(output_dir, exist_ok=True)
272 | word_vocab_path = os.path.join(output_dir, 'word_vocab.json')
273 | label_vocab_path = os.path.join(output_dir, 'label_vocab.json')
274 | save_model_path = os.path.join(output_dir, 'model.pth')
275 |
276 | dataset, self.word_id_map, self.label_id_map = build_dataset(
277 | self.tokenizer, X, y, word_vocab_path,
278 | label_vocab_path,
279 | max_vocab_size=self.max_vocab_size,
280 | max_seq_length=self.max_seq_length,
281 | unk_token=self.unk_token,
282 | pad_token=self.pad_token
283 | )
284 | train_data, dev_data = train_test_split(dataset, test_size=test_size, random_state=SEED)
285 | logger.debug(f"train_data size: {len(train_data)}, dev_data size: {len(dev_data)}")
286 | logger.debug(f'train_data sample:\n{train_data[:3]}\ndev_data sample:\n{dev_data[:3]}')
287 | train_iter = build_iterator(train_data, device, self.batch_size, self.enable_ngram,
288 | self.n_gram_vocab, self.max_seq_length)
289 | dev_iter = build_iterator(dev_data, device, self.batch_size, self.enable_ngram,
290 | self.n_gram_vocab, self.max_seq_length)
291 | # create model
292 | vocab_size = len(self.word_id_map)
293 | num_classes = len(self.label_id_map)
294 | logger.debug(f'vocab_size:{vocab_size}', 'num_classes:', num_classes)
295 | self.model = FastTextModel(
296 | vocab_size, num_classes, self.embed_size, self.n_gram_vocab, self.hidden_size,
297 | self.dropout_rate
298 | )
299 | self.model.to(device)
300 | # init_network(self.model)
301 | logger.info(self.model.parameters)
302 | # train model
303 | history = self.train_model_from_data_iterator(
304 | save_model_path, train_iter, dev_iter, num_epochs, learning_rate,
305 | require_improvement, evaluate_during_training_steps
306 | )
307 | self.is_trained = True
308 | logger.debug('train model done')
309 | return history
310 |
311 | def train_model_from_data_iterator(
312 | self, save_model_path, train_iter, dev_iter,
313 | num_epochs=10, learning_rate=1e-3,
314 | require_improvement=1000, evaluate_during_training_steps=100
315 | ):
316 | history = []
317 | # train
318 | start_time = time.time()
319 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
320 |
321 | total_batch = 0 # 记录进行到多少batch
322 | dev_best_loss = 1e10
323 | last_improve = 0 # 记录上次验证集loss下降的batch数
324 | flag = False # 记录是否很久没有效果提升
325 | for epoch in range(num_epochs):
326 | logger.debug('Epoch [{}/{}]'.format(epoch + 1, num_epochs))
327 | for i, (trains, labels) in enumerate(train_iter):
328 | self.model.train()
329 | outputs = self.model(trains)
330 | loss = F.cross_entropy(outputs, labels)
331 | # compute gradient and do SGD step
332 | optimizer.zero_grad()
333 | loss.backward()
334 | optimizer.step()
335 | if total_batch % evaluate_during_training_steps == 0:
336 | # 输出在训练集和验证集上的效果
337 | y_true = labels.cpu()
338 | y_pred = torch.max(outputs, 1)[1].cpu()
339 | train_acc = metrics.accuracy_score(y_true, y_pred)
340 | if dev_iter is not None:
341 | dev_acc, dev_loss = self.evaluate(dev_iter)
342 | if dev_loss < dev_best_loss:
343 | dev_best_loss = dev_loss
344 | torch.save(self.model.state_dict(), save_model_path)
345 | logger.debug(f'Saved model: {save_model_path}')
346 | improve = '*'
347 | last_improve = total_batch
348 | else:
349 | improve = ''
350 | time_dif = get_time_spend(start_time)
351 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},' \
352 | 'Val Loss:{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'.format(
353 | total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)
354 | else:
355 | time_dif = get_time_spend(start_time)
356 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Time:{3}'.format(
357 | total_batch, loss.item(), train_acc, time_dif)
358 | logger.debug(msg)
359 | history.append(msg)
360 | self.model.train()
361 | total_batch += 1
362 | if total_batch - last_improve > require_improvement:
363 | # 验证集loss超过1000batch没下降,结束训练
364 | logger.debug("No optimization for a long time, auto-stopping...")
365 | flag = True
366 | break
367 | if flag:
368 | break
369 | return history
370 |
371 | def predict(self, sentences: list):
372 | """
373 | Predict labels and label probability for sentences.
374 | @param sentences: list, input text list, eg: [text1, text2, ...]
375 | @return: predict_label, predict_prob
376 | """
377 | if not self.is_trained:
378 | raise ValueError('model not trained.')
379 | self.model.eval()
380 |
381 | def biGramHash(sequence, t, buckets):
382 | t1 = sequence[t - 1] if t - 1 >= 0 else 0
383 | return (t1 * 14918087) % buckets
384 |
385 | def triGramHash(sequence, t, buckets):
386 | t1 = sequence[t - 1] if t - 1 >= 0 else 0
387 | t2 = sequence[t - 2] if t - 2 >= 0 else 0
388 | return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
389 |
390 | def load_dataset(X, max_seq_length=128):
391 | contents = []
392 | for content in X:
393 | words_line = []
394 | token = self.tokenizer(content)
395 | seq_len = len(token)
396 | if max_seq_length:
397 | if len(token) < max_seq_length:
398 | token.extend([self.pad_token] * (max_seq_length - len(token)))
399 | else:
400 | token = token[:max_seq_length]
401 | seq_len = max_seq_length
402 | # word to id
403 | for word in token:
404 | words_line.append(self.word_id_map.get(word, self.word_id_map.get(self.unk_token)))
405 | # fasttext ngram
406 | bigram = []
407 | trigram = []
408 | if self.enable_ngram:
409 | buckets = self.n_gram_vocab
410 | # ------ngram------
411 | for i in range(max_seq_length):
412 | bigram.append(biGramHash(words_line, i, buckets))
413 | trigram.append(triGramHash(words_line, i, buckets))
414 | # -----------------
415 | else:
416 | bigram = [0] * max_seq_length
417 | trigram = [0] * max_seq_length
418 | contents.append((words_line, 0, seq_len, bigram, trigram))
419 | return contents
420 |
421 | data = load_dataset(sentences, self.max_seq_length)
422 | data_iter = build_iterator(data, device, self.batch_size)
423 | # predict probs
424 | predict_all = np.array([], dtype=int)
425 | proba_all = np.array([], dtype=float)
426 | with torch.no_grad():
427 | for texts, _ in data_iter:
428 | outputs = self.model(texts)
429 | logit = F.softmax(outputs, dim=1).detach().cpu().numpy()
430 | pred = np.argmax(logit, axis=1)
431 | proba = np.max(logit, axis=1)
432 |
433 | predict_all = np.append(predict_all, pred)
434 | proba_all = np.append(proba_all, proba)
435 | id_label_map = {v: k for k, v in self.label_id_map.items()}
436 | predict_labels = [id_label_map.get(i) for i in predict_all]
437 | predict_probs = proba_all.tolist()
438 | return predict_labels, predict_probs
439 |
440 | def evaluate_model(self, data_list_or_path, header=None,
441 | names=('labels', 'text'), delimiter='\t'):
442 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter)
443 | self.load_model()
444 | data, word_id_map, label_id_map = build_dataset(
445 | self.tokenizer, X_test, y_test,
446 | self.word_vocab_path,
447 | self.label_vocab_path,
448 | max_vocab_size=self.max_vocab_size,
449 | max_seq_length=self.max_seq_length,
450 | unk_token=self.unk_token,
451 | pad_token=self.pad_token,
452 | )
453 | data_iter = build_iterator(data, device, self.batch_size)
454 | return self.evaluate(data_iter)[0]
455 |
456 | def evaluate(self, data_iter):
457 | """
458 | Evaluate model.
459 | @param data_iter:
460 | @return: accuracy score, loss
461 | """
462 | if not self.model:
463 | raise ValueError('model not trained.')
464 | self.model.eval()
465 | loss_total = 0.0
466 | predict_all = np.array([], dtype=int)
467 | labels_all = np.array([], dtype=int)
468 | with torch.no_grad():
469 | for texts, labels in data_iter:
470 | outputs = self.model(texts)
471 | loss = F.cross_entropy(outputs, labels)
472 | loss_total += loss
473 | labels = labels.cpu().numpy()
474 | predic = torch.max(outputs, 1)[1].cpu().numpy()
475 | labels_all = np.append(labels_all, labels)
476 | predict_all = np.append(predict_all, predic)
477 | logger.debug(f"evaluate, last batch, y_true: {labels}, y_pred: {predic}")
478 | acc = metrics.accuracy_score(labels_all, predict_all)
479 | return acc, loss_total / len(data_iter)
480 |
481 | def load_model(self):
482 | """
483 | Load model from output_dir
484 | @return:
485 | """
486 | model_path = os.path.join(self.output_dir, 'model.pth')
487 | if os.path.exists(model_path):
488 | self.word_vocab_path = os.path.join(self.output_dir, 'word_vocab.json')
489 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json')
490 | self.word_id_map = load_vocab(self.word_vocab_path)
491 | self.label_id_map = load_vocab(self.label_vocab_path)
492 | vocab_size = len(self.word_id_map)
493 | num_classes = len(self.label_id_map)
494 | self.model = FastTextModel(
495 | vocab_size, num_classes, self.embed_size, self.n_gram_vocab, self.hidden_size,
496 | self.dropout_rate
497 | )
498 | self.model.load_state_dict(torch.load(model_path, map_location=device))
499 | self.model.to(device)
500 | self.is_trained = True
501 | else:
502 | logger.error(f'{model_path} not exists.')
503 | self.is_trained = False
504 | return self.is_trained
505 |
506 |
507 | if __name__ == '__main__':
508 | parser = argparse.ArgumentParser(description='Text Classification')
509 | parser.add_argument('--output_dir', default='models/fasttext', type=str, help='save model dir')
510 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'),
511 | type=str, help='sample data file path')
512 | args = parser.parse_args()
513 | print(args)
514 | # create model
515 | m = FastTextClassifier(args.output_dir)
516 | # train model
517 | m.train(data_list_or_path=args.data_path, num_epochs=3)
518 | # load trained model and predict
519 | m.load_model()
520 | print('best model loaded from file, and predict')
521 | X, y, _ = load_data(args.data_path)
522 | X = X[:5]
523 | y = y[:5]
524 | predict_labels, predict_probs = m.predict(X)
525 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y):
526 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth)
527 |
--------------------------------------------------------------------------------
/pytextclassifier/textcluster.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import os
7 | import sys
8 | from codecs import open
9 | import pickle
10 | from sklearn.cluster import MiniBatchKMeans
11 | from sklearn.feature_extraction.text import TfidfVectorizer
12 | from loguru import logger
13 |
14 | sys.path.append('..')
15 | from pytextclassifier.tokenizer import Tokenizer
16 |
17 | pwd_path = os.path.abspath(os.path.dirname(__file__))
18 | default_stopwords_path = os.path.join(pwd_path, 'stopwords.txt')
19 |
20 |
21 | class TextCluster(object):
22 | def __init__(
23 | self,
24 | output_dir="outputs",
25 | model=None, tokenizer=None, feature=None,
26 | stopwords_path=default_stopwords_path,
27 | n_clusters=3, n_init=10, ngram_range=(1, 2), **kwargs
28 | ):
29 | self.output_dir = output_dir
30 | self.model = model if model else MiniBatchKMeans(n_clusters=n_clusters, n_init=n_init)
31 | self.tokenizer = tokenizer if tokenizer else Tokenizer()
32 | self.feature = feature if feature else TfidfVectorizer(ngram_range=ngram_range, **kwargs)
33 | self.stopwords = set(self.load_list(stopwords_path)) if stopwords_path and os.path.exists(
34 | stopwords_path) else set()
35 | self.is_trained = False
36 |
37 | def __str__(self):
38 | return 'TextCluster instance ({}, {}, {})'.format(self.model, self.tokenizer, self.feature)
39 |
40 | @staticmethod
41 | def load_file_data(file_path, sep='\t', use_col=1):
42 | """
43 | Load text file, format(txt): text
44 | :param file_path: str
45 | :param sep: \t
46 | :param use_col: int or None
47 | :return: list, text list
48 | """
49 | contents = []
50 | if not os.path.exists(file_path):
51 | raise ValueError('file not found. path: {}'.format(file_path))
52 | with open(file_path, "r", encoding='utf-8') as f:
53 | for line in f:
54 | line = line.strip()
55 | if use_col:
56 | contents.append(line.split(sep)[use_col])
57 | else:
58 | contents.append(line)
59 | logger.info('load file done. path: {}, size: {}'.format(file_path, len(contents)))
60 | return contents
61 |
62 | @staticmethod
63 | def load_list(path):
64 | """
65 | 加载停用词
66 | :param path:
67 | :return: list
68 | """
69 | return [word for word in open(path, 'r', encoding='utf-8').read().split()]
70 |
71 | @staticmethod
72 | def load_pkl(pkl_path):
73 | """
74 | 加载词典文件
75 | :param pkl_path:
76 | :return:
77 | """
78 | with open(pkl_path, 'rb') as f:
79 | result = pickle.load(f)
80 | return result
81 |
82 | @staticmethod
83 | def save_pkl(vocab, pkl_path, overwrite=True):
84 | """
85 | 存储文件
86 | :param pkl_path:
87 | :param overwrite:
88 | :return:
89 | """
90 | if pkl_path and os.path.exists(pkl_path) and not overwrite:
91 | return
92 | if pkl_path:
93 | with open(pkl_path, 'wb') as f:
94 | pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL)
95 | # pickle.dump(vocab, f, protocol=2) # 兼容python2和python3
96 | print("save %s ok." % pkl_path)
97 |
98 | @staticmethod
99 | def show_plt(feature_matrix, labels, image_file='cluster.png'):
100 | """
101 | Show cluster plt
102 | :param feature_matrix:
103 | :param labels:
104 | :param image_file:
105 | :return:
106 | """
107 | from sklearn.decomposition import TruncatedSVD
108 | import matplotlib.pyplot as plt
109 | svd = TruncatedSVD()
110 | plot_columns = svd.fit_transform(feature_matrix)
111 | plt.scatter(x=plot_columns[:, 0], y=plot_columns[:, 1], c=labels)
112 | if image_file:
113 | plt.savefig(image_file)
114 | plt.show()
115 |
116 | def tokenize_sentences(self, sentences):
117 | """
118 | Encoding input text
119 | :param sentences: list of text, eg: [text1, text2, ...]
120 | :return: X_tokens
121 | """
122 | X_tokens = [' '.join([w for w in self.tokenizer.tokenize(line) if w not in self.stopwords]) for line in
123 | sentences]
124 | return X_tokens
125 |
126 | def train(self, sentences):
127 | """
128 | Train model and save model
129 | :param sentences: list of text, eg: [text1, text2, ...]
130 | :return: model
131 | """
132 | logger.debug('train model')
133 | X_tokens = self.tokenize_sentences(sentences)
134 | logger.debug('data tokens top 1: {}'.format(X_tokens[:1]))
135 | feature = self.feature.fit_transform(X_tokens)
136 | # fit cluster
137 | self.model.fit(feature)
138 | labels = self.model.labels_
139 | logger.debug('cluster labels:{}'.format(labels))
140 | output_dir = self.output_dir
141 | if output_dir:
142 | os.makedirs(output_dir, exist_ok=True)
143 | feature_path = os.path.join(output_dir, 'cluster_feature.pkl')
144 | self.save_pkl(self.feature, feature_path)
145 | model_path = os.path.join(output_dir, 'cluster_model.pkl')
146 | self.save_pkl(self.model, model_path)
147 | logger.info('save done. feature path: {}, model path: {}'.format(feature_path, model_path))
148 |
149 | self.is_trained = True
150 | return feature, labels
151 |
152 | def predict(self, X):
153 | """
154 | Predict label
155 | :param X: list, input text list, eg: [text1, text2, ...]
156 | :return: list, label name
157 | """
158 | if not self.is_trained:
159 | raise ValueError('model is None, run train first.')
160 | # tokenize text
161 | X_tokens = self.tokenize_sentences(X)
162 | # transform
163 | feat = self.feature.transform(X_tokens)
164 | return self.model.predict(feat)
165 |
166 | def show_clusters(self, feature_matrix, labels, image_file='cluster.png'):
167 | """
168 | Show cluster plt image
169 | :param feature_matrix:
170 | :param labels:
171 | :param image_file:
172 | :return:
173 | """
174 | if not self.is_trained:
175 | raise ValueError('model is None, run train first.')
176 | self.show_plt(feature_matrix, labels, image_file)
177 |
178 | def load_model(self):
179 | """
180 | Load model from output_dir
181 | :param output_dir: path
182 | :return: None
183 | """
184 | model_path = os.path.join(self.output_dir, 'cluster_model.pkl')
185 | if not os.path.exists(model_path):
186 | raise ValueError("model is not found. please train and save model first.")
187 | self.model = self.load_pkl(model_path)
188 | feature_path = os.path.join(self.output_dir, 'cluster_feature.pkl')
189 | self.feature = self.load_pkl(feature_path)
190 | self.is_trained = True
191 | logger.info('model loaded {}'.format(self.output_dir))
192 | return self.is_trained
193 |
194 |
195 | if __name__ == '__main__':
196 | m = TextCluster(output_dir='models/cluster', n_clusters=2)
197 | print(m)
198 | data = [
199 | 'Student debt to cost Britain billions within decades',
200 | 'Chinese education for TV experiment',
201 | 'Abbott government spends $8 million on higher education',
202 | 'Middle East and Asia boost investment in top level sports',
203 | 'Summit Series look launches HBO Canada sports doc series: Mudhar'
204 | ]
205 | feat, labels = m.train(data)
206 | m.show_clusters(feat, labels, image_file='models/cluster/cluster.png')
207 | m.load_model()
208 | r = m.predict(['Abbott government spends $8 million on higher education media blitz',
209 | 'Middle East and Asia boost investment in top level sports'])
210 | print(r)
211 |
--------------------------------------------------------------------------------
/pytextclassifier/textcnn_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: TextCNN Classifier
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import sys
10 | import time
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from loguru import logger
17 | from sklearn import metrics
18 | from sklearn.model_selection import train_test_split
19 |
20 | sys.path.append('..')
21 | from pytextclassifier.base_classifier import ClassifierABC, load_data
22 | from pytextclassifier.data_helper import set_seed, build_vocab, load_vocab
23 | from pytextclassifier.time_util import get_time_spend
24 |
25 | pwd_path = os.path.abspath(os.path.dirname(__file__))
26 | device = 'cuda' if torch.cuda.is_available() else (
27 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')
28 |
29 |
30 | def build_dataset(tokenizer, X, y, word_vocab_path, label_vocab_path, max_seq_length=128,
31 | unk_token='[UNK]', pad_token='[PAD]', max_vocab_size=10000):
32 | if os.path.exists(word_vocab_path):
33 | word_id_map = json.load(open(word_vocab_path, 'r', encoding='utf-8'))
34 | else:
35 | word_id_map = build_vocab(X, tokenizer=tokenizer, max_size=max_vocab_size, min_freq=1,
36 | unk_token=unk_token, pad_token=pad_token)
37 | json.dump(word_id_map, open(word_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
38 | logger.debug(f"word vocab size: {len(word_id_map)}, word_vocab_path: {word_vocab_path}")
39 |
40 | if os.path.exists(label_vocab_path):
41 | label_id_map = json.load(open(label_vocab_path, 'r', encoding='utf-8'))
42 | else:
43 | id_label_map = {id: v for id, v in enumerate(set(y.tolist()))}
44 | label_id_map = {v: k for k, v in id_label_map.items()}
45 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
46 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}")
47 |
48 | def load_dataset(X, y, max_seq_length=128):
49 | contents = []
50 | for content, label in zip(X, y):
51 | words_line = []
52 | token = tokenizer(content)
53 | seq_len = len(token)
54 | if max_seq_length:
55 | if len(token) < max_seq_length:
56 | token.extend([pad_token] * (max_seq_length - len(token)))
57 | else:
58 | token = token[:max_seq_length]
59 | seq_len = max_seq_length
60 | # word to id
61 | for word in token:
62 | words_line.append(word_id_map.get(word, word_id_map.get(unk_token)))
63 | label_id = label_id_map.get(label)
64 | contents.append((words_line, label_id, seq_len))
65 | return contents
66 |
67 | dataset = load_dataset(X, y, max_seq_length)
68 | return dataset, word_id_map, label_id_map
69 |
70 |
71 | class DatasetIterater:
72 | def __init__(self, dataset, device, batch_size=32):
73 | self.batch_size = batch_size
74 | self.dataset = dataset
75 | self.n_batches = len(dataset) // batch_size if len(dataset) > batch_size else 1
76 | self.residue = False # 记录batch数量是否为整数
77 | if len(dataset) % self.n_batches != 0:
78 | self.residue = True
79 | self.index = 0
80 | self.device = device
81 |
82 | def _to_tensor(self, datas):
83 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
84 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
85 |
86 | # pad前的长度(超过max_seq_length的设为max_seq_length)
87 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
88 | return (x, seq_len), y
89 |
90 | def __next__(self):
91 | if self.residue and self.index == self.n_batches:
92 | batches = self.dataset[self.index * self.batch_size: len(self.dataset)]
93 | self.index += 1
94 | batches = self._to_tensor(batches)
95 | return batches
96 |
97 | elif self.index >= self.n_batches:
98 | self.index = 0
99 | raise StopIteration
100 | else:
101 | batches = self.dataset[self.index * self.batch_size: (self.index + 1) * self.batch_size]
102 | self.index += 1
103 | batches = self._to_tensor(batches)
104 | return batches
105 |
106 | def __iter__(self):
107 | return self
108 |
109 | def __len__(self):
110 | if self.residue:
111 | return self.n_batches + 1
112 | else:
113 | return self.n_batches
114 |
115 |
116 | def build_iterator(dataset, device, batch_size=32):
117 | return DatasetIterater(dataset, device, batch_size)
118 |
119 |
120 | class TextCNNModel(nn.Module):
121 | """Convolutional Neural Networks for Sentence Classification"""
122 |
123 | def __init__(
124 | self,
125 | vocab_size,
126 | num_classes,
127 | embed_size=200, filter_sizes=(2, 3, 4), num_filters=256, dropout_rate=0.5
128 | ):
129 | """
130 | Init the TextCNNModel
131 | @param vocab_size:
132 | @param num_classes:
133 | @param embed_size:
134 | @param filter_sizes: 卷积核尺寸
135 | @param num_filters: 卷积核数量(channels数)
136 | @param dropout_rate:
137 | """
138 | super().__init__()
139 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size - 1)
140 | self.convs = nn.ModuleList(
141 | [nn.Conv2d(1, num_filters, (k, embed_size)) for k in filter_sizes])
142 | self.dropout = nn.Dropout(dropout_rate)
143 | self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes)
144 |
145 | def conv_and_pool(self, x, conv):
146 | x = F.relu(conv(x)).squeeze(3)
147 | x = F.max_pool1d(x, x.size(2)).squeeze(2)
148 | return x
149 |
150 | def forward(self, x):
151 | out = self.embedding(x[0])
152 | out = out.unsqueeze(1)
153 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
154 | out = self.dropout(out)
155 | out = self.fc(out)
156 | return out
157 |
158 |
159 | class TextCNNClassifier(ClassifierABC):
160 | def __init__(
161 | self,
162 | output_dir="outputs",
163 | filter_sizes=(2, 3, 4), num_filters=256,
164 | dropout_rate=0.5, batch_size=64, max_seq_length=128,
165 | embed_size=200, max_vocab_size=10000,
166 | unk_token='[UNK]', pad_token='[PAD]', tokenizer=None,
167 | ):
168 | """
169 | Init the TextCNNClassifier
170 | @param output_dir: 模型保存路径
171 | @param filter_sizes: 卷积核尺寸
172 | @param num_filters: 卷积核数量(channels数)
173 | @param dropout_rate:
174 | @param batch_size:
175 | @param max_seq_length:
176 | @param embed_size:
177 | @param max_vocab_size:
178 | @param unk_token:
179 | @param pad_token:
180 | @param tokenizer: 切词器,默认为字粒度切分
181 | """
182 | self.output_dir = output_dir
183 | self.is_trained = False
184 | self.model = None
185 | logger.debug(f'device: {device}')
186 | self.filter_sizes = filter_sizes
187 | self.num_filters = num_filters
188 | self.dropout_rate = dropout_rate
189 | self.batch_size = batch_size
190 | self.max_seq_length = max_seq_length
191 | self.embed_size = embed_size
192 | self.max_vocab_size = max_vocab_size
193 | self.unk_token = unk_token
194 | self.pad_token = pad_token
195 | self.tokenizer = tokenizer if tokenizer else lambda x: [y for y in x] # char-level
196 |
197 | def __str__(self):
198 | return f'TextCNNClassifier instance ({self.model})'
199 |
200 | def train(
201 | self,
202 | data_list_or_path,
203 | header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1,
204 | num_epochs=20, learning_rate=1e-3,
205 | require_improvement=1000, evaluate_during_training_steps=100
206 | ):
207 | """
208 | Train model with data_list_or_path and save model to output_dir
209 | @param data_list_or_path:
210 | @param output_dir:
211 | @param header:
212 | @param names:
213 | @param delimiter:
214 | @param test_size:
215 | @param num_epochs: epoch数
216 | @param learning_rate: 学习率
217 | @param require_improvement: 若超过1000batch效果还没提升,则提前结束训练
218 | @param evaluate_during_training_steps: 每隔多少step评估一次模型
219 | @return:
220 | """
221 | logger.debug('train model...')
222 | SEED = 1
223 | set_seed(SEED)
224 | # load data
225 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True)
226 | output_dir = self.output_dir
227 | if output_dir:
228 | os.makedirs(output_dir, exist_ok=True)
229 | word_vocab_path = os.path.join(output_dir, 'word_vocab.json')
230 | label_vocab_path = os.path.join(output_dir, 'label_vocab.json')
231 | save_model_path = os.path.join(output_dir, 'model.pth')
232 |
233 | dataset, self.word_id_map, self.label_id_map = build_dataset(
234 | self.tokenizer, X, y,
235 | word_vocab_path,
236 | label_vocab_path,
237 | max_vocab_size=self.max_vocab_size,
238 | max_seq_length=self.max_seq_length,
239 | unk_token=self.unk_token, pad_token=self.pad_token
240 | )
241 | train_data, dev_data = train_test_split(dataset, test_size=test_size, random_state=SEED)
242 | logger.debug(f"train_data size: {len(train_data)}, dev_data size: {len(dev_data)}")
243 | logger.debug(f'train_data sample:\n{train_data[:3]}\ndev_data sample:\n{dev_data[:3]}')
244 | train_iter = build_iterator(train_data, device, self.batch_size)
245 | dev_iter = build_iterator(dev_data, device, self.batch_size)
246 | # create model
247 | vocab_size = len(self.word_id_map)
248 | num_classes = len(self.label_id_map)
249 | logger.debug(f'vocab_size:{vocab_size}', 'num_classes:', num_classes)
250 | self.model = TextCNNModel(
251 | vocab_size, num_classes,
252 | embed_size=self.embed_size,
253 | filter_sizes=self.filter_sizes,
254 | num_filters=self.num_filters,
255 | dropout_rate=self.dropout_rate
256 | )
257 | self.model.to(device)
258 | # init_network(self.model)
259 | logger.info(self.model.parameters)
260 | # train model
261 | history = self.train_model_from_data_iterator(save_model_path, train_iter, dev_iter, num_epochs, learning_rate,
262 | require_improvement, evaluate_during_training_steps)
263 | self.is_trained = True
264 | logger.debug('train model done')
265 | return history
266 |
267 | def train_model_from_data_iterator(self, save_model_path, train_iter, dev_iter,
268 | num_epochs=10, learning_rate=1e-3,
269 | require_improvement=1000, evaluate_during_training_steps=100):
270 | history = []
271 | # train
272 | start_time = time.time()
273 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
274 |
275 | total_batch = 0 # 记录进行到多少batch
276 | dev_best_loss = 1e10
277 | last_improve = 0 # 记录上次验证集loss下降的batch数
278 | flag = False # 记录是否很久没有效果提升
279 | for epoch in range(num_epochs):
280 | logger.debug('Epoch [{}/{}]'.format(epoch + 1, num_epochs))
281 | for i, (trains, labels) in enumerate(train_iter):
282 | self.model.train()
283 | outputs = self.model(trains)
284 | loss = F.cross_entropy(outputs, labels)
285 | # compute gradient and do SGD step
286 | optimizer.zero_grad()
287 | loss.backward()
288 | optimizer.step()
289 | if total_batch % evaluate_during_training_steps == 0:
290 | # 输出在训练集和验证集上的效果
291 | y_true = labels.cpu()
292 | y_pred = torch.max(outputs, 1)[1].cpu()
293 | train_acc = metrics.accuracy_score(y_true, y_pred)
294 | if dev_iter is not None:
295 | dev_acc, dev_loss = self.evaluate(dev_iter)
296 | if dev_loss < dev_best_loss:
297 | dev_best_loss = dev_loss
298 | torch.save(self.model.state_dict(), save_model_path)
299 | logger.debug(f'Saved model: {save_model_path}')
300 | improve = '*'
301 | last_improve = total_batch
302 | else:
303 | improve = ''
304 | time_dif = get_time_spend(start_time)
305 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Val Loss:{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'.format(
306 | total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)
307 | else:
308 | time_dif = get_time_spend(start_time)
309 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Time:{3}'.format(
310 | total_batch, loss.item(), train_acc, time_dif)
311 | logger.debug(msg)
312 | history.append(msg)
313 | self.model.train()
314 | total_batch += 1
315 | if total_batch - last_improve > require_improvement:
316 | # 验证集loss超过1000batch没下降,结束训练
317 | logger.debug("No optimization for a long time, auto-stopping...")
318 | flag = True
319 | break
320 | if flag:
321 | break
322 | return history
323 |
324 | def predict(self, sentences: list):
325 | """
326 | Predict labels and label probability for sentences.
327 | @param sentences: list, input text list, eg: [text1, text2, ...]
328 | @return: predict_label, predict_prob
329 | """
330 | if not self.is_trained:
331 | raise ValueError('model not trained.')
332 | self.model.eval()
333 |
334 | def load_dataset(X, max_seq_length=128):
335 | contents = []
336 | for content in X:
337 | words_line = []
338 | token = self.tokenizer(content)
339 | seq_len = len(token)
340 | if max_seq_length:
341 | if len(token) < max_seq_length:
342 | token.extend([self.pad_token] * (max_seq_length - len(token)))
343 | else:
344 | token = token[:max_seq_length]
345 | seq_len = max_seq_length
346 | # word to id
347 | for word in token:
348 | words_line.append(self.word_id_map.get(word, self.word_id_map.get(self.unk_token)))
349 | contents.append((words_line, 0, seq_len))
350 | return contents
351 |
352 | data = load_dataset(sentences, self.max_seq_length)
353 | data_iter = build_iterator(data, device, self.batch_size)
354 | # predict prob
355 | predict_all = np.array([], dtype=int)
356 | proba_all = np.array([], dtype=float)
357 | with torch.no_grad():
358 | for texts, _ in data_iter:
359 | outputs = self.model(texts)
360 | logit = F.softmax(outputs, dim=1).detach().cpu().numpy()
361 | pred = np.argmax(logit, axis=1)
362 | proba = np.max(logit, axis=1)
363 |
364 | predict_all = np.append(predict_all, pred)
365 | proba_all = np.append(proba_all, proba)
366 | id_label_map = {v: k for k, v in self.label_id_map.items()}
367 | predict_labels = [id_label_map.get(i) for i in predict_all]
368 | predict_probs = proba_all.tolist()
369 | return predict_labels, predict_probs
370 |
371 | def evaluate_model(self, data_list_or_path, header=None,
372 | names=('labels', 'text'), delimiter='\t'):
373 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter)
374 | self.load_model()
375 | data, word_id_map, label_id_map = build_dataset(
376 | self.tokenizer, X_test, y_test,
377 | self.word_vocab_path,
378 | self.label_vocab_path,
379 | max_vocab_size=self.max_vocab_size,
380 | max_seq_length=self.max_seq_length,
381 | unk_token=self.unk_token,
382 | pad_token=self.pad_token,
383 | )
384 | data_iter = build_iterator(data, device, self.batch_size)
385 | return self.evaluate(data_iter)[0]
386 |
387 | def evaluate(self, data_iter):
388 | """
389 | Evaluate model.
390 | @param data_iter:
391 | @return: accuracy score, loss
392 | """
393 | if not self.model:
394 | raise ValueError('model not trained.')
395 | self.model.eval()
396 | loss_total = 0.0
397 | predict_all = np.array([], dtype=int)
398 | labels_all = np.array([], dtype=int)
399 | with torch.no_grad():
400 | for texts, labels in data_iter:
401 | outputs = self.model(texts)
402 | loss = F.cross_entropy(outputs, labels)
403 | loss_total += loss
404 | labels = labels.cpu().numpy()
405 | predic = torch.max(outputs, 1)[1].cpu().numpy()
406 | labels_all = np.append(labels_all, labels)
407 | predict_all = np.append(predict_all, predic)
408 | logger.debug(f"evaluate, last batch, y_true: {labels}, y_pred: {predic}")
409 | acc = metrics.accuracy_score(labels_all, predict_all)
410 | return acc, loss_total / len(data_iter)
411 |
412 | def load_model(self):
413 | """
414 | Load model from output_dir
415 | @return:
416 | """
417 | model_path = os.path.join(self.output_dir, 'model.pth')
418 | if os.path.exists(model_path):
419 | self.word_vocab_path = os.path.join(self.output_dir, 'word_vocab.json')
420 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json')
421 | self.word_id_map = load_vocab(self.word_vocab_path)
422 | self.label_id_map = load_vocab(self.label_vocab_path)
423 | vocab_size = len(self.word_id_map)
424 | num_classes = len(self.label_id_map)
425 | self.model = TextCNNModel(
426 | vocab_size, num_classes,
427 | embed_size=self.embed_size,
428 | filter_sizes=self.filter_sizes,
429 | num_filters=self.num_filters,
430 | dropout_rate=self.dropout_rate
431 | )
432 | self.model.load_state_dict(torch.load(model_path, map_location=device))
433 | self.model.to(device)
434 | self.is_trained = True
435 | else:
436 | logger.error(f'{model_path} not exists.')
437 | self.is_trained = False
438 | return self.is_trained
439 |
440 |
441 | if __name__ == '__main__':
442 | parser = argparse.ArgumentParser(description='Text Classification')
443 | parser.add_argument('--output_dir', default='models/textcnn', type=str, help='save model dir')
444 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'),
445 | type=str, help='sample data file path')
446 | args = parser.parse_args()
447 | print(args)
448 | # create model
449 | m = TextCNNClassifier(output_dir=args.output_dir)
450 | # train model
451 | m.train(data_list_or_path=args.data_path, num_epochs=3)
452 | # load trained model and predict
453 | m.load_model()
454 | print('best model loaded from file, and predict')
455 | X, y, _ = load_data(args.data_path)
456 | X = X[:5]
457 | y = y[:5]
458 | predict_labels, predict_probs = m.predict(X)
459 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y):
460 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth)
461 |
--------------------------------------------------------------------------------
/pytextclassifier/textrnn_classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: TextRNN Attention Classifier
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import sys
10 | import time
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from loguru import logger
17 | from sklearn import metrics
18 | from sklearn.model_selection import train_test_split
19 |
20 | sys.path.append('..')
21 | from pytextclassifier.base_classifier import ClassifierABC, load_data
22 | from pytextclassifier.data_helper import set_seed, build_vocab, load_vocab
23 | from pytextclassifier.time_util import get_time_spend
24 |
25 | pwd_path = os.path.abspath(os.path.dirname(__file__))
26 | device = 'cuda' if torch.cuda.is_available() else (
27 | 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')
28 |
29 |
30 | def build_dataset(tokenizer, X, y, word_vocab_path, label_vocab_path, max_seq_length=128,
31 | unk_token='[UNK]', pad_token='[PAD]', max_vocab_size=10000):
32 | if os.path.exists(word_vocab_path):
33 | word_id_map = json.load(open(word_vocab_path, 'r', encoding='utf-8'))
34 | else:
35 | word_id_map = build_vocab(X, tokenizer=tokenizer, max_size=max_vocab_size, min_freq=1,
36 | unk_token=unk_token, pad_token=pad_token)
37 | json.dump(word_id_map, open(word_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
38 | logger.debug(f"word vocab size: {len(word_id_map)}, word_vocab_path: {word_vocab_path}")
39 |
40 | if os.path.exists(label_vocab_path):
41 | label_id_map = json.load(open(label_vocab_path, 'r', encoding='utf-8'))
42 | else:
43 | id_label_map = {id: v for id, v in enumerate(set(y.tolist()))}
44 | label_id_map = {v: k for k, v in id_label_map.items()}
45 | json.dump(label_id_map, open(label_vocab_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
46 | logger.debug(f"label vocab size: {len(label_id_map)}, label_vocab_path: {label_vocab_path}")
47 |
48 | def load_dataset(X, y, max_seq_length=128):
49 | contents = []
50 | for content, label in zip(X, y):
51 | words_line = []
52 | token = tokenizer(content)
53 | seq_len = len(token)
54 | if max_seq_length:
55 | if len(token) < max_seq_length:
56 | token.extend([pad_token] * (max_seq_length - len(token)))
57 | else:
58 | token = token[:max_seq_length]
59 | seq_len = max_seq_length
60 | # word to id
61 | for word in token:
62 | words_line.append(word_id_map.get(word, word_id_map.get(unk_token)))
63 | label_id = label_id_map.get(label)
64 | contents.append((words_line, label_id, seq_len))
65 | return contents
66 |
67 | dataset = load_dataset(X, y, max_seq_length)
68 | return dataset, word_id_map, label_id_map
69 |
70 |
71 | class DatasetIterater:
72 | def __init__(self, dataset, device, batch_size=32):
73 | self.batch_size = batch_size
74 | self.dataset = dataset
75 | self.n_batches = len(dataset) // batch_size if len(dataset) > batch_size else 1
76 | self.residue = False # 记录batch数量是否为整数
77 | if len(dataset) % self.n_batches != 0:
78 | self.residue = True
79 | self.index = 0
80 | self.device = device
81 |
82 | def _to_tensor(self, datas):
83 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
84 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
85 |
86 | # pad前的长度(超过max_seq_length的设为max_seq_length)
87 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
88 | return (x, seq_len), y
89 |
90 | def __next__(self):
91 | if self.residue and self.index == self.n_batches:
92 | batches = self.dataset[self.index * self.batch_size: len(self.dataset)]
93 | self.index += 1
94 | batches = self._to_tensor(batches)
95 | return batches
96 |
97 | elif self.index >= self.n_batches:
98 | self.index = 0
99 | raise StopIteration
100 | else:
101 | batches = self.dataset[self.index * self.batch_size: (self.index + 1) * self.batch_size]
102 | self.index += 1
103 | batches = self._to_tensor(batches)
104 | return batches
105 |
106 | def __iter__(self):
107 | return self
108 |
109 | def __len__(self):
110 | if self.residue:
111 | return self.n_batches + 1
112 | else:
113 | return self.n_batches
114 |
115 |
116 | def build_iterator(dataset, device, batch_size=32):
117 | return DatasetIterater(dataset, device, batch_size)
118 |
119 |
120 | class TextRNNAttModel(nn.Module):
121 | """Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification"""
122 |
123 | def __init__(
124 | self,
125 | vocab_size,
126 | num_classes,
127 | embed_size=200,
128 | hidden_size=128,
129 | num_layers=2,
130 | dropout_rate=0.5,
131 | ):
132 | """
133 | TextRNN-Att model
134 | @param vocab_size:
135 | @param num_classes:
136 | @param embed_size: 字向量维度
137 | @param hidden_size: lstm隐藏层
138 | @param num_layers: lstm层数
139 | @param dropout_rate:
140 | """
141 | super().__init__()
142 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size - 1)
143 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers,
144 | bidirectional=True, batch_first=True, dropout=dropout_rate)
145 | self.tanh1 = nn.Tanh()
146 | # self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2))
147 | self.w = nn.Parameter(torch.zeros(hidden_size * 2))
148 | self.tanh2 = nn.Tanh()
149 | self.fc1 = nn.Linear(hidden_size * 2, int(hidden_size / 2))
150 | self.fc2 = nn.Linear(int(hidden_size / 2), num_classes)
151 |
152 | def forward(self, x):
153 | emb = self.embedding(x[0]) # [batch_size, seq_len, embeding]=[128, 32, 300]
154 | H, _ = self.lstm(emb) # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256]
155 |
156 | M = self.tanh1(H) # [128, 32, 256]
157 | # M = torch.tanh(torch.matmul(H, self.u))
158 | alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1) # [128, 32, 1]
159 | out = H * alpha # [128, 32, 256]
160 | out = torch.sum(out, 1) # [128, 256]
161 | out = F.relu(out)
162 | out = self.fc1(out)
163 | out = self.fc2(out) # [128, 64]
164 | return out
165 |
166 |
167 | class TextRNNClassifier(ClassifierABC):
168 | def __init__(
169 | self,
170 | output_dir="outputs",
171 | hidden_size=128,
172 | num_layers=2,
173 | dropout_rate=0.5, batch_size=64, max_seq_length=128,
174 | embed_size=200, max_vocab_size=10000, unk_token='[UNK]',
175 | pad_token='[PAD]', tokenizer=None,
176 | ):
177 | """
178 | Init the TextRNNClassifier
179 | @param output_dir: 模型保存路径
180 | @param hidden_size: lstm隐藏层
181 | @param num_layers: lstm层数
182 | @param dropout_rate:
183 | @param batch_size:
184 | @param max_seq_length:
185 | @param embed_size:
186 | @param max_vocab_size:
187 | @param unk_token:
188 | @param pad_token:
189 | @param tokenizer: 切词器,默认为字粒度切分
190 | """
191 | self.output_dir = output_dir
192 | self.is_trained = False
193 | self.model = None
194 | logger.debug(f'device: {device}')
195 | self.hidden_size = hidden_size
196 | self.num_layers = num_layers
197 | self.dropout_rate = dropout_rate
198 | self.batch_size = batch_size
199 | self.max_seq_length = max_seq_length
200 | self.embed_size = embed_size
201 | self.max_vocab_size = max_vocab_size
202 | self.unk_token = unk_token
203 | self.pad_token = pad_token
204 | self.tokenizer = tokenizer if tokenizer else lambda x: [y for y in x] # char-level
205 |
206 | def __str__(self):
207 | return f'TextRNNClassifier instance ({self.model})'
208 |
209 | def train(
210 | self,
211 | data_list_or_path,
212 | header=None, names=('labels', 'text'), delimiter='\t', test_size=0.1,
213 | num_epochs=20, learning_rate=1e-3,
214 | require_improvement=1000, evaluate_during_training_steps=100
215 | ):
216 | """
217 | Train model with data_list_or_path and save model to output_dir
218 | @param data_list_or_path:
219 | @param header:
220 | @param names:
221 | @param delimiter:
222 | @param test_size:
223 | @param num_epochs: epoch数
224 | @param learning_rate: 学习率
225 | @param require_improvement: 若超过1000batch效果还没提升,则提前结束训练
226 | @param evaluate_during_training_steps: 每隔多少step评估一次模型
227 | @return:
228 | """
229 | logger.debug('train model...')
230 | SEED = 1
231 | set_seed(SEED)
232 | # load data
233 | X, y, data_df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter, is_train=True)
234 | output_dir = self.output_dir
235 | if output_dir:
236 | os.makedirs(output_dir, exist_ok=True)
237 | word_vocab_path = os.path.join(output_dir, 'word_vocab.json')
238 | label_vocab_path = os.path.join(output_dir, 'label_vocab.json')
239 | save_model_path = os.path.join(output_dir, 'model.pth')
240 |
241 | dataset, self.word_id_map, self.label_id_map = build_dataset(
242 | self.tokenizer, X, y,
243 | word_vocab_path,
244 | label_vocab_path,
245 | max_vocab_size=self.max_vocab_size,
246 | max_seq_length=self.max_seq_length,
247 | unk_token=self.unk_token, pad_token=self.pad_token
248 | )
249 | train_data, dev_data = train_test_split(dataset, test_size=test_size, random_state=SEED)
250 | logger.debug(f"train_data size: {len(train_data)}, dev_data size: {len(dev_data)}")
251 | logger.debug(f'train_data sample:\n{train_data[:3]}\ndev_data sample:\n{dev_data[:3]}')
252 | train_iter = build_iterator(train_data, device, self.batch_size)
253 | dev_iter = build_iterator(dev_data, device, self.batch_size)
254 | # create model
255 | vocab_size = len(self.word_id_map)
256 | num_classes = len(self.label_id_map)
257 | logger.debug(f'vocab_size:{vocab_size}', 'num_classes:', num_classes)
258 | self.model = TextRNNAttModel(
259 | vocab_size, num_classes,
260 | embed_size=self.embed_size,
261 | hidden_size=self.hidden_size,
262 | num_layers=self.num_layers,
263 | dropout_rate=self.dropout_rate
264 | )
265 | self.model.to(device)
266 | # init_network(self.model)
267 | logger.info(self.model.parameters)
268 | # train model
269 | history = self.train_model_from_data_iterator(save_model_path, train_iter, dev_iter, num_epochs, learning_rate,
270 | require_improvement, evaluate_during_training_steps)
271 | self.is_trained = True
272 | logger.debug('train model done')
273 | return history
274 |
275 | def train_model_from_data_iterator(self, save_model_path, train_iter, dev_iter,
276 | num_epochs=10, learning_rate=1e-3,
277 | require_improvement=1000, evaluate_during_training_steps=100):
278 | history = []
279 | # train
280 | start_time = time.time()
281 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
282 |
283 | total_batch = 0 # 记录进行到多少batch
284 | dev_best_loss = 1e10
285 | last_improve = 0 # 记录上次验证集loss下降的batch数
286 | flag = False # 记录是否很久没有效果提升
287 | for epoch in range(num_epochs):
288 | logger.debug('Epoch [{}/{}]'.format(epoch + 1, num_epochs))
289 | for i, (trains, labels) in enumerate(train_iter):
290 | self.model.train()
291 | outputs = self.model(trains)
292 | loss = F.cross_entropy(outputs, labels)
293 | # compute gradient and do SGD step
294 | optimizer.zero_grad()
295 | loss.backward()
296 | optimizer.step()
297 | if total_batch % evaluate_during_training_steps == 0:
298 | # 输出在训练集和验证集上的效果
299 | y_true = labels.cpu()
300 | y_pred = torch.max(outputs, 1)[1].cpu()
301 | train_acc = metrics.accuracy_score(y_true, y_pred)
302 | if dev_iter is not None:
303 | dev_acc, dev_loss = self.evaluate(dev_iter)
304 | if dev_loss < dev_best_loss:
305 | dev_best_loss = dev_loss
306 | torch.save(self.model.state_dict(), save_model_path)
307 | logger.debug(f'Saved model: {save_model_path}')
308 | improve = '*'
309 | last_improve = total_batch
310 | else:
311 | improve = ''
312 | time_dif = get_time_spend(start_time)
313 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Val Loss:{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'.format(
314 | total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)
315 | else:
316 | time_dif = get_time_spend(start_time)
317 | msg = 'Iter:{0:>6},Train Loss:{1:>5.2},Train Acc:{2:>6.2%},Time:{3}'.format(
318 | total_batch, loss.item(), train_acc, time_dif)
319 | logger.debug(msg)
320 | history.append(msg)
321 | self.model.train()
322 | total_batch += 1
323 | if total_batch - last_improve > require_improvement:
324 | # 验证集loss超过1000batch没下降,结束训练
325 | logger.debug("No optimization for a long time, auto-stopping...")
326 | flag = True
327 | break
328 | if flag:
329 | break
330 | return history
331 |
332 | def predict(self, sentences: list):
333 | """
334 | Predict labels and label probability for sentences.
335 | @param sentences: list, input text list, eg: [text1, text2, ...]
336 | @return: predict_label, predict_prob
337 | """
338 | if not self.is_trained:
339 | raise ValueError('model not trained.')
340 | self.model.eval()
341 |
342 | def load_dataset(X, max_seq_length=128):
343 | contents = []
344 | for content in X:
345 | words_line = []
346 | token = self.tokenizer(content)
347 | seq_len = len(token)
348 | if max_seq_length:
349 | if len(token) < max_seq_length:
350 | token.extend([self.pad_token] * (max_seq_length - len(token)))
351 | else:
352 | token = token[:max_seq_length]
353 | seq_len = max_seq_length
354 | # word to id
355 | for word in token:
356 | words_line.append(self.word_id_map.get(word, self.word_id_map.get(self.unk_token)))
357 | contents.append((words_line, 0, seq_len))
358 | return contents
359 |
360 | data = load_dataset(sentences, self.max_seq_length)
361 | data_iter = build_iterator(data, device, self.batch_size)
362 | # predict prob
363 | predict_all = np.array([], dtype=int)
364 | proba_all = np.array([], dtype=float)
365 | with torch.no_grad():
366 | for texts, _ in data_iter:
367 | outputs = self.model(texts)
368 | logit = F.softmax(outputs, dim=1).detach().cpu().numpy()
369 | pred = np.argmax(logit, axis=1)
370 | proba = np.max(logit, axis=1)
371 |
372 | predict_all = np.append(predict_all, pred)
373 | proba_all = np.append(proba_all, proba)
374 | id_label_map = {v: k for k, v in self.label_id_map.items()}
375 | predict_labels = [id_label_map.get(i) for i in predict_all]
376 | predict_probs = proba_all.tolist()
377 | return predict_labels, predict_probs
378 |
379 | def evaluate_model(self, data_list_or_path, header=None,
380 | names=('labels', 'text'), delimiter='\t'):
381 | """
382 | Evaluate model.
383 | @param data_list_or_path:
384 | @param header:
385 | @param names:
386 | @param delimiter:
387 | @return:
388 | """
389 | X_test, y_test, df = load_data(data_list_or_path, header=header, names=names, delimiter=delimiter)
390 | self.load_model()
391 | data, word_id_map, label_id_map = build_dataset(
392 | self.tokenizer, X_test, y_test,
393 | self.word_vocab_path,
394 | self.label_vocab_path,
395 | max_vocab_size=self.max_vocab_size,
396 | max_seq_length=self.max_seq_length,
397 | unk_token=self.unk_token,
398 | pad_token=self.pad_token,
399 | )
400 | data_iter = build_iterator(data, device, self.batch_size)
401 | return self.evaluate(data_iter)[0]
402 |
403 | def evaluate(self, data_iter):
404 | """
405 | Evaluate model.
406 | @param data_iter:
407 | @return: accuracy score, loss
408 | """
409 | if not self.model:
410 | raise ValueError('model not trained.')
411 | self.model.eval()
412 | loss_total = 0.0
413 | predict_all = np.array([], dtype=int)
414 | labels_all = np.array([], dtype=int)
415 | with torch.no_grad():
416 | for texts, labels in data_iter:
417 | outputs = self.model(texts)
418 | loss = F.cross_entropy(outputs, labels)
419 | loss_total += loss
420 | labels = labels.cpu().numpy()
421 | predic = torch.max(outputs, 1)[1].cpu().numpy()
422 | labels_all = np.append(labels_all, labels)
423 | predict_all = np.append(predict_all, predic)
424 | logger.debug(f"evaluate, last batch, y_true: {labels}, y_pred: {predic}")
425 | acc = metrics.accuracy_score(labels_all, predict_all)
426 | return acc, loss_total / len(data_iter)
427 |
428 | def load_model(self):
429 | """
430 | Load model from output_dir
431 | @return:
432 | """
433 | model_path = os.path.join(self.output_dir, 'model.pth')
434 | if os.path.exists(model_path):
435 | self.word_vocab_path = os.path.join(self.output_dir, 'word_vocab.json')
436 | self.label_vocab_path = os.path.join(self.output_dir, 'label_vocab.json')
437 | self.word_id_map = load_vocab(self.word_vocab_path)
438 | self.label_id_map = load_vocab(self.label_vocab_path)
439 | vocab_size = len(self.word_id_map)
440 | num_classes = len(self.label_id_map)
441 | self.model = TextRNNAttModel(
442 | vocab_size, num_classes,
443 | embed_size=self.embed_size,
444 | hidden_size=self.hidden_size,
445 | num_layers=self.num_layers,
446 | dropout_rate=self.dropout_rate
447 | )
448 | self.model.load_state_dict(torch.load(model_path, map_location=device))
449 | self.model.to(device)
450 | self.is_trained = True
451 | else:
452 | logger.error(f'{model_path} not exists.')
453 | self.is_trained = False
454 | return self.is_trained
455 |
456 |
457 | if __name__ == '__main__':
458 | parser = argparse.ArgumentParser(description='Text Classification')
459 | parser.add_argument('--output_dir', default='models/textrnn', type=str, help='save model dir')
460 | parser.add_argument('--data_path', default=os.path.join(pwd_path, '../examples/thucnews_train_1w.txt'),
461 | type=str, help='sample data file path')
462 | args = parser.parse_args()
463 | print(args)
464 | # create model
465 | m = TextRNNClassifier(args.output_dir)
466 | # train model
467 | m.train(data_list_or_path=args.data_path, num_epochs=3)
468 | # load trained the best model and predict
469 | m.load_model()
470 | print('best model loaded from file, and predict')
471 | X, y, _ = load_data(args.data_path)
472 | X = X[:5]
473 | y = y[:5]
474 | predict_labels, predict_probs = m.predict(X)
475 | for text, pred_label, pred_prob, y_truth in zip(X, predict_labels, predict_probs, y):
476 | print(text, 'pred:', pred_label, pred_prob, ' truth:', y_truth)
477 |
--------------------------------------------------------------------------------
/pytextclassifier/time_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import time
7 | from datetime import timedelta
8 |
9 |
10 | def get_time_spend(start_time):
11 | """获取已使用时间"""
12 | end_time = time.time()
13 | time_dif = end_time - start_time
14 | return timedelta(seconds=int(round(time_dif)))
15 |
16 |
17 | def init_network(model, method='xavier', exclude='embedding'):
18 | """权重初始化,默认xavier"""
19 | import torch.nn as nn
20 | for name, w in model.named_parameters():
21 | if exclude not in name:
22 | if 'weight' in name:
23 | if method == 'xavier':
24 | nn.init.xavier_normal_(w)
25 | elif method == 'kaiming':
26 | nn.init.kaiming_normal_(w)
27 | else:
28 | nn.init.normal_(w)
29 | if 'bias' in name:
30 | nn.init.constant_(w, 0)
31 |
--------------------------------------------------------------------------------
/pytextclassifier/tokenizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description: Tokenization
5 | """
6 | import re
7 |
8 | import jieba
9 |
10 |
11 | def tokenize_words(text):
12 | """Word segmentation"""
13 | output = []
14 | sentences = split_2_short_text(text, include_symbol=True)
15 | for sentence, idx in sentences:
16 | if is_any_chinese_string(sentence):
17 | output.extend(jieba.lcut(sentence))
18 | else:
19 | output.extend(whitespace_tokenize(sentence))
20 | return output
21 |
22 |
23 | class Tokenizer(object):
24 | """Given Full tokenization."""
25 |
26 | def __init__(self, lower=True):
27 | self.lower = lower
28 |
29 | def tokenize(self, text):
30 | """Tokenizes a piece of text."""
31 | res = []
32 | if len(text) == 0:
33 | return res
34 |
35 | if self.lower:
36 | text = text.lower()
37 | # for the multilingual and Chinese
38 | res = tokenize_words(text)
39 | return res
40 |
41 |
42 | def split_2_short_text(text, include_symbol=True):
43 | """
44 | 长句切分为短句
45 | :param text: str
46 | :param include_symbol: bool
47 | :return: (sentence, idx)
48 | """
49 | re_han = re.compile("([\u4E00-\u9Fa5a-zA-Z0-9+#&]+)", re.U)
50 | result = []
51 | blocks = re_han.split(text)
52 | start_idx = 0
53 | for blk in blocks:
54 | if not blk:
55 | continue
56 | if include_symbol:
57 | result.append((blk, start_idx))
58 | else:
59 | if re_han.match(blk):
60 | result.append((blk, start_idx))
61 | start_idx += len(blk)
62 | return result
63 |
64 |
65 | def is_chinese(uchar):
66 | """判断一个unicode是否是汉字"""
67 | return '\u4e00' <= uchar <= '\u9fa5'
68 |
69 |
70 | def is_all_chinese_string(string):
71 | """判断是否全为汉字"""
72 | return all(is_chinese(c) for c in string)
73 |
74 |
75 | def is_any_chinese_string(string):
76 | """判断是否有中文汉字"""
77 | return any(is_chinese(c) for c in string)
78 |
79 |
80 | def whitespace_tokenize(text):
81 | """Runs basic whitespace cleaning and splitting on a piece of text."""
82 | text = text.strip()
83 | if not text:
84 | return []
85 | tokens = text.split()
86 | return tokens
87 |
88 |
89 | if __name__ == '__main__':
90 | a = '晚上一个人好孤单,想:找附近的人陪陪我.'
91 | b = "unlabeled example, aug_copy_num is the index of the generated augmented. you don't know."
92 |
93 | t = Tokenizer()
94 | word_list_a = t.tokenize(a + b)
95 | print('VocabTokenizer-word:', word_list_a)
96 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | loguru
2 | jieba
3 | scikit-learn
4 | pandas
5 | numpy
6 | transformers
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | python_functions=test_
3 |
4 | codestyle_max_line_length = 119
5 |
6 | log_cli = true
7 | log_cli_level = WARNING
8 |
9 | [metadata]
10 | description-file = README.md
11 | license_file = LICENSE
12 |
13 | [pycodestyle]
14 | max-line-length = 119
15 |
16 | [flake8]
17 | max-line-length = 119
18 | ignore = E203 , W503, F401
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | from setuptools import setup, find_packages
7 |
8 | __version__ = '1.4.0'
9 |
10 | with open('README.md', 'r', encoding='utf-8') as f:
11 | readme = f.read()
12 |
13 | setup(
14 | name='pytextclassifier',
15 | version=__version__,
16 | description='Text Classifier, Text Classification',
17 | long_description=readme,
18 | long_description_content_type='text/markdown',
19 | author='XuMing',
20 | author_email='xuming624@qq.com',
21 | url='https://github.com/shibing624/pytextclassifier',
22 | license='Apache 2.0',
23 | classifiers=[
24 | 'Intended Audience :: Science/Research',
25 | 'Operating System :: OS Independent',
26 | 'License :: OSI Approved :: Apache Software License',
27 | 'Programming Language :: Python',
28 | 'Programming Language :: Python :: 2.7',
29 | 'Programming Language :: Python :: 3',
30 | 'Topic :: Text Processing :: Linguistic',
31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
32 | ],
33 | keywords='pytextclassifier,textclassifier,classifier,textclassification',
34 | install_requires=[
35 | "loguru",
36 | "jieba",
37 | "scikit-learn",
38 | "pandas",
39 | "numpy",
40 | "transformers",
41 | ],
42 | packages=find_packages(exclude=['tests']),
43 | package_dir={'pytextclassifier': 'pytextclassifier'},
44 | package_data={
45 | 'pytextclassifier': ['*.*', '*.txt', '../examples/thucnews_train_1w.txt'],
46 | },
47 | )
48 |
--------------------------------------------------------------------------------
/tests/test_bert_onnx_bs_qps.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 |
7 | import os
8 | import shutil
9 | import sys
10 | import time
11 | import unittest
12 |
13 | import numpy as np
14 | import torch
15 | from loguru import logger
16 |
17 | sys.path.append('..')
18 | from pytextclassifier import BertClassifier
19 |
20 |
21 | class ModelSpeedTestCase(unittest.TestCase):
22 |
23 | def test_classifier_diff_batch_size(self):
24 | # Helper function to calculate QPS and 95th percentile latency
25 | def calculate_metrics(times):
26 | completion_times = np.array(times)
27 | total_requests = 300
28 |
29 | # 平均每秒请求数(QPS),计算公式为总请求数除以总耗时
30 | average_qps = total_requests / np.sum(completion_times)
31 | latency = np.sum(completion_times) / total_requests
32 |
33 | # 返回所有计算结果
34 | return {
35 | 'total_requests': total_requests, # 总请求数
36 | 'latency': latency, # 每条请求的平均完成时间
37 | 'average_qps': average_qps, # 平均每秒请求数
38 | }
39 |
40 | # Train the model once
41 | def train_model(output_dir='models/bert-chinese-v1'):
42 | m = BertClassifier(output_dir=output_dir, num_classes=2,
43 | model_type='bert', model_name='bert-base-chinese', num_epochs=1)
44 | data = [
45 | ('education', '名师指导托福语法技巧:名词的复数形式'),
46 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
47 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
48 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'),
49 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'),
50 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'),
51 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'),
52 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
53 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
54 | ('sports', '米兰客场8战不败国米10年连胜1'),
55 | ('sports', '米兰客场8战不败国米10年连胜2'),
56 | ('sports', '米兰客场8战不败国米10年连胜3'),
57 | ('sports', '米兰客场8战不败国米10年连胜4'),
58 | ('sports', '米兰客场8战不败国米10年连胜5'),
59 | ]
60 | m.train(data * 10)
61 | m.load_model()
62 | return m
63 |
64 | # Evaluate performance for a given batch size
65 | def evaluate_performance(m, batch_size):
66 | samples = ['名师指导托福语法技巧',
67 | '米兰客场8战不败',
68 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100
69 |
70 | batch_times = []
71 |
72 | for i in range(0, len(samples), batch_size):
73 | batch_samples = samples[i:i + batch_size]
74 | start_time = time.time()
75 | m.predict(batch_samples)
76 | end_time = time.time()
77 |
78 | batch_times.append(end_time - start_time)
79 |
80 | metrics = calculate_metrics(batch_times)
81 | return metrics
82 |
83 | # Convert the model to ONNX format
84 | def convert_model_to_onnx(m, model_dir='models/bert-chinese-v1'):
85 | save_onnx_dir = os.path.join(model_dir, 'onnx')
86 | if os.path.exists(save_onnx_dir):
87 | shutil.rmtree(save_onnx_dir)
88 | m.model.convert_to_onnx(save_onnx_dir)
89 | shutil.copy(m.label_vocab_path, save_onnx_dir)
90 | return save_onnx_dir
91 |
92 | # Main function
93 | batch_sizes = [1, 8, 16, 32, 64, 128] # Modify these values as appropriate
94 | model_dir = 'models/bert-chinese-v1'
95 | # Train the model once
96 | model = train_model(model_dir)
97 | # Convert to ONNX
98 | onnx_model_path = convert_model_to_onnx(model)
99 | del model
100 | torch.cuda.empty_cache()
101 |
102 | # Evaluate Standard BERT model performance
103 | for batch_size in batch_sizes:
104 | model = BertClassifier(output_dir=model_dir, num_classes=2, model_type='bert',
105 | model_name=model_dir,
106 | args={"eval_batch_size": batch_size, "onnx": False})
107 | model.load_model()
108 | metrics = evaluate_performance(model, batch_size)
109 | logger.info(
110 | f'Standard BERT model - Batch size: {batch_size}, total_requests: {metrics["total_requests"]}, '
111 | f'Average QPS: {metrics["average_qps"]:.2f}, Average Latency: {metrics["latency"]:.4f}')
112 | del model
113 | torch.cuda.empty_cache()
114 |
115 | # Load and evaluate ONNX model performance
116 | for batch_size in batch_sizes:
117 | onnx_model = BertClassifier(output_dir=onnx_model_path, num_classes=2, model_type='bert',
118 | model_name=onnx_model_path,
119 | args={"eval_batch_size": batch_size, "onnx": True})
120 | onnx_model.load_model()
121 | metrics = evaluate_performance(onnx_model, batch_size)
122 | logger.info(
123 | f'ONNX model - Batch size: {batch_size}, total_requests: {metrics["total_requests"]}, '
124 | f'Average QPS: {metrics["average_qps"]:.2f}, Average Latency: {metrics["latency"]:.4f}')
125 | del onnx_model
126 | torch.cuda.empty_cache()
127 |
128 | # Clean up
129 | shutil.rmtree('models')
130 |
131 |
132 | if __name__ == '__main__':
133 | unittest.main()
134 |
--------------------------------------------------------------------------------
/tests/test_bert_onnx_speed.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import os.path
7 | import shutil
8 | import sys
9 | import time
10 | import unittest
11 |
12 | import torch
13 |
14 | sys.path.append('..')
15 | from pytextclassifier import BertClassifier
16 |
17 |
18 | class ModelSpeedTestCase(unittest.TestCase):
19 | def test_classifier(self):
20 | m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2,
21 | model_type='bert', model_name='bert-base-chinese', num_epochs=1)
22 | data = [
23 | ('education', '名师指导托福语法技巧:名词的复数形式'),
24 | ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
25 | ('education', '公务员考虑越来越吃香,这是怎么回事?'),
26 | ('education', '公务员考虑越来越吃香,这是怎么回事1?'),
27 | ('education', '公务员考虑越来越吃香,这是怎么回事2?'),
28 | ('education', '公务员考虑越来越吃香,这是怎么回事3?'),
29 | ('education', '公务员考虑越来越吃香,这是怎么回事4?'),
30 | ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
31 | ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
32 | ('sports', '米兰客场8战不败国米10年连胜1'),
33 | ('sports', '米兰客场8战不败国米10年连胜2'),
34 | ('sports', '米兰客场8战不败国米10年连胜3'),
35 | ('sports', '米兰客场8战不败国米10年连胜4'),
36 | ('sports', '米兰客场8战不败国米10年连胜5'),
37 | ]
38 | m.train(data * 10)
39 | m.load_model()
40 |
41 | samples = ['名师指导托福语法技巧',
42 | '米兰客场8战不败',
43 | '恒生AH溢指收平 A股对H股折价1.95%'] * 100
44 |
45 | start_time = time.time()
46 | predict_label_bert, predict_proba_bert = m.predict(samples)
47 | print(f'predict_label_bert size: {len(predict_label_bert)}')
48 | self.assertEqual(len(predict_label_bert), 300)
49 | end_time = time.time()
50 | elapsed_time_bert = end_time - start_time
51 | print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds')
52 |
53 | # convert to onnx, and load onnx model to predict, speed up 10x
54 | save_onnx_dir = 'models/bert-chinese-v1/onnx'
55 | m.model.convert_to_onnx(save_onnx_dir)
56 | # copy label_vocab.json to save_onnx_dir
57 | if os.path.exists(m.label_vocab_path):
58 | shutil.copy(m.label_vocab_path, save_onnx_dir)
59 |
60 | # Manually delete the model and clear CUDA cache
61 | del m
62 | torch.cuda.empty_cache()
63 |
64 | m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir,
65 | args={"onnx": True})
66 | m.load_model()
67 | start_time = time.time()
68 | predict_label_bert, predict_proba_bert = m.predict(samples)
69 | print(f'predict_label_bert size: {len(predict_label_bert)}')
70 | end_time = time.time()
71 | elapsed_time_onnx = end_time - start_time
72 | print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
73 |
74 | self.assertEqual(len(predict_label_bert), 300)
75 | shutil.rmtree('models')
76 |
77 |
78 | if __name__ == '__main__':
79 | unittest.main()
80 |
--------------------------------------------------------------------------------
/tests/test_fasttext.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 | import os
7 | import unittest
8 |
9 | import sys
10 |
11 | sys.path.append('..')
12 | from pytextclassifier import FastTextClassifier
13 | import torch
14 |
15 |
16 | class SaveModelTestCase(unittest.TestCase):
17 | def test_classifier(self):
18 | m = FastTextClassifier(output_dir='models/fasttext')
19 | data = [
20 | ('education', 'Student debt to cost Britain billions within decades'),
21 | ('education', 'Chinese education for TV experiment'),
22 | ('sports', 'Middle East and Asia boost investment in top level sports'),
23 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
24 | ]
25 | m.train(data)
26 | m.load_model()
27 | samples = ['Abbott government spends $8 million on higher education media blitz',
28 | 'Middle East and Asia boost investment in top level sports']
29 | r, p = m.predict(samples)
30 | print(r, p)
31 | print('-' * 20)
32 | torch.save(m.model, 'models/model.pkl')
33 | model = torch.load('models/model.pkl')
34 | m.model = model
35 | r1, p1 = m.predict(samples)
36 | print(r1, p1)
37 | self.assertEqual(r, r1)
38 | self.assertEqual(p, p1)
39 |
40 | os.remove('models/model.pkl')
41 | import shutil
42 | shutil.rmtree('models')
43 |
44 |
45 | if __name__ == '__main__':
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/tests/test_lr_classification.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 |
7 | import unittest
8 |
9 | import sys
10 |
11 | sys.path.append('..')
12 | from pytextclassifier import ClassicClassifier
13 |
14 |
15 | class BaseTestCase(unittest.TestCase):
16 | def test_classifier(self):
17 | m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')
18 | data = [
19 | ('education', 'Student debt to cost Britain billions within decades'),
20 | ('education', 'Chinese education for TV experiment'),
21 | ('sports', 'Middle East and Asia boost investment in top level sports'),
22 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
23 | ]
24 | m.train(data)
25 | m.load_model()
26 | r, _ = m.predict(['Abbott government spends $8 million on higher education media blitz',
27 | 'Middle East and Asia boost investment in top level sports'])
28 | print(r)
29 | self.assertEqual(r[0], 'education')
30 | self.assertEqual(r[1], 'sports')
31 | test_data = [
32 | ('education', 'Abbott government spends $8 million on higher education media blitz'),
33 | ('sports', 'Middle East and Asia boost investment in top level sports'),
34 | ]
35 | acc_score = m.evaluate_model(test_data)
36 | print(acc_score) # 1.0
37 | self.assertEqual(acc_score, 1.0)
38 | import shutil
39 | shutil.rmtree('models')
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
44 |
--------------------------------------------------------------------------------
/tests/test_lr_vec.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author:XuMing(xuming624@qq.com)
4 | @description:
5 | """
6 |
7 | import unittest
8 |
9 | import sys
10 |
11 | sys.path.append('..')
12 | from pytextclassifier import ClassicClassifier
13 |
14 |
15 | class VecTestCase(unittest.TestCase):
16 | def setUp(self):
17 | m = ClassicClassifier('models/lr')
18 | data = [
19 | ('education', 'Student debt to cost Britain billions within decades'),
20 | ('education', 'Chinese education for TV experiment'),
21 | ('sports', 'Middle East and Asia boost investment in top level sports'),
22 | ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
23 | ]
24 | m.train(data)
25 | print('model trained:', m)
26 |
27 | def test_classifier(self):
28 | new_m = ClassicClassifier('models/lr')
29 | new_m.load_model()
30 | r, _ = new_m.predict(['Abbott government spends $8 million on higher education media blitz',
31 | 'Middle East and Asia boost investment in top level sports'])
32 | print(r)
33 | self.assertTrue(r[0] == 'education')
34 |
35 | def test_vec(self):
36 | new_m = ClassicClassifier('models/lr')
37 | new_m.load_model()
38 | print(new_m.feature.get_feature_names())
39 | print('feature name size:', len(new_m.feature.get_feature_names()))
40 | self.assertTrue(len(new_m.feature.get_feature_names()) > 0)
41 |
42 | def test_stopwords(self):
43 | new_m = ClassicClassifier('models/lr')
44 | new_m.load_model()
45 | stopwords = new_m.stopwords
46 | print(len(stopwords))
47 | self.assertTrue(len(stopwords) > 0)
48 |
49 | @classmethod
50 | def tearDownClass(cls):
51 | import shutil
52 |
53 | shutil.rmtree('models')
54 | print('remove dir: models')
55 |
56 |
57 | if __name__ == '__main__':
58 | unittest.main()
59 |
--------------------------------------------------------------------------------