├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── README_CH.md ├── data ├── map.zip ├── ner.pb ├── pos.pb └── seg.pb ├── fool ├── __init__.py ├── __main__.py ├── dictionary.py ├── lexical.py ├── model.py ├── predictor.py └── trie.py ├── requirements.txt ├── setup.py ├── test ├── __init__.py ├── dictonary.py ├── loadmodel.py └── test_dict.txt └── train ├── README.md ├── __init__.py ├── bert_predict.py ├── bi_lstm.py ├── create_map_file.py ├── data_utils.py ├── datasets └── demo │ ├── dev.txt │ ├── test.txt │ └── train.txt ├── decode.py ├── export_model.py ├── load_model.py ├── main.sh ├── norm_train_recoard.py ├── prepare_vec.py ├── text_to_tfrecords.py ├── tf_metrics.py ├── third_party └── word2vec │ ├── .gitignore │ ├── LICENSE │ ├── README.txt │ ├── compute-accuracy.c │ ├── demo-analogy.sh │ ├── demo-classes.sh │ ├── demo-phrase-accuracy.sh │ ├── demo-phrases.sh │ ├── demo-train-big-model-v1.sh │ ├── demo-word-accuracy.sh │ ├── demo-word.sh │ ├── distance.c │ ├── makefile │ ├── word-analogy.c │ ├── word2phrase.c │ └── word2vec.c ├── train_bert.sh ├── train_bert_ner.py └── word2vec.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | .idea/ 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | train/results/ 106 | train/datasets/demo/vec 107 | train/datasets/demo/*.pkl 108 | train/datasets/demo/.json 109 | *.tfrecord 110 | *.json -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "bert"] 2 | path = bert 3 | url = git@github.com:google-research/bert.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FoolNLTK 2 | A Chinese word processing toolkit 3 | 4 | [Chinese document](./README_CH.md) 5 | ## Features 6 | * Although not the fastest, FoolNLTK is probably the most accurate open source Chinese word segmenter in the market 7 | * Trained based on the [BiLSTM model](http://www.aclweb.org/anthology/N16-1030 ) 8 | * High-accuracy in participle, part-of-speech tagging, entity recognition 9 | * User-defined dictionary 10 | * Ability to self train models 11 | * Allows for batch processing 12 | 13 | 14 | ## Getting Started 15 | 16 | *** 2020/2/16 *** update: use bert model train and export model to deploy, [chinese train documentation](./train/README.md) 17 | 18 | 19 | 20 | To download and build FoolNLTK, type: 21 | 22 | ```bash 23 | get clone https://github.com/rockyzhengwu/FoolNLTK.git 24 | cd FoolNLTK/train 25 | 26 | ``` 27 | For detailed [instructions](./train/README.md) 28 | 29 | * Only tested in Linux Python 3 environment. 30 | 31 | 32 | ### Installation 33 | ```bash 34 | pip install foolnltk 35 | ``` 36 | 37 | 38 | ## Usage Intructions 39 | 40 | ##### For Participles: 41 | 42 | 43 | 44 | ``` 45 | import fool 46 | 47 | text = "一个傻子在北京" 48 | print(fool.cut(text)) 49 | # ['一个', '傻子', '在', '北京'] 50 | ``` 51 | 52 | For participle segmentations, specify a ```-b``` parameter to increase the number of lines segmented every run. 53 | 54 | ```bash 55 | python -m fool [filename] 56 | ``` 57 | 58 | ###### User-defined dictionary 59 | The format of the dictionary is as follows: the higher the weight of a word, and the longer the word length is, 60 | the more likely the word is to appear. Word weight value should be greater than 1。 61 | 62 | ``` 63 | 难受香菇 10 64 | 什么鬼 10 65 | 分词工具 10 66 | 北京 10 67 | 北京天安门 10 68 | ``` 69 | To load the dictionary: 70 | 71 | ```python 72 | import fool 73 | fool.load_userdict(path) 74 | text = ["我在北京天安门看你难受香菇", "我在北京晒太阳你在非洲看雪"] 75 | print(fool.cut(text)) 76 | #[['我', '在', '北京', '天安门', '看', '你', '难受', '香菇'], 77 | # ['我', '在', '北京', '晒太阳', '你', '在', '非洲', '看', '雪']] 78 | ``` 79 | 80 | To delete the dictionary 81 | ```python 82 | fool.delete_userdict(); 83 | ``` 84 | 85 | 86 | 87 | ##### POS tagging 88 | 89 | ``` 90 | import fool 91 | 92 | text = ["一个傻子在北京"] 93 | print(fool.pos_cut(text)) 94 | #[[('一个', 'm'), ('傻子', 'n'), ('在', 'p'), ('北京', 'ns')]] 95 | ``` 96 | 97 | 98 | ##### Entity Recognition 99 | ``` 100 | import fool 101 | 102 | text = ["一个傻子在北京","你好啊"] 103 | words, ners = fool.analysis(text) 104 | print(ners) 105 | #[[(5, 8, 'location', '北京')]] 106 | ``` 107 | 108 | ### Versions in Other languages 109 | * [Java](https://github.com/rockyzhengwu/JFoolNLTK) 110 | 111 | #### Note 112 | * For any missing model files, try looking in ```sys.prefix```, under ```/usr/local/``` 113 | -------------------------------------------------------------------------------- /README_CH.md: -------------------------------------------------------------------------------- 1 | # FoolNLTK 2 | 中文处理工具包 3 | 4 | ## 特点 5 | * 可能不是最快的开源中文分词,但很可能是最准的开源中文分词 6 | * 基于[BiLSTM模型](http://www.aclweb.org/anthology/N16-1030 )训练而成 7 | * 包含分词,词性标注,实体识别, 都有比较高的准确率 8 | * 用户自定义词典 9 | * 可训练自己的模型 10 | * 批量处理 11 | 12 | 13 | ## 定制自己的模型 14 | 15 | *** 2020/2/16 *** 更新: 添加 bert 训练的代码, 详情看[文档](./train/README.md) 16 | 17 | 18 | 19 | ```bash 20 | get clone https://github.com/rockyzhengwu/FoolNLTK.git 21 | cd FoolNLTK/train 22 | 23 | ``` 24 | 详细训练步骤可参考[文档](./train/README.md) 25 | 26 | 仅在linux Python3 环境测试通过 27 | 28 | 29 | ## Install 30 | ```bash 31 | pip install foolnltk 32 | ``` 33 | 34 | 35 | ## 使用说明 36 | 37 | ##### 分词 38 | 39 | 40 | 41 | ``` 42 | import fool 43 | 44 | text = "一个傻子在北京" 45 | print(fool.cut(text)) 46 | # ['一个', '傻子', '在', '北京'] 47 | ``` 48 | 命令行分词, 可指定```-b```参数,每次切割的行数能加快分词速度 49 | 50 | ```bash 51 | python -m fool [filename] 52 | ``` 53 | 54 | ###### 用户自定义词典 55 | 词典格式格式如下,词的权重越高,词的长度越长就越越可能出现, 权重值请大于1 56 | ``` 57 | 难受香菇 10 58 | 什么鬼 10 59 | 分词工具 10 60 | 北京 10 61 | 北京天安门 10 62 | ``` 63 | 加载词典 64 | 65 | ```python 66 | import fool 67 | fool.load_userdict(path) 68 | text = ["我在北京天安门看你难受香菇", "我在北京晒太阳你在非洲看雪"] 69 | print(fool.cut(text)) 70 | #[['我', '在', '北京', '天安门', '看', '你', '难受', '香菇'], 71 | # ['我', '在', '北京', '晒太阳', '你', '在', '非洲', '看', '雪']] 72 | ``` 73 | 74 | 删除词典 75 | ```python 76 | fool.delete_userdict(); 77 | ``` 78 | 79 | 80 | 81 | ##### 词性标注 82 | 83 | ``` 84 | import fool 85 | 86 | text = ["一个傻子在北京"] 87 | print(fool.pos_cut(text)) 88 | #[[('一个', 'm'), ('傻子', 'n'), ('在', 'p'), ('北京', 'ns')]] 89 | ``` 90 | 91 | 92 | ##### 实体识别 93 | ``` 94 | import fool 95 | 96 | text = ["一个傻子在北京","你好啊"] 97 | words, ners = fool.analysis(text) 98 | print(ners) 99 | #[[(5, 8, 'location', '北京')]] 100 | ``` 101 | 102 | ### 其他语言版本 103 | [Java版](https://github.com/rockyzhengwu/JFoolNLTK) 104 | 105 | #### 注意 106 | * 有找不到模型文件的, 可以看下```sys.prefix```,一般默认为```/usr/local/``` 107 | 108 | -------------------------------------------------------------------------------- /data/map.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/FoolNLTK/5e2cb40c98e98b7397684193ca232e33289a2354/data/map.zip -------------------------------------------------------------------------------- /data/ner.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/FoolNLTK/5e2cb40c98e98b7397684193ca232e33289a2354/data/ner.pb -------------------------------------------------------------------------------- /data/pos.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/FoolNLTK/5e2cb40c98e98b7397684193ca232e33289a2354/data/pos.pb -------------------------------------------------------------------------------- /data/seg.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/FoolNLTK/5e2cb40c98e98b7397684193ca232e33289a2354/data/seg.pb -------------------------------------------------------------------------------- /fool/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | 5 | 6 | import sys 7 | import logging 8 | from collections import defaultdict 9 | 10 | from fool import lexical 11 | from fool import dictionary 12 | from fool import model 13 | 14 | LEXICAL_ANALYSER = lexical.LexicalAnalyzer() 15 | _DICTIONARY = dictionary.Dictionary() 16 | 17 | __log_console = logging.StreamHandler(sys.stderr) 18 | DEFAULT_LOGGER = logging.getLogger(__name__) 19 | DEFAULT_LOGGER.setLevel(logging.DEBUG) 20 | DEFAULT_LOGGER.addHandler(__log_console) 21 | 22 | __all__= ["load_model", "cut", "pos_cut", "ner", "analysis", "load_userdict", "delete_userdict"] 23 | 24 | def load_model(map_file, model_file): 25 | m = model.Model(map_file=map_file, model_file=model_file) 26 | return m 27 | 28 | def _check_input(text, ignore=False): 29 | if not text: 30 | return [] 31 | 32 | if not isinstance(text, list): 33 | text = [text] 34 | 35 | null_index = [i for i, t in enumerate(text) if not t] 36 | if null_index and not ignore: 37 | raise Exception("null text in input ") 38 | 39 | return text 40 | 41 | def ner(text, ignore=False): 42 | text = _check_input(text, ignore) 43 | if not text: 44 | return [[]] 45 | res = LEXICAL_ANALYSER.ner(text) 46 | return res 47 | 48 | 49 | def analysis(text, ignore=False): 50 | text = _check_input(text, ignore) 51 | if not text: 52 | return [[]], [[]] 53 | res = LEXICAL_ANALYSER.analysis(text) 54 | return res 55 | 56 | 57 | def cut(text, ignore=False): 58 | 59 | text = _check_input(text, ignore) 60 | 61 | if not text: 62 | return [[]] 63 | 64 | text = [t for t in text if t] 65 | all_words = LEXICAL_ANALYSER.cut(text) 66 | new_words = [] 67 | if _DICTIONARY.sizes != 0: 68 | for sent, words in zip(text, all_words): 69 | words = _mearge_user_words(sent, words) 70 | new_words.append(words) 71 | else: 72 | new_words = all_words 73 | return new_words 74 | 75 | 76 | def pos_cut(text): 77 | words = cut(text) 78 | pos_labels = LEXICAL_ANALYSER.pos(words) 79 | word_inf = [list(zip(ws, ps)) for ws, ps in zip(words, pos_labels)] 80 | return word_inf 81 | 82 | 83 | def load_userdict(path): 84 | _DICTIONARY.add_dict(path) 85 | 86 | 87 | def delete_userdict(): 88 | _DICTIONARY.delete_dict() 89 | 90 | 91 | def _mearge_user_words(text, seg_results): 92 | if not _DICTIONARY: 93 | return seg_results 94 | 95 | matchs = _DICTIONARY.parse_words(text) 96 | graph = defaultdict(dict) 97 | text_len = len(text) 98 | 99 | for i in range(text_len): 100 | graph[i][i + 1] = 1.0 101 | 102 | index = 0 103 | 104 | for w in seg_results: 105 | w_len = len(w) 106 | graph[index][index + w_len] = _DICTIONARY.get_weight(w) + w_len 107 | index += w_len 108 | 109 | for m in matchs: 110 | graph[m.start][m.end] = _DICTIONARY.get_weight(m.keyword) * len(m.keyword) 111 | 112 | route = {} 113 | route[text_len] = (0, 0) 114 | 115 | for idx in range(text_len - 1, -1, -1): 116 | m = [((graph.get(idx).get(k) + route[k][0]), k) for k in graph.get(idx).keys()] 117 | mm = max(m) 118 | route[idx] = (mm[0], mm[1]) 119 | 120 | index = 0 121 | path = [index] 122 | words = [] 123 | 124 | while index < text_len: 125 | ind_y = route[index][1] 126 | path.append(ind_y) 127 | words.append(text[index:ind_y]) 128 | index = ind_y 129 | 130 | return words 131 | -------------------------------------------------------------------------------- /fool/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | 5 | """Fool command line interface.""" 6 | import sys 7 | import fool 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser(usage="%s -m fool [options] filename" % sys.executable, 11 | description="Fool command line interface.", 12 | epilog="If no filename specified, use STDIN instead.") 13 | 14 | parser.add_argument("-d", "--delimiter", metavar="DELIM", default=' / ', 15 | nargs='?', const=' ', 16 | help="use DELIM instead of ' / ' for word delimiter; or a space if it is used without DELIM") 17 | 18 | parser.add_argument("-p", "--pos", metavar="DELIM", nargs='?', const='_', 19 | help="enable POS tagging; if DELIM is specified, use DELIM instead of '_' for POS delimiter") 20 | 21 | 22 | parser.add_argument("-u", "--user_dict", 23 | help="use USER_DICT together with the default dictionary or DICT (if specified)") 24 | 25 | parser.add_argument("-b", "--batch_size", default=1, type = int ,help="batch size ") 26 | 27 | parser.add_argument("filename", nargs='?', help="input file") 28 | 29 | args = parser.parse_args() 30 | 31 | delim = args.delimiter 32 | plim = args.pos 33 | 34 | batch_zize = args.batch_size 35 | 36 | if args.user_dict: 37 | fool.load_userdict(args.user_dict) 38 | 39 | fp = open(args.filename, 'r') if args.filename else sys.stdin 40 | lines = fp.readlines(batch_zize) 41 | 42 | 43 | while lines: 44 | lines = [ln.strip("\r\n") for ln in lines] 45 | if args.pos: 46 | result_list = fool.pos_cut(lines) 47 | for res in result_list: 48 | out_str = [plim.join(p) for p in res] 49 | print(delim.join(out_str)) 50 | else: 51 | result_list = fool.cut(lines) 52 | for res in result_list: 53 | print(delim.join(res)) 54 | lines = fp.readlines(batch_zize) 55 | 56 | fp.close() 57 | -------------------------------------------------------------------------------- /fool/dictionary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | from fool import trie 5 | 6 | 7 | class Dictionary(): 8 | def __init__(self): 9 | self.trie = trie.Trie() 10 | self.weights = {} 11 | self.sizes = 0 12 | 13 | def delete_dict(self): 14 | self.trie = trie.Trie() 15 | self.weights = {} 16 | self.sizes = 0 17 | 18 | def add_dict(self, path): 19 | words = [] 20 | 21 | with open(path) as f: 22 | for i, line in enumerate(f): 23 | line = line.strip("\n").strip() 24 | if not line: 25 | continue 26 | line = line.split() 27 | word = line[0].strip() 28 | self.trie.add_keyword(word) 29 | if len(line) == 1: 30 | weight = 1.0 31 | else: 32 | weight = float(line[1]) 33 | weight = float(weight) 34 | self.weights[word] = weight 35 | words.append(word) 36 | self.sizes += len(self.weights) 37 | 38 | def parse_words(self, text): 39 | matchs = self.trie.parse_text(text) 40 | return matchs 41 | 42 | def get_weight(self, word): 43 | return self.weights.get(word, 0.1) 44 | -------------------------------------------------------------------------------- /fool/lexical.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | import sys 5 | import os 6 | import json 7 | 8 | from fool.predictor import Predictor 9 | 10 | from zipfile import ZipFile 11 | 12 | OOV_STR = "" 13 | 14 | 15 | def _load_map_file(path, char_map_name, id_map_name): 16 | with ZipFile(path) as myzip: 17 | with myzip.open('all_map.json') as myfile: 18 | content = myfile.readline() 19 | content = content.decode() 20 | data = json.loads(content) 21 | return data.get(char_map_name), data.get(id_map_name) 22 | 23 | 24 | class LexicalAnalyzer(object): 25 | def __init__(self): 26 | self.initialized = False 27 | self.map = None 28 | self.seg_model = None 29 | self.pos_model = None 30 | self.ner_model = None 31 | self.data_path = os.path.join(sys.prefix, "fool") 32 | self.map_file_path = os.path.join(self.data_path, "map.zip") 33 | 34 | 35 | def _load_model(self, model_namel, word_map_name, tag_name): 36 | seg_model_path = os.path.join(self.data_path, model_namel) 37 | char_to_id, id_to_seg = _load_map_file(self.map_file_path, word_map_name, tag_name) 38 | return Predictor(seg_model_path, char_to_id, id_to_seg) 39 | 40 | def _load_seg_model(self): 41 | self.seg_model = self._load_model("seg.pb", "char_map", "seg_map") 42 | 43 | def _load_pos_model(self): 44 | self.pos_model = self._load_model("pos.pb", "word_map", "pos_map") 45 | 46 | def _load_ner_model(self): 47 | self.ner_model = self._load_model("ner.pb", "char_map", "ner_map") 48 | 49 | def pos(self, seg_words_list): 50 | if not self.pos_model: 51 | self._load_pos_model() 52 | pos_labels = self.pos_model.predict(seg_words_list) 53 | return pos_labels 54 | 55 | def ner(self, text_list): 56 | if not self.ner_model: 57 | self._load_ner_model() 58 | 59 | ner_labels = self.ner_model.predict(text_list) 60 | all_entitys = [] 61 | 62 | for ti, text in enumerate(text_list): 63 | ens = [] 64 | entity = "" 65 | i = 0 66 | ner_label = ner_labels[ti] 67 | chars = list(text) 68 | 69 | for label, word in zip(ner_label, chars): 70 | i += 1 71 | 72 | if label == "O": 73 | continue 74 | 75 | lt = label.split("_")[1] 76 | lb = label.split("_")[0] 77 | 78 | if lb == "S": 79 | ens.append((i-1, i, lt, word)) 80 | elif lb == "B": 81 | entity = "" 82 | entity += word 83 | elif lb == "M": 84 | entity += word 85 | 86 | elif lb == "E": 87 | entity += word 88 | ens.append((i - len(entity), i, lt, entity)) 89 | entity = "" 90 | 91 | if entity: 92 | ens.append((i - len(entity), i, lt, entity)) 93 | all_entitys.append(ens) 94 | 95 | return all_entitys 96 | 97 | def cut(self, text_list): 98 | 99 | if not self.seg_model: 100 | self._load_seg_model() 101 | 102 | all_labels = self.seg_model.predict(text_list) 103 | sent_words = [] 104 | for ti, text in enumerate(text_list): 105 | words = [] 106 | N = len(text) 107 | seg_labels = all_labels[ti] 108 | tmp_word = "" 109 | for i in range(N): 110 | label = seg_labels[i] 111 | w = text[i] 112 | if label == "B": 113 | tmp_word += w 114 | elif label == "M": 115 | tmp_word += w 116 | elif label == "E": 117 | tmp_word += w 118 | words.append(tmp_word) 119 | tmp_word = "" 120 | else: 121 | tmp_word = "" 122 | words.append(w) 123 | if tmp_word: 124 | words.append(tmp_word) 125 | sent_words.append(words) 126 | return sent_words 127 | 128 | 129 | def analysis(self, text_list): 130 | words = self.cut(text_list) 131 | pos_labels = self.pos(words) 132 | ners = self.ner(text_list) 133 | word_inf = [list(zip(ws, ps)) for ws, ps in zip(words, pos_labels)] 134 | return word_inf, ners 135 | 136 | -------------------------------------------------------------------------------- /fool/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- 3 | 4 | 5 | import tensorflow as tf 6 | import pickle 7 | import numpy as np 8 | from tensorflow.contrib.crf import viterbi_decode 9 | 10 | def decode(logits, trans, sequence_lengths, tag_num): 11 | viterbi_sequences = [] 12 | small = -1000.0 13 | start = np.asarray([[small] * tag_num + [0]]) 14 | for logit, length in zip(logits, sequence_lengths): 15 | score = logit[:length] 16 | pad = small * np.ones([length, 1]) 17 | logits = np.concatenate([score, pad], axis=1) 18 | logits = np.concatenate([start, logits], axis=0) 19 | viterbi_seq, viterbi_score = viterbi_decode(logits, trans) 20 | viterbi_sequences .append(viterbi_seq[1:]) 21 | return viterbi_sequences 22 | 23 | 24 | def load_map(path): 25 | with open(path, 'rb') as f: 26 | char_to_id, tag_to_id, id_to_tag = pickle.load(f) 27 | return char_to_id, id_to_tag 28 | 29 | 30 | def load_graph(path): 31 | with tf.gfile.GFile(path, mode='rb') as f: 32 | graph_def = tf.GraphDef() 33 | graph_def.ParseFromString(f.read()) 34 | with tf.Graph().as_default() as graph: 35 | tf.import_graph_def(graph_def, name="prefix") 36 | return graph 37 | 38 | 39 | 40 | class Model(object): 41 | def __init__(self, map_file, model_file): 42 | self.char_to_id, self.id_to_tag = load_map(map_file) 43 | self.graph = load_graph(model_file) 44 | self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0") 45 | self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0") 46 | self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0") 47 | self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0") 48 | self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0") 49 | 50 | self.sess = tf.Session(graph=self.graph) 51 | self.sess.as_default() 52 | self.num_class = len(self.id_to_tag) 53 | 54 | 55 | def predict(self, sents): 56 | inputs = [] 57 | lengths = [len(text) for text in sents] 58 | max_len = max(lengths) 59 | 60 | for sent in sents: 61 | sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("") for w in sent] 62 | padding = [0] * (max_len - len(sent_ids)) 63 | sent_ids += padding 64 | inputs.append(sent_ids) 65 | 66 | inputs = np.array(inputs, dtype=np.int32) 67 | 68 | feed_dict = { 69 | self.input_x: inputs, 70 | self.lengths: lengths, 71 | self.dropout: 1.0 72 | } 73 | 74 | logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict) 75 | path = decode(logits, trans, lengths, self.num_class) 76 | labels = [[self.id_to_tag.get(l) for l in p] for p in path] 77 | return labels 78 | 79 | 80 | -------------------------------------------------------------------------------- /fool/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | from tensorflow.contrib.crf import viterbi_decode 9 | 10 | 11 | def decode(logits, trans, sequence_lengths, tag_num): 12 | viterbi_sequences = [] 13 | small = -1000.0 14 | start = np.asarray([[small] * tag_num + [0]]) 15 | for logit, length in zip(logits, sequence_lengths): 16 | score = logit[:length] 17 | pad = small * np.ones([length, 1]) 18 | score = np.concatenate([score, pad], axis=1) 19 | score = np.concatenate([start, score], axis=0) 20 | viterbi_seq, viterbi_score = viterbi_decode(score, trans) 21 | viterbi_sequences.append(viterbi_seq[1:]) 22 | return viterbi_sequences 23 | 24 | 25 | 26 | def list_to_array(data_list, dtype=np.int32): 27 | array = np.array(data_list, dtype).reshape(1, len(data_list)) 28 | return array 29 | 30 | 31 | def load_graph(path): 32 | with tf.gfile.GFile(path, "rb") as f: 33 | graph_def = tf.GraphDef() 34 | graph_def.ParseFromString(f.read()) 35 | with tf.Graph().as_default() as graph: 36 | tf.import_graph_def(graph_def, name="prefix") 37 | return graph 38 | 39 | 40 | class Predictor(object): 41 | def __init__(self, model_file, char_to_id, id_to_tag): 42 | 43 | self.char_to_id = char_to_id 44 | self.id_to_tag = {int(k):v for k,v in id_to_tag.items()} 45 | self.graph = load_graph(model_file) 46 | 47 | self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0") 48 | self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0") 49 | self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0") 50 | self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0") 51 | self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0") 52 | 53 | self.sess = tf.Session(graph=self.graph) 54 | self.sess.as_default() 55 | self.num_class = len(self.id_to_tag) 56 | 57 | def predict(self, sents): 58 | inputs = [] 59 | lengths = [len(text) for text in sents] 60 | max_len = max(lengths) 61 | 62 | for sent in sents: 63 | sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("") for w in sent] 64 | padding = [0] * (max_len - len(sent_ids)) 65 | sent_ids += padding 66 | inputs.append(sent_ids) 67 | inputs = np.array(inputs, dtype=np.int32) 68 | 69 | feed_dict = { 70 | self.input_x: inputs, 71 | self.lengths: lengths, 72 | self.dropout: 1.0 73 | } 74 | 75 | logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict) 76 | path = decode(logits, trans, lengths, self.num_class) 77 | labels = [[self.id_to_tag.get(l) for l in p] for p in path] 78 | return labels 79 | -------------------------------------------------------------------------------- /fool/trie.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | import queue 5 | 6 | 7 | class Match(object): 8 | def __init__(self, start, end, keyword): 9 | self.start = start 10 | self.end = end 11 | self.keyword = keyword 12 | 13 | def __str__(self): 14 | return str(self.start) + ":" + str(self.end) + "=" + self.keyword 15 | 16 | 17 | class State(object): 18 | def __init__(self, word, deepth): 19 | self.success = {} 20 | self.failure = None 21 | self.emits = set() 22 | self.deepth = deepth 23 | 24 | def add_word(self, word): 25 | if word in self.success: 26 | return self.success.get(word) 27 | else: 28 | state = State(word, self.deepth + 1) 29 | self.success[word] = state 30 | return state 31 | 32 | def add_one_emit(self, keyword): 33 | self.emits.add(keyword) 34 | 35 | def add_emits(self, emits): 36 | if not isinstance(emits, set): 37 | raise Exception("keywords need a set") 38 | self.emits = self.emits | emits 39 | 40 | def set_failure(self, state): 41 | self.failure = state 42 | 43 | def get_transitions(self): 44 | return self.success.keys() 45 | 46 | def next_state(self, word): 47 | return self.success.get(word) 48 | 49 | 50 | class Trie(object): 51 | def __init__(self, words=[]): 52 | 53 | self.root = State("", 0) 54 | self.root.set_failure(self.root) 55 | self.is_create_failure = False 56 | if words: 57 | self.create_trie(words) 58 | 59 | def create_trie(self, keyword_list): 60 | 61 | for keyword in keyword_list: 62 | self.add_keyword(keyword) 63 | return self 64 | 65 | def add_keyword(self, keyword): 66 | current_state = self.root 67 | word_list = list(keyword) 68 | 69 | for word in word_list: 70 | current_state = current_state.add_word(word) 71 | current_state.add_one_emit(keyword) 72 | 73 | def create_failure(self): 74 | root = self.root 75 | state_queue = queue.Queue() 76 | 77 | for k, v in self.root.success.items(): 78 | state_queue.put(v) 79 | v.set_failure(root) 80 | 81 | while (not state_queue.empty()): 82 | 83 | current_state = state_queue.get() 84 | transitions = current_state.get_transitions() 85 | 86 | for word in transitions: 87 | target_state = current_state.next_state(word) 88 | 89 | state_queue.put(target_state) 90 | trace_state = current_state.failure 91 | 92 | while (trace_state.next_state(word) is None and trace_state.deepth != 0): 93 | trace_state = trace_state.failure 94 | 95 | if trace_state.next_state(word) is not None: 96 | target_state.set_failure(trace_state.next_state(word)) 97 | target_state.add_emits(trace_state.next_state(word).emits) 98 | else: 99 | target_state.set_failure(trace_state) 100 | self.is_create_failure = True 101 | 102 | def get_state(self, current_state, word): 103 | new_current_state = current_state.next_state(word) 104 | 105 | while (new_current_state is None and current_state.deepth != 0): 106 | current_state = current_state.failure 107 | new_current_state = current_state.next_state(word) 108 | 109 | return new_current_state 110 | 111 | def parse_text(self, text, allow_over_laps=True): 112 | matchs = [] 113 | if not self.is_create_failure: 114 | self.create_failure() 115 | 116 | position = 0 117 | current_state = self.root 118 | for word in list(text): 119 | position += 1 120 | current_state = self.get_state(current_state, word) 121 | if not current_state: 122 | current_state = self.root 123 | continue 124 | for mw in current_state.emits: 125 | m = Match(position - len(mw), position, mw) 126 | matchs.append(m) 127 | # todo remove over laps 128 | return matchs 129 | 130 | 131 | def create_trie(words): 132 | return Trie().create_trie(words) 133 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=1.3.0 2 | numpy>=1.12.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | 5 | 6 | from distutils.core import setup 7 | 8 | setup(name = 'foolnltk', 9 | version = '0.1.7', 10 | description = 'Fool Nature Language toolkit ', 11 | author = 'wu.zheng', 12 | author_email = 'rocky.zheng314@gmail.com', 13 | url = 'https://github.com/rockyzhengwu/FoolNLTK', 14 | install_requires=['tensorflow>=1.3.0', 'numpy>=1.12.1'], 15 | packages = ['fool',], 16 | package_dir = {"fool": "fool"}, 17 | data_files = [("fool", ["data/map.zip", "data/pos.pb", "data/seg.pb", "data/ner.pb"])] 18 | ) 19 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- -------------------------------------------------------------------------------- /test/dictonary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | import fool 5 | 6 | text = ["我在北京天安门看你难受香菇,一一千四百二十九", "我在北京晒太阳你在非洲看雪", "千年不变的是什么", "我在北京天安门。"] 7 | 8 | print("no dict:", fool.cut(text, ignore=True)) 9 | fool.load_userdict("./test_dict.txt") 10 | print("use dict: ", fool.cut(text)) 11 | fool.delete_userdict() 12 | print("delete dict:", fool.cut(text)) 13 | 14 | pos_words =fool.pos_cut(text) 15 | print("pos result", pos_words) 16 | 17 | words, ners = fool.analysis(text) 18 | print("ners: ", ners) 19 | 20 | ners = fool.ner(text) 21 | print("ners:", ners) -------------------------------------------------------------------------------- /test/loadmodel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- 3 | 4 | 5 | 6 | import fool 7 | 8 | map_file = "../datasets/demo/maps.pkl" 9 | checkpoint_ifle = "../results/demo_seg/modle.pb" 10 | 11 | smodel = fool.load_model(map_file=map_file, model_file=checkpoint_ifle) 12 | tags = smodel.predict(["北京欢迎你", "你在哪里"]) 13 | print(tags) 14 | -------------------------------------------------------------------------------- /test/test_dict.txt: -------------------------------------------------------------------------------- 1 | 难受香菇 10 2 | 什么鬼 10 3 | 分词工具 10 4 | 北京 10 5 | 北京天安门 10 6 | 二十 10 7 | 四百 10 8 | 一千 10 -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # FoolNLTK-train 2 | FoolNLTK training process 3 | 4 | 增加了利用 bert 训练实体模型的代码,可以导出 pb 文件,并用 python 加载实现线上部署 5 | 模型没使用 crf ,而是直接用交叉熵作为损失,效果并没有太多损失 6 | 7 | 1. 模型训练 8 | data_dir 存放训练数据格式如 datasets/demo 下。下载与训练的模型,我这里是将下载的模型软链接到 pretrainmodel 下 9 | 10 | ```shell script 11 | 12 | python ./train_bert_ner.py --data_dir=data/bid_train_data \ 13 | --bert_config_file=./pretrainmodel/bert_config.json \ 14 | --init_checkpoint=./pretrainmodel/bert_model.ckpt \ 15 | --vocab_file=./pretrainmodel/vocab.txt \ 16 | --output_dir=./output/all_bid_result_dir/ --do_train 17 | 18 | ``` 19 | 20 | 2. 模型导出 21 | predict 同时指定 do_export 就能导出 pb 格式的模型,用于部署 22 | ```shell script 23 | python ./train_bert_ner.py --data_dir=data/bid_train_data \ 24 | --bert_config_file=./pretrainmodel/bert_config.json \ 25 | --init_checkpoint=./pretrainmodel/bert_model.ckpt \ 26 | --vocab_file=vocab.txt \ 27 | --output_dir=./output/all_bid_result_dir/ --do_predict --do_export 28 | ``` 29 | 30 | 在 bert_predict.py 中指定下面三个参数就能加载训练好的模型完成预测: 31 | ```python 32 | VOCAB_FILE = './pretrainmodel/vocab.txt' 33 | LABEL_FILE = './output/label2id.pkl' 34 | EXPORT_PATH = './export_models/1581318324' 35 | ``` 36 | 37 | 代码参考: 38 | - [bert-chinese-ner](https://github.com/ProHiryu/bert-chinese-ner) 39 | - [BERT-NER](https://github.com/kyzhouhzau/BERT-NER) 40 | 41 | ## 1.train file 42 | 43 | 训练数据的格式和CRF++的训练数据一致,每一列用`\t`分隔,每一个句子用`\n`分隔 44 | 45 | exmpale: 46 | 47 | ``` 48 | 北 B 49 | 京 E 50 | 欢 B 51 | 迎 E 52 | 你 S 53 | ``` 54 | 55 | 56 | ```bash 57 | cd train 58 | vim main.sh 59 | ``` 60 | 61 | 在```main.sh```中指定以下参数 62 | 63 | ```bash 64 | 65 | # train file 66 | TRAIN_FILE=./datasets/demo/train.txt 67 | # dev file 68 | DEV_FILE=./datasets/demo/dev.txt 69 | # test file 70 | TEST_FILE=./datasets/demo/test.txt 71 | 72 | # data out dir 73 | DATA_OUT_DIR=./datasets/demo 74 | 75 | # model save dir 76 | MODEL_OUT_DIR=./results/demo_seg/ 77 | 78 | # label tag column index 79 | TAG_INDEX=1 80 | 81 | # max length of sentlen 82 | MAX_LENGTH=100 83 | 84 | ``` 85 | 86 | ## 2.embeding 87 | 88 | 编译word2vec 89 | 90 | ```bash 91 | cd third_paty && make 92 | 93 | ``` 94 | 95 | 在```main.sh```中指定 word2vec 路径 96 | 97 | ``` 98 | WORD2VEC=./third_party/word2vec/word2vec 99 | ```` 100 | 101 | 默认使用word2vec 训练字向量 102 | 103 | ```bash 104 | ./main.sh vec 105 | ``` 106 | 107 | ## 3.map file 108 | 这一步产生需要的映射文件 109 | 110 | ```bash 111 | ./main.sh map 112 | ``` 113 | 114 | ## 4.tfrecord 115 | 为了处理好内存,先把训练数据转换成tfrecord格式 116 | 117 | ```bash 118 | ./mainsh data 119 | ``` 120 | 121 | ## 5.train 122 | ```bash 123 | ./main.sh train 124 | ``` 125 | 126 | ## export model 127 | 训练好的模型导出成.pb文件,导出路径见 ```main.sh``` 中```MODEL_PATH``` 128 | 下面这个命令会导出最新的模型文件 129 | 130 | ```bash 131 | ./main.sh export 132 | ``` 133 | 134 | ## load model 135 | 训练好模型,现在可以直接调用 136 | ```python 137 | 138 | import fool 139 | 140 | map_file = "./datasets/demo/maps.pkl" 141 | checkpoint_ifle = "./results/demo_seg/modle.pb" 142 | 143 | smodel = fool.load_model(map_file=map_file, model_file=checkpoint_ifle) 144 | tags = smodel.predict(["北京欢迎你", "你在哪里"]) 145 | print(tags) 146 | 147 | ``` 148 | 149 | ## 注 150 | 151 | 如果需要新增新的特征,要修改很多代码,请看懂后随意修改,**没有东西是完全正确的当然也包括我的代码**。 -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- -------------------------------------------------------------------------------- /train/bert_predict.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | # !/usr/bin/env python 4 | # -*- coding:utf-8 -*- 5 | # author: wu.zheng midday.me 6 | 7 | import collections 8 | import tensorflow as tf 9 | from tensorflow.contrib import predictor 10 | from bert import tokenization 11 | import pickle 12 | 13 | 14 | VOCAB_FILE = './pretrainmodel/vocab.txt' 15 | LABEL_FILE = './output/label2id.pkl' 16 | EXPORT_PATH = './export_models/1581318324' 17 | 18 | 19 | def create_int_feature(values): 20 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 21 | return f 22 | 23 | 24 | def convert_single_example(words, label_map, max_seq_length, tokenizer, mode): 25 | max_seq_length = len(words) + 4 26 | textlist = words 27 | tokens = [] 28 | for i, word in enumerate(textlist): 29 | token = tokenizer.tokenize(word) 30 | tokens.extend(token) 31 | if len(tokens) >= max_seq_length - 1: 32 | tokens = tokens[0:(max_seq_length - 2)] 33 | ntokens = [] 34 | segment_ids = [] 35 | ntokens.append("[CLS]") 36 | segment_ids.append(0) 37 | for i, token in enumerate(tokens): 38 | ntokens.append(token) 39 | segment_ids.append(0) 40 | ntokens.append("[SEP]") 41 | segment_ids.append(0) 42 | print(ntokens) 43 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 44 | input_mask = [1] * len(input_ids) 45 | while len(input_ids) < max_seq_length: 46 | input_ids.append(0) 47 | input_mask.append(0) 48 | segment_ids.append(0) 49 | ntokens.append("**NULL**") 50 | assert len(input_ids) == max_seq_length 51 | assert len(input_mask) == max_seq_length 52 | assert len(segment_ids) == max_seq_length 53 | print("input_ids: ", input_ids) 54 | features = collections.OrderedDict() 55 | features["input_ids"] = create_int_feature(input_ids) 56 | features["input_mask"] = create_int_feature(input_mask) 57 | features["segment_ids"] = create_int_feature(segment_ids) 58 | return features 59 | 60 | 61 | class Predictor(object): 62 | # LABELS= ['O', "B-COMAPNY", "I-COMAPNY", "B-REAL", "I-REAL", "B-AMOUT", "I-AMOUT", "[CLS]", "[SEP]"] 63 | def __init__(self, export_model_path, vocab_file): 64 | self.export_model_path = export_model_path 65 | self.vocab_file = vocab_file 66 | self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) 67 | self.predict_fn = predictor.from_saved_model(self.export_model_path) 68 | self.label_map = pickle.load(open(LABEL_FILE, 'rb')) 69 | self.id_to_label = {v: k for k, v in self.label_map.items()} 70 | 71 | def create_example(self, content): 72 | words = list(content) 73 | words = [tokenization.convert_to_unicode(w) for w in words] 74 | features = convert_single_example(words, self.label_map, 256, self.tokenizer, "predict") 75 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 76 | return tf_example.SerializeToString() 77 | 78 | def predict(self, conent): 79 | example_str = self.create_example(content) 80 | output = self.predict_fn({"inputs": example_str})['output'] 81 | pred_labels = output[0][1:-1] 82 | predict_labels = [l for l in pred_labels if l != 0] 83 | print(self.id_to_label) 84 | print(pred_labels) 85 | pred_labels = [self.id_to_label[i] for i in predict_labels] 86 | return pred_labels 87 | 88 | 89 | if __name__ == "__main__": 90 | predict_model = Predictor(EXPORT_PATH, VOCAB_FILE) 91 | content = "河北远大工程咨询有限公司受石家庄市藁城区环境卫生服务站的委托,对石家庄市藁城区2019年新建公厕及垃圾转运站工程进行,并于2019-12-1009:30:00开标、评标,开评标会结束后根据有关法律、法规要求,现对中标候选人进行公员会评审,确定中标候选人为:第一中标候选人:石家庄市藁城区盛安建筑有限公司投标报价:7409295.86第二中标候选人:河北一方建设工程有限公司投标报价:7251181.03第三中标候选人:石家庄卓晟建筑工程有限公司投标报价:709" 92 | labels = predict_model.predict(content) 93 | print(labels) 94 | -------------------------------------------------------------------------------- /train/bi_lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | 5 | import tensorflow as tf 6 | from tensorflow.contrib import rnn 7 | from tensorflow.contrib.crf import crf_log_likelihood 8 | from tensorflow.contrib.layers.python.layers import initializers 9 | 10 | 11 | class BiLSTM(object): 12 | def __init__(self, config, embeddings): 13 | 14 | self.config = config 15 | 16 | self.lstm_dim = config["lstm_dim"] 17 | self.num_chars = config["num_chars"] 18 | self.num_tags = config["num_tags"] 19 | self.char_dim = config["char_dim"] 20 | self.lr = config["lr"] 21 | 22 | 23 | self.char_embeding = tf.get_variable(name="char_embeding", initializer=embeddings) 24 | 25 | self.global_step = tf.Variable(0, trainable=False) 26 | self.initializer = initializers.xavier_initializer() 27 | 28 | self.char_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None], name="char_inputs") 29 | self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="targets") 30 | self.dropout = tf.placeholder(dtype=tf.float32, name="dropout") 31 | self.lengths = tf.placeholder(dtype=tf.int32, shape=[None, ], name="lengths") 32 | 33 | 34 | # self.middle_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="middle_dropout_keep_prob") 35 | # self.hidden_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="hidden_dropout_keep_prob") 36 | 37 | self.input_dropout_keep_prob = tf.placeholder_with_default(config["input_dropout_keep"], [], name="input_dropout_keep_prob") 38 | 39 | self.batch_size = tf.shape(self.char_inputs)[0] 40 | self.num_steps = tf.shape(self.char_inputs)[-1] 41 | 42 | # forward 43 | embedding = self.embedding_layer(self.char_inputs) 44 | lstm_inputs = tf.nn.dropout(embedding, self.input_dropout_keep_prob) 45 | 46 | ## bi-directional lstm layer 47 | lstm_outputs = self.bilstm_layer(lstm_inputs) 48 | ## logits for tags 49 | self.project_layer(lstm_outputs) 50 | ## loss of the model 51 | self.loss = self.loss_layer(self.logits, self.lengths) 52 | 53 | 54 | with tf.variable_scope("optimizer"): 55 | optimizer = self.config["optimizer"] 56 | if optimizer == "sgd": 57 | self.opt = tf.train.GradientDescentOptimizer(self.lr) 58 | elif optimizer == "adam": 59 | self.opt = tf.train.AdamOptimizer(self.lr) 60 | elif optimizer == "adgrad": 61 | self.opt = tf.train.AdagradOptimizer(self.lr) 62 | else: 63 | raise KeyError 64 | grads_vars = self.opt.compute_gradients(self.loss) 65 | capped_grads_vars = [[tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v] for g, v in grads_vars] 66 | self.train_op = self.opt.apply_gradients(capped_grads_vars, self.global_step) 67 | 68 | 69 | def embedding_layer(self, char_inputs): 70 | with tf.variable_scope("char_embedding"), tf.device('/cpu:0'): 71 | embed = tf.nn.embedding_lookup(self.char_embeding, char_inputs) 72 | return embed 73 | 74 | 75 | def bilstm_layer(self, lstm_inputs, name=None): 76 | with tf.variable_scope("char_bilstm" if not name else name): 77 | lstm_fw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True) 78 | lstm_bw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True) 79 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs, dtype=tf.float32, sequence_length=self.lengths) 80 | return tf.concat(outputs, axis=2) 81 | 82 | def project_layer(self, lstm_outputs, name=None): 83 | """ 84 | """ 85 | with tf.variable_scope("project" if not name else name): 86 | with tf.variable_scope("hidden"): 87 | w_tanh = tf.get_variable("w_tanh", shape=[self.lstm_dim * 2, self.lstm_dim], 88 | dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001)) 89 | 90 | b_tanh = tf.get_variable("b_tanh", shape=[self.lstm_dim], dtype=tf.float32, 91 | initializer=tf.zeros_initializer()) 92 | 93 | output = tf.reshape(lstm_outputs, shape=[-1, self.lstm_dim * 2]) 94 | hidden = tf.tanh(tf.nn.xw_plus_b(output, w_tanh, b_tanh)) 95 | 96 | drop_hidden = tf.nn.dropout(hidden, self.dropout) 97 | 98 | 99 | # project to score of tags 100 | with tf.variable_scope("output"): 101 | w_out = tf.get_variable("w_out", shape=[self.lstm_dim, self.num_tags], 102 | dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001)) 103 | 104 | b_out = tf.get_variable("b_out", shape=[self.num_tags], dtype=tf.float32, 105 | initializer=tf.zeros_initializer() ) 106 | pred = tf.nn.xw_plus_b(drop_hidden, w_out, b_out, name="pred") 107 | self.logits = tf.reshape(pred, [-1, self.num_steps, self.num_tags], name="logits") 108 | 109 | 110 | def loss_layer(self, project_logits, lengths, name=None): 111 | 112 | with tf.variable_scope("crf_loss" if not name else name): 113 | small = -1000.0 114 | start_logits = tf.concat( 115 | [small * tf.ones(shape=[self.batch_size, 1, self.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])], 116 | axis=-1) 117 | 118 | pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32) 119 | logits = tf.concat([project_logits, pad_logits], axis=-1) 120 | logits = tf.concat([start_logits, logits], axis=1) 121 | targets = tf.concat( 122 | [tf.cast(self.num_tags * tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1) 123 | 124 | self.trans = tf.get_variable( 125 | "transitions", 126 | shape=[self.num_tags + 1, self.num_tags + 1], 127 | initializer=self.initializer) 128 | 129 | log_likelihood, self.trans = crf_log_likelihood( 130 | inputs=logits, 131 | tag_indices=targets, 132 | transition_params=self.trans, 133 | sequence_lengths=lengths + 1) 134 | 135 | return tf.reduce_mean(-log_likelihood) 136 | -------------------------------------------------------------------------------- /train/create_map_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | import os 5 | import pickle 6 | import argparse 7 | import json 8 | 9 | 10 | OOV_STR = "" 11 | PAD_STR = "" 12 | 13 | 14 | def create_vocab(embeding_file): 15 | 16 | vocab_dict = {} 17 | vocab_dict[PAD_STR] = 0 18 | vocab_dict[OOV_STR] = 1 19 | 20 | f = open(embeding_file, errors="ignore") 21 | m, n = f.readline().split(" ") 22 | n = int(n) 23 | m = int(m) 24 | print("preembeding size : %d"%(m)) 25 | 26 | for i, line in enumerate(f): 27 | word = line.split()[0] 28 | if not word: 29 | continue 30 | if word not in vocab_dict: 31 | vocab_dict[word] = len(vocab_dict) 32 | print("vocab size : %d" % len(vocab_dict)) 33 | return vocab_dict 34 | 35 | 36 | def tag_to_map(train_file, tag_index=-1): 37 | f = open(train_file) 38 | tag_to_id = {} 39 | 40 | for i, line in enumerate(f): 41 | line = line.strip("\n") 42 | if not line: 43 | continue 44 | data = line.split("\t") 45 | # todo hand this in train file 46 | 47 | tag = data[tag_index] 48 | if tag not in tag_to_id: 49 | tag_to_id[tag] = len(tag_to_id) 50 | 51 | id_to_tag = {v: k for k, v in tag_to_id.items()} 52 | 53 | print("tag num in %s: %d "%(train_file, len(tag_to_id))) 54 | return tag_to_id, id_to_tag 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--train_file", required=True, help="train file path") 60 | parser.add_argument("--embeding_file", required=True, help="embeding file path") 61 | parser.add_argument("--map_file", required=True, help="the out dir of map file") 62 | parser.add_argument("--tag_index", required=False, type=int, default=-1, help="the column num of taget tag in train_file ") 63 | parser.add_argument("--size_file", required=True, help="save size to") 64 | args = parser.parse_args() 65 | 66 | # if not os.path.exists(args.out_dir): 67 | # os.mkdir(args.out_dir) 68 | 69 | with open(args.map_file, 'wb') as f: 70 | vocab = create_vocab(embeding_file=args.embeding_file) 71 | tag_to_id, id_to_tag = tag_to_map(train_file=args.train_file, tag_index=args.tag_index) 72 | print("tag map result: ") 73 | print(tag_to_id) 74 | pickle.dump((vocab, tag_to_id, id_to_tag), f) 75 | vocab_size = len(vocab) 76 | num_class = len(tag_to_id) 77 | print("vocab size :%d, num of tag : %d"%(vocab_size, num_class)) 78 | size_file = open(args.size_file, 'w') 79 | json.dump({"vocab_size": vocab_size, "num_tag":num_class}, size_file) 80 | -------------------------------------------------------------------------------- /train/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- 3 | 4 | import math 5 | import random 6 | import pickle 7 | import re 8 | import json 9 | import tensorflow as tf 10 | 11 | 12 | def load_size_file(size_filename): 13 | with open(size_filename, 'r') as f: 14 | print(size_filename) 15 | num_obj = json.load(f) 16 | return num_obj 17 | 18 | 19 | 20 | def load_sample_size(filename): 21 | with open(filename) as f: 22 | train_num, dev_num, test_num = map(int, f.readline().split(",")) 23 | return train_num, dev_num, test_num 24 | 25 | 26 | 27 | def reverse_dict(d): 28 | return {v: k for k, v in d.items()} 29 | 30 | 31 | def load_map_file(path, ): 32 | vocab, tag_to_id, id_to_tag = pickle.load(open(path, 'rb')) 33 | return vocab, tag_to_id, id_to_tag 34 | 35 | 36 | def load_sentences(path, indexs): 37 | f = open(path) 38 | sentences = [] 39 | sent = [] 40 | for i, line in enumerate(f): 41 | line = line.strip("\n") 42 | if not line: 43 | if sent: 44 | sentences.append(sent) 45 | sent = [] 46 | continue 47 | word_info = line.split("\t") 48 | word_info = [word_info[i] for i in indexs] 49 | sent.append(word_info) 50 | return sentences 51 | 52 | 53 | def create_dataset(data_list, word_to_id, tag_to_id): 54 | data_set = [] 55 | oov_count = 0 56 | total_count = 0 57 | space_word_count = 0 58 | for data in data_list: 59 | sent = [] 60 | tag = [] 61 | for w in data: 62 | total_count += 1 63 | if w[0] in word_to_id: 64 | sent.append(word_to_id[w[0]]) 65 | elif w[0] in ["\t", "\n", "\r", " "]: 66 | space_word_count += 1 67 | sent.append(word_to_id[""]) 68 | else: 69 | sent.append(word_to_id[""]) 70 | oov_count += 1 71 | tag.append(tag_to_id[w[-1]]) 72 | data_set.append([sent, tag]) 73 | print("space word count: %d "%(space_word_count)) 74 | print("dataset oov count:%d percent: %f"%(oov_count, 1.0 * oov_count / total_count)) 75 | return data_set 76 | 77 | en_p = re.compile('[a-zA-Z]', re.U) 78 | re_han = re.compile("([\u4E00-\u9FD5]+)") 79 | 80 | def get_char_type(ch): 81 | """ 82 | 0, 汉字 83 | 1, 英文字母 84 | 2. 数字 85 | 3. 其他 86 | """ 87 | if re.match(en_p, ch): 88 | return 1 89 | elif re.match("\d+", ch): 90 | return 2 91 | elif re.match(re_han, ch): 92 | return 3 93 | else: 94 | return 4 95 | 96 | 97 | def create_ner_dataset(data_list, word_to_id, tag_to_id, pos_to_id, seg_to_id): 98 | data_set = [] 99 | oov_count = 0 100 | total_count = 0 101 | for data in data_list: 102 | sent = [] 103 | tag = [] 104 | seg_list = [] 105 | pos_list = [] 106 | char_type_list = [] 107 | for w in data: 108 | total_count += 1 109 | if w[0] in word_to_id: 110 | sent.append(word_to_id[w[0]]) 111 | else: 112 | sent.append(word_to_id[""]) 113 | oov_count += 1 114 | tag.append(tag_to_id[w[-1]]) 115 | seg_label = w[1] 116 | seg_list.append(seg_to_id[seg_label]) 117 | pos_list.append(pos_to_id[w[2]]) 118 | char_type_list.append(get_char_type(w[0])) 119 | data_set.append([sent, seg_list, pos_list, char_type_list, tag]) 120 | print("dataset oov count:%d percent: %f" % (oov_count, 1.0 * oov_count / total_count)) 121 | return data_set 122 | 123 | 124 | 125 | class BatchManager(object): 126 | 127 | def __init__(self, data, batch_size): 128 | self.batch_data = self.sort_and_pad(data, batch_size) 129 | self.len_data = len(self.batch_data) 130 | 131 | def sort_and_pad(self, data, batch_size): 132 | num_batch = int(math.ceil(len(data) / batch_size)) 133 | sorted_data = sorted(data, key=lambda x: len(x[0])) 134 | batch_data = list() 135 | for i in range(num_batch): 136 | batch_data.append(self.pad_data(sorted_data[i*batch_size : (i+1)*batch_size])) 137 | return batch_data 138 | 139 | @staticmethod 140 | def pad_data(data): 141 | strings = [] 142 | targets = [] 143 | max_length = max([len(sentence[0]) for sentence in data]) 144 | for line in data: 145 | string, target = line 146 | padding = [0] * (max_length - len(string)) 147 | strings.append(string + padding) 148 | targets.append(target + padding) 149 | return [strings, targets] 150 | 151 | 152 | def iter_batch(self, shuffle=False): 153 | if shuffle: 154 | random.shuffle(self.batch_data) 155 | 156 | for idx in range(self.len_data): 157 | yield self.batch_data[idx] 158 | 159 | 160 | 161 | class NERBatchManager(object): 162 | 163 | def __init__(self, data, batch_size): 164 | self.batch_data = self.sort_and_pad(data, batch_size) 165 | self.len_data = len(self.batch_data) 166 | 167 | def sort_and_pad(self, data, batch_size): 168 | num_batch = int(math.ceil(len(data) / batch_size)) 169 | sorted_data = sorted(data, key=lambda x: len(x[0])) 170 | batch_data = list() 171 | for i in range(num_batch): 172 | batch_data.append(self.pad_data(sorted_data[i*batch_size : (i+1)*batch_size])) 173 | return batch_data 174 | 175 | @staticmethod 176 | def pad_data(data): 177 | strings = [] 178 | segs = [] 179 | pos_list = [] 180 | targets = [] 181 | char_types = [] 182 | 183 | max_length = max([len(sentence[0]) for sentence in data]) 184 | for line in data: 185 | string, seg, pos, char_type, target = line 186 | padding = [0] * (max_length - len(string)) 187 | strings.append(string + padding) 188 | segs.append(seg + padding) 189 | pos_list.append(pos + padding) 190 | targets.append(target + padding) 191 | char_types.append(char_type + padding) 192 | return [strings, segs, pos_list, char_types, targets] 193 | 194 | 195 | def iter_batch(self, shuffle=False): 196 | if shuffle: 197 | random.shuffle(self.batch_data) 198 | 199 | for idx in range(self.len_data): 200 | yield self.batch_data[idx] 201 | 202 | class SegBatcher(object): 203 | def __init__(self, record_file_name, batch_size, num_epochs=None): 204 | self._batch_size = batch_size 205 | self._epoch = 0 206 | self._step = 1. 207 | self.num_epochs = num_epochs 208 | self.next_batch_op = self.input_pipeline(record_file_name, self._batch_size, self.num_epochs) 209 | 210 | 211 | def example_parser(self, filename_queue): 212 | reader = tf.TFRecordReader() 213 | key, record_string = reader.read(filename_queue) 214 | 215 | features = { 216 | 'labels': tf.FixedLenSequenceFeature([], tf.int64), 217 | 'char_list': tf.FixedLenSequenceFeature([], tf.int64), 218 | 'sent_len': tf.FixedLenSequenceFeature([], tf.int64), 219 | } 220 | 221 | _, example = tf.parse_single_sequence_example(serialized=record_string, sequence_features=features) 222 | labels = example['labels'] 223 | char_list = example['char_list'] 224 | sent_len = example['sent_len'] 225 | return labels, char_list, sent_len 226 | 227 | def input_pipeline(self, filenames, batch_size, num_epochs=None): 228 | filename_queue = tf.train.string_input_producer([filenames], num_epochs=num_epochs, shuffle=True) 229 | labels, char_list, sent_len = self.example_parser(filename_queue) 230 | 231 | min_after_dequeue = 10000 232 | capacity = min_after_dequeue + 12 * batch_size 233 | next_batch = tf.train.batch([labels, char_list, sent_len], batch_size=batch_size, capacity=capacity, 234 | dynamic_pad=True, allow_smaller_final_batch=True) 235 | return next_batch 236 | 237 | 238 | -------------------------------------------------------------------------------- /train/decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- 3 | 4 | from tensorflow.contrib.crf import viterbi_decode 5 | 6 | import numpy as np 7 | 8 | # def decode(logits, trans, lengths, num_tags): 9 | # small = -1000 10 | # 11 | # start = np.asarray([[small] * num_tags + [0 , small]]) 12 | # end = np.asarray([[small] * num_tags + [small, 0]]) 13 | # path = [] 14 | # 15 | # for logit, length in zip(logits,lengths): 16 | # logit = logit[ :length] 17 | # pad = small * np.ones([length, 2]) 18 | # logit = np.concatenate([logit, pad], axis=-1) 19 | # 20 | # logit = np.concatenate([start, logit, end], axis = 0) 21 | # viterbi , viterbi_score = viterbi_decode(logit, trans) 22 | # path.append(np.array(viterbi[1: -1])) 23 | # 24 | # return path 25 | 26 | 27 | def vdecode(logits, trans, sequence_lengths, tag_num): 28 | viterbi_sequences = [] 29 | small = -1000.0 30 | start = np.asarray([[small] * tag_num + [0]]) 31 | for logit, length in zip(logits, sequence_lengths): 32 | score = logit[:length] 33 | pad = small * np.ones([length, 1]) 34 | score = np.concatenate([score, pad], axis=1) 35 | score = np.concatenate([start, score], axis=0) 36 | viterbi_seq, viterbi_score = viterbi_decode(score, trans) 37 | viterbi_sequences.append(viterbi_seq[1:]) 38 | 39 | return viterbi_sequences -------------------------------------------------------------------------------- /train/export_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- 3 | 4 | 5 | import tensorflow as tf 6 | from tensorflow.contrib.crf import viterbi_decode 7 | import numpy as np 8 | 9 | def decode(logits, trans, sequence_lengths, tag_num): 10 | viterbi_sequences = [] 11 | small = -1000.0 12 | start = np.asarray([[small] * tag_num + [0]]) 13 | for logit, length in zip(logits, sequence_lengths): 14 | score = logit[:length] 15 | pad = small * np.ones([length, 1]) 16 | logits = np.concatenate([score, pad], axis=1) 17 | logits = np.concatenate([start, logits], axis=0) 18 | viterbi_seq, viterbi_score = viterbi_decode(logits, trans) 19 | viterbi_sequences += [viterbi_seq] 20 | return viterbi_sequences 21 | 22 | def save_to_binary(checkpoints_path, out_model_path, out_put_names): 23 | checkpoint_dir = checkpoints_path 24 | graph = tf.Graph() 25 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 26 | print(checkpoint_file) 27 | with graph.as_default(): 28 | session_conf = tf.ConfigProto( 29 | allow_soft_placement=True, 30 | log_device_placement=False 31 | ) 32 | sess = tf.Session(config=session_conf) 33 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 34 | saver.restore(sess, checkpoint_file) 35 | 36 | output_graph_def = tf.graph_util.convert_variables_to_constants( 37 | sess, 38 | tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 39 | output_node_names= out_put_names 40 | ) 41 | with tf.gfile.FastGFile(out_model_path, mode='wb') as f: 42 | f.write(output_graph_def.SerializeToString()) 43 | 44 | 45 | import pickle 46 | def load_map(path): 47 | with open(path, 'rb') as f: 48 | char_to_id, tag_to_id, id_to_tag = pickle.load(f) 49 | return char_to_id, id_to_tag 50 | 51 | 52 | def load_graph(path): 53 | with tf.gfile.GFile(path, "rb") as f: 54 | graph_def = tf.GraphDef() 55 | graph_def.ParseFromString(f.read()) 56 | with tf.Graph().as_default() as graph: 57 | tf.import_graph_def(graph_def, name="prefix") 58 | return graph 59 | 60 | 61 | 62 | class Predictor(object): 63 | def __init__(self, map_path, model_path): 64 | self.char_to_id, self.id_to_tag = load_map(map_path) 65 | self.graph = load_graph(model_path) 66 | self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0") 67 | self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0") 68 | self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0") 69 | self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0") 70 | self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0") 71 | 72 | self.sess = tf.Session(graph=self.graph) 73 | self.sess.as_default() 74 | self.num_class = len(self.id_to_tag) 75 | 76 | 77 | def predict(self, text): 78 | inputs = [] 79 | for w in text: 80 | if w in self.char_to_id: 81 | inputs.append(self.char_to_id[w]) 82 | else: 83 | inputs.append(self.char_to_id[""]) 84 | inputs = np.array(inputs, dtype=np.int32).reshape(1, len(inputs)) 85 | lengths=[len(text)] 86 | feed_dict = { 87 | self.input_x: inputs, 88 | self.lengths: lengths, 89 | self.dropout: 1.0 90 | } 91 | logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict) 92 | path = decode(logits, trans, [inputs.shape[1]], self.num_class) 93 | path = path[0][1:] 94 | tags = [self.id_to_tag[p] for p in path] 95 | print(tags) 96 | 97 | 98 | 99 | if __name__ == '__main__': 100 | import os 101 | import argparse 102 | 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--checkpoint_dir", required=True, help="checkpoint dir") 105 | parser.add_argument("--out_dir", required=True, help="out dir ") 106 | args = parser.parse_args() 107 | out_names = ['project/output/pred', 'project/logits', "crf_loss/transitions"] 108 | save_to_binary(args.checkpoint_dir, os.path.join(args.out_dir, "modle.pb"), out_names) 109 | 110 | -------------------------------------------------------------------------------- /train/load_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*-coding:utf-8-*- 3 | 4 | from fool.model import Model 5 | 6 | map_file = "./datasets/demo/maps.pkl" 7 | checkpoint_ifle = "./results/demo_seg/modle.pb" 8 | 9 | smodel = Model(map_path=map_file, model_path=checkpoint_ifle) 10 | tags = smodel.predict(list("北京欢迎你")) 11 | print(tags) -------------------------------------------------------------------------------- /train/main.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # train file 5 | TRAIN_FILE=./datasets/demo/train.txt 6 | # dev file 7 | DEV_FILE=./datasets/demo/dev.txt 8 | # test file 9 | TEST_FILE=./datasets/demo/test.txt 10 | 11 | # data out dir 12 | DATA_OUT_DIR=./datasets/demo 13 | 14 | # model save dir 15 | MODEL_OUT_DIR=./results/demo_seg/ 16 | 17 | # label tag column index 18 | TAG_INDEX=1 19 | 20 | # max length of sentlen 21 | MAX_LENGTH=100 22 | 23 | # 模型输出路径 24 | 25 | 26 | # word2vec 27 | WORD2VEC=./third_party/word2vec/word2vec 28 | 29 | ########################################################################### 30 | # 31 | ########################################################################### 32 | 33 | 34 | VEC_OUT_DIR="$DATA_OUT_DIR"/vec/ 35 | 36 | CHAR_PRE_TRAIN_FILE="$VEC_OUT_DIR"/char_vec_train.txt 37 | EMBEDING_FILE="$VEC_OUT_DIR"/char_vec.txt 38 | 39 | 40 | # ckpt 41 | CHECKPOINT_DIR="$MODEL_OUT_DIR"/ckpt/ 42 | 43 | # map file 44 | MAP_FILE="$DATA_OUT_DIR"/maps.pkl 45 | # train, file 46 | TRAIN_FILE_TF="$DATA_OUT_DIR"/train.tfrecord 47 | # dev file 48 | DEV_FILE_TF="$DATA_OUT_DIR"/dev.tfrecord 49 | # test tf record 50 | TEST_FILE_TF="$DATA_OUT_DIR"/test.tfrecord 51 | # 保存部分大小参数 52 | SIZE_FILE="$DATA_OUT_DIR"/size.json 53 | 54 | if [ ! -d $DATA_OUT_DIR ]; 55 | then mkdir -p $DATA_OUT_DIR; 56 | fi; 57 | 58 | 59 | if [ ! -d $VEC_OUT_DIR ]; 60 | then mkdir -p $VEC_OUT_DIR 61 | fi; 62 | 63 | if [ ! -d $MODEL_OUT_DIR ]; 64 | then mkdir -p $MODEL_OUT_DIR 65 | fi; 66 | 67 | if [ ! -d $CHECKPOINT_DIR ]; 68 | then mkdir -p $CHECKPOINT_DIR 69 | fi; 70 | 71 | 72 | 73 | if [ "$1" = "vec" ]; then 74 | echo "train vec " 75 | python prepare_vec.py --train_file "$TRAIN_FILE" --dev_file "$DEV_FILE" --test_file "$TEST_FILE" --out_file "$CHAR_PRE_TRAIN_FILE" 76 | time $WORD2VEC -train "$CHAR_PRE_TRAIN_FILE" -output "$EMBEDING_FILE" -cbow 1 -size 100 -window 8 -negative 25 -hs 0 \ 77 | -sample 1e-4 -threads 4 -binary 0 -iter 15 -min-count 5 78 | 79 | elif [ "$1" = "map" ]; then 80 | 81 | echo "create map file" 82 | python create_map_file.py --train_file "$TRAIN_FILE" --embeding_file "$EMBEDING_FILE" --map_file "$MAP_FILE" \ 83 | --size_file "$SIZE_FILE" --tag_index "$TAG_INDEX" 84 | 85 | elif [ "$1" = "data" ]; then 86 | echo "data to tfrecord"; 87 | python text_to_tfrecords.py --train_file "$TRAIN_FILE" --dev_file "$DEV_FILE" --test_file "$TEST_FILE" --map_file \ 88 | "$MAP_FILE" --size_file "$SIZE_FILE" --out_dir "$DATA_OUT_DIR" --tag_index "$TAG_INDEX" --max_length "$MAX_LENGTH" 89 | 90 | elif [ "$1" = "train" ]; then 91 | python norm_train_recoard.py --train_file "$TRAIN_FILE_TF" --dev_file "$DEV_FILE_TF" --test_file "$TEST_FILE_TF"\ 92 | --out_dir "$CHECKPOINT_DIR" --map_file "$MAP_FILE" --pre_embedding_file "$EMBEDING_FILE" --size_file "$SIZE_FILE" 93 | 94 | elif [ "$1" = "export" ]; then 95 | echo "echo model" 96 | python export_model.py --checkpoint_dir "$CHECKPOINT_DIR" --out_dir "$MODEL_OUT_DIR" 97 | 98 | else 99 | echo "param error" 100 | fi -------------------------------------------------------------------------------- /train/norm_train_recoard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf-8-*- 3 | 4 | 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from sklearn.metrics import accuracy_score 10 | from sklearn.metrics import classification_report 11 | 12 | import data_utils 13 | from bi_lstm import BiLSTM 14 | import word2vec 15 | from data_utils import SegBatcher 16 | 17 | 18 | def init_mode_config(vocab_size, tag_size): 19 | model_config = OrderedDict() 20 | 21 | model_config["char_dim"] = FLAGS.char_dim 22 | model_config["lstm_dim"] = FLAGS.lstm_dim 23 | model_config["optimizer"] = FLAGS.optimizer 24 | model_config['clip'] = FLAGS.clip 25 | model_config["lr"] = FLAGS.lr 26 | model_config['dropout'] = FLAGS.dropout 27 | model_config["input_dropout_keep"] = FLAGS.input_dropout_keep 28 | model_config["num_chars"] = vocab_size 29 | model_config["num_tags"] = tag_size 30 | 31 | return model_config 32 | 33 | 34 | def main(argv): 35 | # todo create map file 36 | word_to_id, tag_to_id, id_to_tag = data_utils.load_map_file(FLAGS.map_file) 37 | id_to_word = {v: k for k, v in word_to_id.items()} 38 | 39 | num_dict = data_utils.load_size_file(FLAGS.size_file) 40 | train_num = num_dict["train_num"] 41 | dev_num = num_dict["dev_num"] 42 | test_num = num_dict['test_num'] 43 | 44 | model_config = init_mode_config(len(word_to_id), len(tag_to_id)) 45 | print(model_config) 46 | 47 | tf_config = tf.ConfigProto() 48 | tf_config.gpu_options.allow_growth = True 49 | 50 | with tf.Graph().as_default(): 51 | 52 | print("load pre word2vec ...") 53 | wv = word2vec.Word2vec() 54 | embed = wv.load_w2v_array(FLAGS.pre_embedding_file, id_to_word) 55 | 56 | word_embedding = tf.constant(embed, dtype=tf.float32) 57 | model = BiLSTM(model_config, word_embedding) 58 | train_batcher = SegBatcher(FLAGS.train_file, FLAGS.batch_size, num_epochs=FLAGS.max_epoch) 59 | dev_batcher = SegBatcher(FLAGS.dev_file, FLAGS.batch_size, num_epochs=1) 60 | test_batcher = SegBatcher(FLAGS.test_file, FLAGS.batch_size, num_epochs=1) 61 | 62 | tf.global_variables_initializer() 63 | sv = tf.train.Supervisor(logdir=FLAGS.out_dir, save_model_secs=FLAGS.save_model_secs, ) 64 | 65 | with sv.managed_session() as sess: 66 | sess.as_default() 67 | threads = tf.train.start_queue_runners(sess=sess) 68 | loss = [] 69 | 70 | 71 | def run_evaluation(dev_batches, report=False): 72 | """ 73 | Evaluates model on a dev set 74 | """ 75 | preds = [] 76 | true_tags = [] 77 | tmp_x = [] 78 | for x_batch, y_batch, sent_len in dev_batches: 79 | feed_dict = { 80 | model.char_inputs: x_batch, 81 | model.targets: y_batch, 82 | model.lengths: sent_len.reshape(-1, ), 83 | model.dropout: 1.0 84 | } 85 | 86 | step, loss, logits, lengths, trans = sess.run( 87 | [model.global_step, model.loss, model.logits, model.lengths, model.trans], feed_dict) 88 | 89 | index = 0 90 | small = -1000.0 91 | start = np.asarray([[small] * model_config["num_tags"] + [0]]) 92 | 93 | for score, length in zip(logits, lengths): 94 | score = score[:length] 95 | pad = small * np.ones([length, 1]) 96 | logit = np.concatenate([score, pad], axis=1) 97 | logit = np.concatenate([start, logit], axis=0) 98 | path, _ = tf.contrib.crf.viterbi_decode(logit, trans) 99 | preds.append(path[1:]) 100 | tmp_x.append(x_batch[index][:length]) 101 | index += 1 102 | 103 | for y, length in zip(y_batch, lengths): 104 | y = y.tolist() 105 | true_tags.append(y[: length]) 106 | 107 | if FLAGS.debug and len(tmp_x) > 5: 108 | print(tag_to_id) 109 | 110 | for j in range(5): 111 | sent = [id_to_word.get(i, "") for i in tmp_x[j]] 112 | print("".join(sent)) 113 | print("pred:", preds[j]) 114 | print("true:", true_tags[j]) 115 | 116 | preds = np.concatenate(preds, axis=0) 117 | true_tags = np.concatenate(true_tags, axis=0) 118 | 119 | if report: 120 | print(classification_report(true_tags, preds)) 121 | 122 | acc = accuracy_score(true_tags, preds) 123 | return acc 124 | 125 | def run_test(): 126 | print("start run test ......") 127 | test_batches = [] 128 | done = False 129 | print("load all test batches to memory") 130 | 131 | while not done: 132 | try: 133 | tags, chars, sent_lens = sess.run(test_batcher.next_batch_op) 134 | test_batches.append((chars, tags, sent_lens)) 135 | except: 136 | done = True 137 | test_acc = run_evaluation(test_batches, True) 138 | print("test accc %f" % (test_acc)) 139 | 140 | best_acc = 0.0 141 | dev_batches = [] 142 | done = False 143 | print("load all dev batches to memory") 144 | 145 | while not done: 146 | try: 147 | tags, chars, sent_lens = sess.run(dev_batcher.next_batch_op) 148 | dev_batches.append((chars, tags, sent_lens)) 149 | except: 150 | done = True 151 | 152 | print("start training ...") 153 | early_stop = False 154 | for step in range(FLAGS.max_epoch): 155 | if sv.should_stop(): 156 | run_test() 157 | break 158 | examples = 0 159 | 160 | while examples < train_num: 161 | if early_stop: 162 | break 163 | try: 164 | batch = sess.run(train_batcher.next_batch_op) 165 | except Exception as e: 166 | break 167 | 168 | tags, chars, sent_lens = batch 169 | feed_dict = { 170 | model.char_inputs: chars, 171 | model.targets: tags, 172 | model.dropout: FLAGS.dropout, 173 | model.lengths: sent_lens.reshape(-1, ), 174 | } 175 | global_step, batch_loss, _ = sess.run([model.global_step, model.loss, model.train_op], feed_dict) 176 | 177 | print("%d iteration %d loss: %f" % (step, global_step, batch_loss)) 178 | if global_step % FLAGS.eval_step == 0: 179 | print("evaluation .......") 180 | acc = run_evaluation(dev_batches) 181 | 182 | print("%d iteration , %d dev acc: %f " % (step, global_step, acc)) 183 | 184 | if best_acc - acc > 0.01: 185 | print("stop training ealy ... best dev acc " % (best_acc)) 186 | early_stop = True 187 | break 188 | 189 | elif best_acc < acc: 190 | best_acc = acc 191 | sv.saver.save(sess, FLAGS.out_dir + "model", global_step=global_step) 192 | print("%d iteration , %d global step best dev acc: %f " % (step, global_step, best_acc)) 193 | 194 | loss.append(batch_loss) 195 | examples += FLAGS.batch_size 196 | 197 | sv.saver.save(sess, FLAGS.out_dir + "model", global_step=global_step) 198 | run_test() 199 | sv.coord.request_stop() 200 | sv.coord.join(threads) 201 | sess.close() 202 | 203 | 204 | if __name__ == "__main__": 205 | tf.app.flags.DEFINE_string("train_file", "", "path of train recoard path") 206 | tf.app.flags.DEFINE_string("dev_file", "", "path of dev recoard path") 207 | tf.app.flags.DEFINE_string("test_file", "", "path of dev recoard path") 208 | 209 | tf.app.flags.DEFINE_string("pre_embedding_file", "", "vec of char or map file path") 210 | tf.app.flags.DEFINE_string("map_file", "", "map file ") 211 | tf.app.flags.DEFINE_string("out_dir", "", "log path of the supervisor") 212 | tf.app.flags.DEFINE_string("size_file", "", "size file") 213 | 214 | tf.app.flags.DEFINE_integer("max_epoch", 20, "max epoch") 215 | tf.app.flags.DEFINE_integer("batch_size", 32, "batch size") 216 | 217 | tf.app.flags.DEFINE_float("input_dropout_keep", 1.0, "input drop out ") 218 | tf.app.flags.DEFINE_float("dropout", 0.5, "dropout") 219 | 220 | tf.app.flags.DEFINE_integer("eval_step", 10, "evaluation step size") 221 | 222 | tf.app.flags.DEFINE_integer("char_dim", 100, "the embedding size of char or word") 223 | tf.app.flags.DEFINE_string("optimizer", "adam", "optimizer ") 224 | tf.app.flags.DEFINE_integer("clip", 5, "clip ") 225 | tf.app.flags.DEFINE_integer("lstm_dim", 100, "lstm dim") 226 | tf.app.flags.DEFINE_float("lr", 0.001, "learning rate") 227 | tf.app.flags.DEFINE_integer("save_model_secs", 30, "save model every second") 228 | 229 | tf.app.flags.DEFINE_boolean("debug", True, "if debug ") 230 | 231 | FLAGS = tf.flags.FLAGS 232 | tf.app.run() 233 | -------------------------------------------------------------------------------- /train/prepare_vec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | 5 | 6 | def load_file(file_name): 7 | f = open(file_name) 8 | sent = [] 9 | for line_no, line in enumerate(f): 10 | line = line.strip("\n") 11 | if not line: 12 | yield sent 13 | sent = [] 14 | else: 15 | item = line.split("\t") 16 | sent.append(item) 17 | f.close() 18 | 19 | 20 | def papre_char_vec(train_file, dev_file, test_file, out_filename): 21 | sent_counter = 0 22 | outf = open(out_filename, 'w') 23 | for file_name in [train_file, dev_file, test_file]: 24 | print("papre char vec train data from : %s" % (file_name)) 25 | for sent in load_file(file_name): 26 | sent_counter += 1 27 | s = [item[0] for item in sent] 28 | outf.write(" ".join(s) + "\n") 29 | print("all sent count %d" % (sent_counter)) 30 | 31 | 32 | def papre_word_vec(train_file, dev_file, test_file, out_filename): 33 | sent_counter = 0 34 | outf = open(out_filename, 'w') 35 | for file_name in [train_file, dev_file, test_file]: 36 | 37 | print("papre word vec train data from : %s" % (file_name)) 38 | 39 | for sent in load_file(file_name): 40 | sent_counter += 1 41 | word = "" 42 | words = [] 43 | for item in sent: 44 | ch = item[0] 45 | seg_label = item[1] 46 | if seg_label == "B": 47 | word += ch 48 | elif seg_label == "M": 49 | word += ch 50 | elif seg_label == "S": 51 | words.append(ch) 52 | elif seg_label == "E": 53 | word += ch 54 | words.append(word) 55 | word = "" 56 | else: 57 | raise Exception("%s ignore" % (seg_label)) 58 | outf.write(" ".join(words) + "\n") 59 | 60 | 61 | if __name__ == '__main__': 62 | import argparse 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--train_file", required=True, help="train_file path") 66 | parser.add_argument("--dev_file", required=True, help="dev file path") 67 | parser.add_argument("--test_file", required=True, help="test file path") 68 | parser.add_argument("--out_file", required=True, help="out dir for vec path") 69 | args = parser.parse_args() 70 | papre_char_vec(args.train_file, args.dev_file, args.test_file, args.out_file) 71 | -------------------------------------------------------------------------------- /train/text_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import pickle 6 | import os 7 | import json 8 | 9 | 10 | def load_map_file(map_filename): 11 | vocab, tag_to_id, id_to_tag = pickle.load(open(map_filename, 'rb')) 12 | return vocab, tag_to_id, id_to_tag 13 | 14 | 15 | def seg_to_tfrecords(text_file, out_dir, map_file, out_name, indexs=[0, 1]): 16 | out_filename = os.path.join(out_dir, out_name + ".tfrecord") 17 | vocab, tag_to_id, id_to_tag = load_map_file(map_file) 18 | 19 | writer = tf.python_io.TFRecordWriter(out_filename) 20 | 21 | num_sample = 0 22 | all_oov = 0 23 | total_word = 0 24 | with open(text_file) as f: 25 | sent = [] 26 | for lineno, line in enumerate(f): 27 | line = line.strip("\n") 28 | if not line: 29 | if sent: 30 | num_sample += 1 31 | word_count, oov_count = create_one_seg_sample(writer, sent, vocab, tag_to_id) 32 | total_word += word_count 33 | all_oov += oov_count 34 | sent = [] 35 | continue 36 | 37 | word_info = line.split("\t") 38 | word_info = [word_info[i] for i in indexs] 39 | sent.append(word_info) 40 | print("oov rate : %f" % (1.0 * oov_count / total_word)) 41 | 42 | return num_sample 43 | 44 | 45 | def create_one_seg_sample(writer, sent, char_to_id, tag_to_id): 46 | char_list = [] 47 | seg_label_list = [] 48 | oov_count = 0 49 | word_count = 0 50 | for word in sent: 51 | ch = word[0] 52 | label = word[1] 53 | word_count += 1 54 | if ch in char_to_id: 55 | char_list.append(char_to_id[ch]) 56 | else: 57 | char_list.append(char_to_id[""]) 58 | oov_count += 1 59 | seg_label_list.append(tag_to_id[label]) 60 | 61 | example = tf.train.SequenceExample() 62 | 63 | sent_len = MAX_LENGTH if len(sent) > MAX_LENGTH else len(seg_label_list) 64 | 65 | fl_labels = example.feature_lists.feature_list["labels"] 66 | for l in seg_label_list[:sent_len]: 67 | fl_labels.feature.add().int64_list.value.append(l) 68 | 69 | fl_tokens = example.feature_lists.feature_list["char_list"] 70 | for t in char_list[:sent_len]: 71 | fl_tokens.feature.add().int64_list.value.append(t) 72 | 73 | fl_sent_len = example.feature_lists.feature_list["sent_len"] 74 | for t in [sent_len]: 75 | fl_sent_len.feature.add().int64_list.value.append(t) 76 | 77 | writer.write(example.SerializeToString()) 78 | return word_count, oov_count 79 | 80 | 81 | if __name__ == '__main__': 82 | import argparse 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--train_file", required=True, help="train file path") 86 | parser.add_argument("--dev_file", required=True, help="dev file path") 87 | parser.add_argument("--map_file", required=True, help="map file path") 88 | parser.add_argument("--test_file", required=True, help="test file path") 89 | parser.add_argument("--out_dir", required=True, help=" out dir for tfrecord ") 90 | parser.add_argument("--size_file", required=True, help="the size file create by create map step") 91 | parser.add_argument("--tag_index", default=-1, type=int, help="column index for target label ") 92 | parser.add_argument("--max_length", default=100, type=int,help="max length of sent") 93 | args = parser.parse_args() 94 | 95 | MAX_LENGTH = args.max_length 96 | 97 | train_num = seg_to_tfrecords(args.train_file, args.out_dir, args.map_file, "train", [0, args.tag_index]) 98 | dev_num = seg_to_tfrecords(args.dev_file, args.out_dir, args.map_file, "dev", [0, args.tag_index]) 99 | test_num = seg_to_tfrecords(args.test_file, args.out_dir, args.map_file, "test", [0, args.tag_index]) 100 | 101 | print("train sample : %d" % (train_num)) 102 | print("dev sample :%d" % (dev_num)) 103 | print("test sample : %d" % (dev_num)) 104 | 105 | size_filename = args.size_file 106 | 107 | with open(size_filename, 'r') as f: 108 | size_obj = json.load(f) 109 | 110 | with open(os.path.join(size_filename), 'w') as f: 111 | size_obj['train_num'] = train_num 112 | size_obj['dev_num'] = dev_num 113 | size_obj['test_num'] = test_num 114 | json.dump(size_obj, f) 115 | -------------------------------------------------------------------------------- /train/tf_metrics.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | """ 4 | Multiclass 5 | from: 6 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 7 | 8 | """ 9 | 10 | __author__ = "Guillaume Genthial" 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 15 | 16 | 17 | def precision(labels, predictions, num_classes, pos_indices=None, 18 | weights=None, average='micro'): 19 | """Multi-class precision metric for Tensorflow 20 | Parameters 21 | ---------- 22 | labels : Tensor of tf.int32 or tf.int64 23 | The true labels 24 | predictions : Tensor of tf.int32 or tf.int64 25 | The predictions, same shape as labels 26 | num_classes : int 27 | The number of classes 28 | pos_indices : list of int, optional 29 | The indices of the positive classes, default is all 30 | weights : Tensor of tf.int32, optional 31 | Mask, must be of compatible shape with labels 32 | average : str, optional 33 | 'micro': counts the total number of true positives, false 34 | positives, and false negatives for the classes in 35 | `pos_indices` and infer the metric from it. 36 | 'macro': will compute the metric separately for each class in 37 | `pos_indices` and average. Will not account for class 38 | imbalance. 39 | 'weighted': will compute the metric separately for each class in 40 | `pos_indices` and perform a weighted average by the total 41 | number of true labels for each class. 42 | Returns 43 | ------- 44 | tuple of (scalar float Tensor, update_op) 45 | """ 46 | cm, op = _streaming_confusion_matrix( 47 | labels, predictions, num_classes, weights) 48 | pr, _, _ = metrics_from_confusion_matrix( 49 | cm, pos_indices, average=average) 50 | op, _, _ = metrics_from_confusion_matrix( 51 | op, pos_indices, average=average) 52 | return (pr, op) 53 | 54 | 55 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 56 | average='micro'): 57 | """Multi-class recall metric for Tensorflow 58 | Parameters 59 | ---------- 60 | labels : Tensor of tf.int32 or tf.int64 61 | The true labels 62 | predictions : Tensor of tf.int32 or tf.int64 63 | The predictions, same shape as labels 64 | num_classes : int 65 | The number of classes 66 | pos_indices : list of int, optional 67 | The indices of the positive classes, default is all 68 | weights : Tensor of tf.int32, optional 69 | Mask, must be of compatible shape with labels 70 | average : str, optional 71 | 'micro': counts the total number of true positives, false 72 | positives, and false negatives for the classes in 73 | `pos_indices` and infer the metric from it. 74 | 'macro': will compute the metric separately for each class in 75 | `pos_indices` and average. Will not account for class 76 | imbalance. 77 | 'weighted': will compute the metric separately for each class in 78 | `pos_indices` and perform a weighted average by the total 79 | number of true labels for each class. 80 | Returns 81 | ------- 82 | tuple of (scalar float Tensor, update_op) 83 | """ 84 | cm, op = _streaming_confusion_matrix( 85 | labels, predictions, num_classes, weights) 86 | _, re, _ = metrics_from_confusion_matrix( 87 | cm, pos_indices, average=average) 88 | _, op, _ = metrics_from_confusion_matrix( 89 | op, pos_indices, average=average) 90 | return (re, op) 91 | 92 | 93 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 94 | average='micro'): 95 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 96 | average) 97 | 98 | 99 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 100 | average='micro', beta=1): 101 | """Multi-class fbeta metric for Tensorflow 102 | Parameters 103 | ---------- 104 | labels : Tensor of tf.int32 or tf.int64 105 | The true labels 106 | predictions : Tensor of tf.int32 or tf.int64 107 | The predictions, same shape as labels 108 | num_classes : int 109 | The number of classes 110 | pos_indices : list of int, optional 111 | The indices of the positive classes, default is all 112 | weights : Tensor of tf.int32, optional 113 | Mask, must be of compatible shape with labels 114 | average : str, optional 115 | 'micro': counts the total number of true positives, false 116 | positives, and false negatives for the classes in 117 | `pos_indices` and infer the metric from it. 118 | 'macro': will compute the metric separately for each class in 119 | `pos_indices` and average. Will not account for class 120 | imbalance. 121 | 'weighted': will compute the metric separately for each class in 122 | `pos_indices` and perform a weighted average by the total 123 | number of true labels for each class. 124 | beta : int, optional 125 | Weight of precision in harmonic mean 126 | Returns 127 | ------- 128 | tuple of (scalar float Tensor, update_op) 129 | """ 130 | cm, op = _streaming_confusion_matrix( 131 | labels, predictions, num_classes, weights) 132 | _, _, fbeta = metrics_from_confusion_matrix( 133 | cm, pos_indices, average=average, beta=beta) 134 | _, _, op = metrics_from_confusion_matrix( 135 | op, pos_indices, average=average, beta=beta) 136 | return (fbeta, op) 137 | 138 | 139 | def safe_div(numerator, denominator): 140 | """Safe division, return 0 if denominator is 0""" 141 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 142 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 143 | denominator_is_zero = tf.equal(denominator, zeros) 144 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 145 | 146 | 147 | def pr_re_fbeta(cm, pos_indices, beta=1): 148 | """Uses a confusion matrix to compute precision, recall and fbeta""" 149 | num_classes = cm.shape[0] 150 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 151 | cm_mask = np.ones([num_classes, num_classes]) 152 | cm_mask[neg_indices, neg_indices] = 0 153 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 154 | 155 | cm_mask = np.ones([num_classes, num_classes]) 156 | cm_mask[:, neg_indices] = 0 157 | tot_pred = tf.reduce_sum(cm * cm_mask) 158 | 159 | cm_mask = np.ones([num_classes, num_classes]) 160 | cm_mask[neg_indices, :] = 0 161 | tot_gold = tf.reduce_sum(cm * cm_mask) 162 | 163 | pr = safe_div(diag_sum, tot_pred) 164 | re = safe_div(diag_sum, tot_gold) 165 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 166 | 167 | return pr, re, fbeta 168 | 169 | 170 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 171 | beta=1): 172 | """Precision, Recall and F1 from the confusion matrix 173 | Parameters 174 | ---------- 175 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 176 | The streaming confusion matrix. 177 | pos_indices : list of int, optional 178 | The indices of the positive classes 179 | beta : int, optional 180 | Weight of precision in harmonic mean 181 | average : str, optional 182 | 'micro', 'macro' or 'weighted' 183 | """ 184 | num_classes = cm.shape[0] 185 | if pos_indices is None: 186 | pos_indices = [i for i in range(num_classes)] 187 | 188 | if average == 'micro': 189 | return pr_re_fbeta(cm, pos_indices, beta) 190 | elif average in {'macro', 'weighted'}: 191 | precisions, recalls, fbetas, n_golds = [], [], [], [] 192 | for idx in pos_indices: 193 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 194 | precisions.append(pr) 195 | recalls.append(re) 196 | fbetas.append(fbeta) 197 | cm_mask = np.zeros([num_classes, num_classes]) 198 | cm_mask[idx, :] = 1 199 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 200 | 201 | if average == 'macro': 202 | pr = tf.reduce_mean(precisions) 203 | re = tf.reduce_mean(recalls) 204 | fbeta = tf.reduce_mean(fbetas) 205 | return pr, re, fbeta 206 | if average == 'weighted': 207 | n_gold = tf.reduce_sum(n_golds) 208 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 209 | pr = safe_div(pr_sum, n_gold) 210 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 211 | re = safe_div(re_sum, n_gold) 212 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 213 | fbeta = safe_div(fbeta_sum, n_gold) 214 | return pr, re, fbeta 215 | 216 | else: 217 | raise NotImplementedError() -------------------------------------------------------------------------------- /train/third_party/word2vec/.gitignore: -------------------------------------------------------------------------------- 1 | compute-accuracy 2 | distance 3 | word2phrase 4 | word2vec 5 | word-analogy -------------------------------------------------------------------------------- /train/third_party/word2vec/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /train/third_party/word2vec/README.txt: -------------------------------------------------------------------------------- 1 | Tools for computing distributed representtion of words 2 | ------------------------------------------------------ 3 | 4 | We provide an implementation of the Continuous Bag-of-Words (CBOW) and the Skip-gram model (SG), as well as several demo scripts. 5 | 6 | Given a text corpus, the word2vec tool learns a vector for every word in the vocabulary using the Continuous 7 | Bag-of-Words or the Skip-Gram neural network architectures. The user should to specify the following: 8 | - desired vector dimensionality 9 | - the size of the context window for either the Skip-Gram or the Continuous Bag-of-Words model 10 | - training algorithm: hierarchical softmax and / or negative sampling 11 | - threshold for downsampling the frequent words 12 | - number of threads to use 13 | - the format of the output word vector file (text or binary) 14 | 15 | Usually, the other hyper-parameters such as the learning rate do not need to be tuned for different training sets. 16 | 17 | The script demo-word.sh downloads a small (100MB) text corpus from the web, and trains a small word vector model. After the training 18 | is finished, the user can interactively explore the similarity of the words. 19 | 20 | More information about the scripts is provided at https://code.google.com/p/word2vec/ 21 | 22 | -------------------------------------------------------------------------------- /train/third_party/word2vec/compute-accuracy.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | const long long max_size = 2000; // max length of strings 23 | const long long N = 1; // number of closest words 24 | const long long max_w = 50; // max length of vocabulary entries 25 | 26 | int main(int argc, char **argv) 27 | { 28 | FILE *f; 29 | char st1[max_size], st2[max_size], st3[max_size], st4[max_size], bestw[N][max_size], file_name[max_size], ch; 30 | float dist, len, bestd[N], vec[max_size]; 31 | long long words, size, a, b, c, d, b1, b2, b3, threshold = 0; 32 | float *M; 33 | char *vocab; 34 | int TCN, CCN = 0, TACN = 0, CACN = 0, SECN = 0, SYCN = 0, SEAC = 0, SYAC = 0, QID = 0, TQ = 0, TQS = 0; 35 | if (argc < 2) { 36 | printf("Usage: ./compute-accuracy \nwhere FILE contains word projections, and threshold is used to reduce vocabulary of the model for fast approximate evaluation (0 = off, otherwise typical value is 30000)\n"); 37 | return 0; 38 | } 39 | strcpy(file_name, argv[1]); 40 | if (argc > 2) threshold = atoi(argv[2]); 41 | f = fopen(file_name, "rb"); 42 | if (f == NULL) { 43 | printf("Input file not found\n"); 44 | return -1; 45 | } 46 | fscanf(f, "%lld", &words); 47 | if (threshold) if (words > threshold) words = threshold; 48 | fscanf(f, "%lld", &size); 49 | vocab = (char *)malloc(words * max_w * sizeof(char)); 50 | M = (float *)malloc(words * size * sizeof(float)); 51 | if (M == NULL) { 52 | printf("Cannot allocate memory: %lld MB\n", words * size * sizeof(float) / 1048576); 53 | return -1; 54 | } 55 | for (b = 0; b < words; b++) { 56 | a = 0; 57 | while (1) { 58 | vocab[b * max_w + a] = fgetc(f); 59 | if (feof(f) || (vocab[b * max_w + a] == ' ')) break; 60 | if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++; 61 | } 62 | vocab[b * max_w + a] = 0; 63 | for (a = 0; a < max_w; a++) vocab[b * max_w + a] = toupper(vocab[b * max_w + a]); 64 | for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f); 65 | len = 0; 66 | for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size]; 67 | len = sqrt(len); 68 | for (a = 0; a < size; a++) M[a + b * size] /= len; 69 | } 70 | fclose(f); 71 | TCN = 0; 72 | while (1) { 73 | for (a = 0; a < N; a++) bestd[a] = 0; 74 | for (a = 0; a < N; a++) bestw[a][0] = 0; 75 | scanf("%s", st1); 76 | for (a = 0; a < strlen(st1); a++) st1[a] = toupper(st1[a]); 77 | if ((!strcmp(st1, ":")) || (!strcmp(st1, "EXIT")) || feof(stdin)) { 78 | if (TCN == 0) TCN = 1; 79 | if (QID != 0) { 80 | printf("ACCURACY TOP1: %.2f %% (%d / %d)\n", CCN / (float)TCN * 100, CCN, TCN); 81 | printf("Total accuracy: %.2f %% Semantic accuracy: %.2f %% Syntactic accuracy: %.2f %% \n", CACN / (float)TACN * 100, SEAC / (float)SECN * 100, SYAC / (float)SYCN * 100); 82 | } 83 | QID++; 84 | scanf("%s", st1); 85 | if (feof(stdin)) break; 86 | printf("%s:\n", st1); 87 | TCN = 0; 88 | CCN = 0; 89 | continue; 90 | } 91 | if (!strcmp(st1, "EXIT")) break; 92 | scanf("%s", st2); 93 | for (a = 0; a < strlen(st2); a++) st2[a] = toupper(st2[a]); 94 | scanf("%s", st3); 95 | for (a = 0; a bestd[a]) { 122 | for (d = N - 1; d > a; d--) { 123 | bestd[d] = bestd[d - 1]; 124 | strcpy(bestw[d], bestw[d - 1]); 125 | } 126 | bestd[a] = dist; 127 | strcpy(bestw[a], &vocab[c * max_w]); 128 | break; 129 | } 130 | } 131 | } 132 | if (!strcmp(st4, bestw[0])) { 133 | CCN++; 134 | CACN++; 135 | if (QID <= 5) SEAC++; else SYAC++; 136 | } 137 | if (QID <= 5) SECN++; else SYCN++; 138 | TCN++; 139 | TACN++; 140 | } 141 | printf("Questions seen / total: %d %d %.2f %% \n", TQS, TQ, TQS/(float)TQ*100); 142 | return 0; 143 | } 144 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-analogy.sh: -------------------------------------------------------------------------------- 1 | make 2 | if [ ! -e text8 ]; then 3 | wget http://mattmahoney.net/dc/text8.zip -O text8.gz 4 | gzip -d text8.gz -f 5 | fi 6 | echo --------------------------------------------------------------------------------------------------- 7 | echo Note that for the word analogy to perform well, the model should be trained on much larger data set 8 | echo Example input: paris france berlin 9 | echo --------------------------------------------------------------------------------------------------- 10 | time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 200 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -iter 15 11 | ./word-analogy vectors.bin 12 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-classes.sh: -------------------------------------------------------------------------------- 1 | make 2 | if [ ! -e text8 ]; then 3 | wget http://mattmahoney.net/dc/text8.zip -O text8.gz 4 | gzip -d text8.gz -f 5 | fi 6 | time ./word2vec -train text8 -output classes.txt -cbow 1 -size 200 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -iter 15 -classes 500 7 | sort classes.txt -k 2 -n > classes.sorted.txt 8 | echo The word classes were saved to file classes.sorted.txt 9 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-phrase-accuracy.sh: -------------------------------------------------------------------------------- 1 | make 2 | if [ ! -e news.2012.en.shuffled ]; then 3 | wget http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.en.shuffled.gz 4 | gzip -d news.2012.en.shuffled.gz -f 5 | fi 6 | sed -e "s/’/'/g" -e "s/′/'/g" -e "s/''/ /g" < news.2012.en.shuffled | tr -c "A-Za-z'_ \n" " " > news.2012.en.shuffled-norm0 7 | time ./word2phrase -train news.2012.en.shuffled-norm0 -output news.2012.en.shuffled-norm0-phrase0 -threshold 200 -debug 2 8 | time ./word2phrase -train news.2012.en.shuffled-norm0-phrase0 -output news.2012.en.shuffled-norm0-phrase1 -threshold 100 -debug 2 9 | tr A-Z a-z < news.2012.en.shuffled-norm0-phrase1 > news.2012.en.shuffled-norm1-phrase1 10 | time ./word2vec -train news.2012.en.shuffled-norm1-phrase1 -output vectors-phrase.bin -cbow 1 -size 200 -window 10 -negative 25 -hs 0 -sample 1e-5 -threads 20 -binary 1 -iter 15 11 | ./compute-accuracy vectors-phrase.bin < questions-phrases.txt 12 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-phrases.sh: -------------------------------------------------------------------------------- 1 | make 2 | if [ ! -e news.2012.en.shuffled ]; then 3 | wget http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.en.shuffled.gz 4 | gzip -d news.2012.en.shuffled.gz -f 5 | fi 6 | sed -e "s/’/'/g" -e "s/′/'/g" -e "s/''/ /g" < news.2012.en.shuffled | tr -c "A-Za-z'_ \n" " " > news.2012.en.shuffled-norm0 7 | time ./word2phrase -train news.2012.en.shuffled-norm0 -output news.2012.en.shuffled-norm0-phrase0 -threshold 200 -debug 2 8 | time ./word2phrase -train news.2012.en.shuffled-norm0-phrase0 -output news.2012.en.shuffled-norm0-phrase1 -threshold 100 -debug 2 9 | tr A-Z a-z < news.2012.en.shuffled-norm0-phrase1 > news.2012.en.shuffled-norm1-phrase1 10 | time ./word2vec -train news.2012.en.shuffled-norm1-phrase1 -output vectors-phrase.bin -cbow 1 -size 200 -window 10 -negative 25 -hs 0 -sample 1e-5 -threads 20 -binary 1 -iter 15 11 | ./distance vectors-phrase.bin 12 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-train-big-model-v1.sh: -------------------------------------------------------------------------------- 1 | ############################################################################################### 2 | # 3 | # Script for training good word and phrase vector model using public corpora, version 1.0. 4 | # The training time will be from several hours to about a day. 5 | # 6 | # Downloads about 8 billion words, makes phrases using two runs of word2phrase, trains 7 | # a 500-dimensional vector model and evaluates it on word and phrase analogy tasks. 8 | # 9 | ############################################################################################### 10 | 11 | # This function will convert text to lowercase and remove special characters 12 | normalize_text() { 13 | awk '{print tolower($0);}' | sed -e "s/’/'/g" -e "s/′/'/g" -e "s/''/ /g" -e "s/'/ ' /g" -e "s/“/\"/g" -e "s/”/\"/g" \ 14 | -e 's/"/ " /g' -e 's/\./ \. /g' -e 's/
/ /g' -e 's/, / , /g' -e 's/(/ ( /g' -e 's/)/ ) /g' -e 's/\!/ \! /g' \ 15 | -e 's/\?/ \? /g' -e 's/\;/ /g' -e 's/\:/ /g' -e 's/-/ - /g' -e 's/=/ /g' -e 's/=/ /g' -e 's/*/ /g' -e 's/|/ /g' \ 16 | -e 's/«/ /g' | tr 0-9 " " 17 | } 18 | 19 | mkdir word2vec 20 | cd word2vec 21 | 22 | wget http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.en.shuffled.gz 23 | wget http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.en.shuffled.gz 24 | gzip -d news.2012.en.shuffled.gz 25 | gzip -d news.2013.en.shuffled.gz 26 | normalize_text < news.2012.en.shuffled > data.txt 27 | normalize_text < news.2013.en.shuffled >> data.txt 28 | 29 | wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz 30 | tar -xvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz 31 | for i in `ls 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled`; do 32 | normalize_text < 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/$i >> data.txt 33 | done 34 | 35 | wget http://ebiquity.umbc.edu/redirect/to/resource/id/351/UMBC-webbase-corpus 36 | tar -zxvf umbc_webbase_corpus.tar.gz webbase_all/*.txt 37 | for i in `ls webbase_all`; do 38 | normalize_text < webbase_all/$i >> data.txt 39 | done 40 | 41 | wget http://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 42 | bzip2 -c -d enwiki-latest-pages-articles.xml.bz2 | awk '{print tolower($0);}' | perl -e ' 43 | # Program to filter Wikipedia XML dumps to "clean" text consisting only of lowercase 44 | # letters (a-z, converted from A-Z), and spaces (never consecutive)... 45 | # All other characters are converted to spaces. Only text which normally appears. 46 | # in the web browser is displayed. Tables are removed. Image captions are. 47 | # preserved. Links are converted to normal text. Digits are spelled out. 48 | # *** Modified to not spell digits or throw away non-ASCII characters *** 49 | 50 | # Written by Matt Mahoney, June 10, 2006. This program is released to the public domain. 51 | 52 | $/=">"; # input record separator 53 | while (<>) { 54 | if (/ ... 55 | if (/#redirect/i) {$text=0;} # remove #REDIRECT 56 | if ($text) { 57 | 58 | # Remove any text not normally visible 59 | if (/<\/text>/) {$text=0;} 60 | s/<.*>//; # remove xml tags 61 | s/&/&/g; # decode URL encoded chars 62 | s/<//g; 64 | s///g; # remove references ... 65 | s/<[^>]*>//g; # remove xhtml tags 66 | s/\[http:[^] ]*/[/g; # remove normal url, preserve visible text 67 | s/\|thumb//ig; # remove images links, preserve caption 68 | s/\|left//ig; 69 | s/\|right//ig; 70 | s/\|\d+px//ig; 71 | s/\[\[image:[^\[\]]*\|//ig; 72 | s/\[\[category:([^|\]]*)[^]]*\]\]/[[$1]]/ig; # show categories without markup 73 | s/\[\[[a-z\-]*:[^\]]*\]\]//g; # remove links to other languages 74 | s/\[\[[^\|\]]*\|/[[/g; # remove wiki url, preserve visible text 75 | s/{{[^}]*}}//g; # remove {{icons}} and {tables} 76 | s/{[^}]*}//g; 77 | s/\[//g; # remove [ and ] 78 | s/\]//g; 79 | s/&[^;]*;/ /g; # remove URL encoded chars 80 | 81 | $_=" $_ "; 82 | chop; 83 | print $_; 84 | } 85 | } 86 | ' | normalize_text | awk '{if (NF>1) print;}' >> data.txt 87 | 88 | wget http://word2vec.googlecode.com/svn/trunk/word2vec.c 89 | wget http://word2vec.googlecode.com/svn/trunk/word2phrase.c 90 | wget http://word2vec.googlecode.com/svn/trunk/compute-accuracy.c 91 | wget http://word2vec.googlecode.com/svn/trunk/questions-words.txt 92 | wget http://word2vec.googlecode.com/svn/trunk/questions-phrases.txt 93 | gcc word2vec.c -o word2vec -lm -pthread -O3 -march=native -funroll-loops 94 | gcc word2phrase.c -o word2phrase -lm -pthread -O3 -march=native -funroll-loops 95 | gcc compute-accuracy.c -o compute-accuracy -lm -pthread -O3 -march=native -funroll-loops 96 | ./word2phrase -train data.txt -output data-phrase.txt -threshold 200 -debug 2 97 | ./word2phrase -train data-phrase.txt -output data-phrase2.txt -threshold 100 -debug 2 98 | ./word2vec -train data-phrase2.txt -output vectors.bin -cbow 1 -size 500 -window 10 -negative 10 -hs 0 -sample 1e-5 -threads 40 -binary 1 -iter 3 -min-count 10 99 | ./compute-accuracy vectors.bin 400000 < questions-words.txt # should get to almost 78% accuracy on 99.7% of questions 100 | ./compute-accuracy vectors.bin 1000000 < questions-phrases.txt # about 78% accuracy with 77% coverage 101 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-word-accuracy.sh: -------------------------------------------------------------------------------- 1 | make 2 | if [ ! -e text8 ]; then 3 | wget http://mattmahoney.net/dc/text8.zip -O text8.gz 4 | gzip -d text8.gz -f 5 | fi 6 | time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 200 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -iter 15 7 | ./compute-accuracy vectors.bin 30000 < questions-words.txt 8 | # to compute accuracy with the full vocabulary, use: ./compute-accuracy vectors.bin < questions-words.txt 9 | -------------------------------------------------------------------------------- /train/third_party/word2vec/demo-word.sh: -------------------------------------------------------------------------------- 1 | make 2 | if [ ! -e text8 ]; then 3 | wget http://mattmahoney.net/dc/text8.zip -O text8.gz 4 | gzip -d text8.gz -f 5 | fi 6 | time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 200 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -iter 15 7 | ./distance vectors.bin 8 | -------------------------------------------------------------------------------- /train/third_party/word2vec/distance.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | const long long max_size = 2000; // max length of strings 21 | const long long N = 40; // number of closest words that will be shown 22 | const long long max_w = 50; // max length of vocabulary entries 23 | 24 | int main(int argc, char **argv) { 25 | FILE *f; 26 | char st1[max_size]; 27 | char *bestw[N]; 28 | char file_name[max_size], st[100][max_size]; 29 | float dist, len, bestd[N], vec[max_size]; 30 | long long words, size, a, b, c, d, cn, bi[100]; 31 | char ch; 32 | float *M; 33 | char *vocab; 34 | if (argc < 2) { 35 | printf("Usage: ./distance \nwhere FILE contains word projections in the BINARY FORMAT\n"); 36 | return 0; 37 | } 38 | strcpy(file_name, argv[1]); 39 | f = fopen(file_name, "rb"); 40 | if (f == NULL) { 41 | printf("Input file not found\n"); 42 | return -1; 43 | } 44 | fscanf(f, "%lld", &words); 45 | fscanf(f, "%lld", &size); 46 | vocab = (char *)malloc((long long)words * max_w * sizeof(char)); 47 | for (a = 0; a < N; a++) bestw[a] = (char *)malloc(max_size * sizeof(char)); 48 | M = (float *)malloc((long long)words * (long long)size * sizeof(float)); 49 | if (M == NULL) { 50 | printf("Cannot allocate memory: %lld MB %lld %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size); 51 | return -1; 52 | } 53 | for (b = 0; b < words; b++) { 54 | a = 0; 55 | while (1) { 56 | vocab[b * max_w + a] = fgetc(f); 57 | if (feof(f) || (vocab[b * max_w + a] == ' ')) break; 58 | if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++; 59 | } 60 | vocab[b * max_w + a] = 0; 61 | for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f); 62 | len = 0; 63 | for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size]; 64 | len = sqrt(len); 65 | for (a = 0; a < size; a++) M[a + b * size] /= len; 66 | } 67 | fclose(f); 68 | while (1) { 69 | for (a = 0; a < N; a++) bestd[a] = 0; 70 | for (a = 0; a < N; a++) bestw[a][0] = 0; 71 | printf("Enter word or sentence (EXIT to break): "); 72 | a = 0; 73 | while (1) { 74 | st1[a] = fgetc(stdin); 75 | if ((st1[a] == '\n') || (a >= max_size - 1)) { 76 | st1[a] = 0; 77 | break; 78 | } 79 | a++; 80 | } 81 | if (!strcmp(st1, "EXIT")) break; 82 | cn = 0; 83 | b = 0; 84 | c = 0; 85 | while (1) { 86 | st[cn][b] = st1[c]; 87 | b++; 88 | c++; 89 | st[cn][b] = 0; 90 | if (st1[c] == 0) break; 91 | if (st1[c] == ' ') { 92 | cn++; 93 | b = 0; 94 | c++; 95 | } 96 | } 97 | cn++; 98 | for (a = 0; a < cn; a++) { 99 | for (b = 0; b < words; b++) if (!strcmp(&vocab[b * max_w], st[a])) break; 100 | if (b == words) b = -1; 101 | bi[a] = b; 102 | printf("\nWord: %s Position in vocabulary: %lld\n", st[a], bi[a]); 103 | if (b == -1) { 104 | printf("Out of dictionary word!\n"); 105 | break; 106 | } 107 | } 108 | if (b == -1) continue; 109 | printf("\n Word Cosine distance\n------------------------------------------------------------------------\n"); 110 | for (a = 0; a < size; a++) vec[a] = 0; 111 | for (b = 0; b < cn; b++) { 112 | if (bi[b] == -1) continue; 113 | for (a = 0; a < size; a++) vec[a] += M[a + bi[b] * size]; 114 | } 115 | len = 0; 116 | for (a = 0; a < size; a++) len += vec[a] * vec[a]; 117 | len = sqrt(len); 118 | for (a = 0; a < size; a++) vec[a] /= len; 119 | for (a = 0; a < N; a++) bestd[a] = -1; 120 | for (a = 0; a < N; a++) bestw[a][0] = 0; 121 | for (c = 0; c < words; c++) { 122 | a = 0; 123 | for (b = 0; b < cn; b++) if (bi[b] == c) a = 1; 124 | if (a == 1) continue; 125 | dist = 0; 126 | for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size]; 127 | for (a = 0; a < N; a++) { 128 | if (dist > bestd[a]) { 129 | for (d = N - 1; d > a; d--) { 130 | bestd[d] = bestd[d - 1]; 131 | strcpy(bestw[d], bestw[d - 1]); 132 | } 133 | bestd[a] = dist; 134 | strcpy(bestw[a], &vocab[c * max_w]); 135 | break; 136 | } 137 | } 138 | } 139 | for (a = 0; a < N; a++) printf("%50s\t\t%f\n", bestw[a], bestd[a]); 140 | } 141 | return 0; 142 | } 143 | -------------------------------------------------------------------------------- /train/third_party/word2vec/makefile: -------------------------------------------------------------------------------- 1 | CC = gcc 2 | #Using -Ofast instead of -O3 might result in faster code, but is supported only by newer GCC versions 3 | CFLAGS = -lm -pthread -O3 -march=native -Wall -funroll-loops -Wno-unused-result 4 | 5 | all: word2vec word2phrase distance word-analogy compute-accuracy 6 | 7 | word2vec : word2vec.c 8 | $(CC) word2vec.c -o word2vec $(CFLAGS) 9 | word2phrase : word2phrase.c 10 | $(CC) word2phrase.c -o word2phrase $(CFLAGS) 11 | distance : distance.c 12 | $(CC) distance.c -o distance $(CFLAGS) 13 | word-analogy : word-analogy.c 14 | $(CC) word-analogy.c -o word-analogy $(CFLAGS) 15 | compute-accuracy : compute-accuracy.c 16 | $(CC) compute-accuracy.c -o compute-accuracy $(CFLAGS) 17 | chmod +x *.sh 18 | 19 | clean: 20 | rm -rf word2vec word2phrase distance word-analogy compute-accuracy -------------------------------------------------------------------------------- /train/third_party/word2vec/word-analogy.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | const long long max_size = 2000; // max length of strings 21 | const long long N = 40; // number of closest words that will be shown 22 | const long long max_w = 50; // max length of vocabulary entries 23 | 24 | int main(int argc, char **argv) { 25 | FILE *f; 26 | char st1[max_size]; 27 | char bestw[N][max_size]; 28 | char file_name[max_size], st[100][max_size]; 29 | float dist, len, bestd[N], vec[max_size]; 30 | long long words, size, a, b, c, d, cn, bi[100]; 31 | char ch; 32 | float *M; 33 | char *vocab; 34 | if (argc < 2) { 35 | printf("Usage: ./word-analogy \nwhere FILE contains word projections in the BINARY FORMAT\n"); 36 | return 0; 37 | } 38 | strcpy(file_name, argv[1]); 39 | f = fopen(file_name, "rb"); 40 | if (f == NULL) { 41 | printf("Input file not found\n"); 42 | return -1; 43 | } 44 | fscanf(f, "%lld", &words); 45 | fscanf(f, "%lld", &size); 46 | vocab = (char *)malloc((long long)words * max_w * sizeof(char)); 47 | M = (float *)malloc((long long)words * (long long)size * sizeof(float)); 48 | if (M == NULL) { 49 | printf("Cannot allocate memory: %lld MB %lld %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size); 50 | return -1; 51 | } 52 | for (b = 0; b < words; b++) { 53 | a = 0; 54 | while (1) { 55 | vocab[b * max_w + a] = fgetc(f); 56 | if (feof(f) || (vocab[b * max_w + a] == ' ')) break; 57 | if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++; 58 | } 59 | vocab[b * max_w + a] = 0; 60 | for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f); 61 | len = 0; 62 | for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size]; 63 | len = sqrt(len); 64 | for (a = 0; a < size; a++) M[a + b * size] /= len; 65 | } 66 | fclose(f); 67 | while (1) { 68 | for (a = 0; a < N; a++) bestd[a] = 0; 69 | for (a = 0; a < N; a++) bestw[a][0] = 0; 70 | printf("Enter three words (EXIT to break): "); 71 | a = 0; 72 | while (1) { 73 | st1[a] = fgetc(stdin); 74 | if ((st1[a] == '\n') || (a >= max_size - 1)) { 75 | st1[a] = 0; 76 | break; 77 | } 78 | a++; 79 | } 80 | if (!strcmp(st1, "EXIT")) break; 81 | cn = 0; 82 | b = 0; 83 | c = 0; 84 | while (1) { 85 | st[cn][b] = st1[c]; 86 | b++; 87 | c++; 88 | st[cn][b] = 0; 89 | if (st1[c] == 0) break; 90 | if (st1[c] == ' ') { 91 | cn++; 92 | b = 0; 93 | c++; 94 | } 95 | } 96 | cn++; 97 | if (cn < 3) { 98 | printf("Only %lld words were entered.. three words are needed at the input to perform the calculation\n", cn); 99 | continue; 100 | } 101 | for (a = 0; a < cn; a++) { 102 | for (b = 0; b < words; b++) if (!strcmp(&vocab[b * max_w], st[a])) break; 103 | if (b == words) b = 0; 104 | bi[a] = b; 105 | printf("\nWord: %s Position in vocabulary: %lld\n", st[a], bi[a]); 106 | if (b == 0) { 107 | printf("Out of dictionary word!\n"); 108 | break; 109 | } 110 | } 111 | if (b == 0) continue; 112 | printf("\n Word Distance\n------------------------------------------------------------------------\n"); 113 | for (a = 0; a < size; a++) vec[a] = M[a + bi[1] * size] - M[a + bi[0] * size] + M[a + bi[2] * size]; 114 | len = 0; 115 | for (a = 0; a < size; a++) len += vec[a] * vec[a]; 116 | len = sqrt(len); 117 | for (a = 0; a < size; a++) vec[a] /= len; 118 | for (a = 0; a < N; a++) bestd[a] = 0; 119 | for (a = 0; a < N; a++) bestw[a][0] = 0; 120 | for (c = 0; c < words; c++) { 121 | if (c == bi[0]) continue; 122 | if (c == bi[1]) continue; 123 | if (c == bi[2]) continue; 124 | a = 0; 125 | for (b = 0; b < cn; b++) if (bi[b] == c) a = 1; 126 | if (a == 1) continue; 127 | dist = 0; 128 | for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size]; 129 | for (a = 0; a < N; a++) { 130 | if (dist > bestd[a]) { 131 | for (d = N - 1; d > a; d--) { 132 | bestd[d] = bestd[d - 1]; 133 | strcpy(bestw[d], bestw[d - 1]); 134 | } 135 | bestd[a] = dist; 136 | strcpy(bestw[a], &vocab[c * max_w]); 137 | break; 138 | } 139 | } 140 | } 141 | for (a = 0; a < N; a++) printf("%50s\t\t%f\n", bestw[a], bestd[a]); 142 | } 143 | return 0; 144 | } 145 | -------------------------------------------------------------------------------- /train/third_party/word2vec/word2phrase.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #define MAX_STRING 60 22 | 23 | const int vocab_hash_size = 500000000; // Maximum 500M entries in the vocabulary 24 | 25 | typedef float real; // Precision of float numbers 26 | 27 | struct vocab_word { 28 | long long cn; 29 | char *word; 30 | }; 31 | 32 | char train_file[MAX_STRING], output_file[MAX_STRING]; 33 | struct vocab_word *vocab; 34 | int debug_mode = 2, min_count = 5, *vocab_hash, min_reduce = 1; 35 | long long vocab_max_size = 10000, vocab_size = 0; 36 | long long train_words = 0; 37 | real threshold = 100; 38 | 39 | unsigned long long next_random = 1; 40 | 41 | // Reads a single word from a file, assuming space + tab + EOL to be word boundaries 42 | void ReadWord(char *word, FILE *fin) { 43 | int a = 0, ch; 44 | while (!feof(fin)) { 45 | ch = fgetc(fin); 46 | if (ch == 13) continue; 47 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) { 48 | if (a > 0) { 49 | if (ch == '\n') ungetc(ch, fin); 50 | break; 51 | } 52 | if (ch == '\n') { 53 | strcpy(word, (char *)""); 54 | return; 55 | } else continue; 56 | } 57 | word[a] = ch; 58 | a++; 59 | if (a >= MAX_STRING - 1) a--; // Truncate too long words 60 | } 61 | word[a] = 0; 62 | } 63 | 64 | // Returns hash value of a word 65 | int GetWordHash(char *word) { 66 | unsigned long long a, hash = 1; 67 | for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a]; 68 | hash = hash % vocab_hash_size; 69 | return hash; 70 | } 71 | 72 | // Returns position of a word in the vocabulary; if the word is not found, returns -1 73 | int SearchVocab(char *word) { 74 | unsigned int hash = GetWordHash(word); 75 | while (1) { 76 | if (vocab_hash[hash] == -1) return -1; 77 | if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash]; 78 | hash = (hash + 1) % vocab_hash_size; 79 | } 80 | return -1; 81 | } 82 | 83 | // Reads a word and returns its index in the vocabulary 84 | int ReadWordIndex(FILE *fin) { 85 | char word[MAX_STRING]; 86 | ReadWord(word, fin); 87 | if (feof(fin)) return -1; 88 | return SearchVocab(word); 89 | } 90 | 91 | // Adds a word to the vocabulary 92 | int AddWordToVocab(char *word) { 93 | unsigned int hash, length = strlen(word) + 1; 94 | if (length > MAX_STRING) length = MAX_STRING; 95 | vocab[vocab_size].word = (char *)calloc(length, sizeof(char)); 96 | strcpy(vocab[vocab_size].word, word); 97 | vocab[vocab_size].cn = 0; 98 | vocab_size++; 99 | // Reallocate memory if needed 100 | if (vocab_size + 2 >= vocab_max_size) { 101 | vocab_max_size += 10000; 102 | vocab=(struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word)); 103 | } 104 | hash = GetWordHash(word); 105 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 106 | vocab_hash[hash]=vocab_size - 1; 107 | return vocab_size - 1; 108 | } 109 | 110 | // Used later for sorting by word counts 111 | int VocabCompare(const void *a, const void *b) { 112 | return ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn; 113 | } 114 | 115 | // Sorts the vocabulary by frequency using word counts 116 | void SortVocab() { 117 | int a; 118 | unsigned int hash; 119 | // Sort the vocabulary and keep at the first position 120 | qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare); 121 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 122 | for (a = 0; a < vocab_size; a++) { 123 | // Words occuring less than min_count times will be discarded from the vocab 124 | if (vocab[a].cn < min_count) { 125 | vocab_size--; 126 | free(vocab[vocab_size].word); 127 | } else { 128 | // Hash will be re-computed, as after the sorting it is not actual 129 | hash = GetWordHash(vocab[a].word); 130 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 131 | vocab_hash[hash] = a; 132 | } 133 | } 134 | vocab = (struct vocab_word *)realloc(vocab, vocab_size * sizeof(struct vocab_word)); 135 | } 136 | 137 | // Reduces the vocabulary by removing infrequent tokens 138 | void ReduceVocab() { 139 | int a, b = 0; 140 | unsigned int hash; 141 | for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) { 142 | vocab[b].cn = vocab[a].cn; 143 | vocab[b].word = vocab[a].word; 144 | b++; 145 | } else free(vocab[a].word); 146 | vocab_size = b; 147 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 148 | for (a = 0; a < vocab_size; a++) { 149 | // Hash will be re-computed, as it is not actual 150 | hash = GetWordHash(vocab[a].word); 151 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 152 | vocab_hash[hash] = a; 153 | } 154 | fflush(stdout); 155 | min_reduce++; 156 | } 157 | 158 | void LearnVocabFromTrainFile() { 159 | char word[MAX_STRING], last_word[MAX_STRING], bigram_word[MAX_STRING * 2]; 160 | FILE *fin; 161 | long long a, i, start = 1; 162 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 163 | fin = fopen(train_file, "rb"); 164 | if (fin == NULL) { 165 | printf("ERROR: training data file not found!\n"); 166 | exit(1); 167 | } 168 | vocab_size = 0; 169 | AddWordToVocab((char *)""); 170 | while (1) { 171 | ReadWord(word, fin); 172 | if (feof(fin)) break; 173 | if (!strcmp(word, "")) { 174 | start = 1; 175 | continue; 176 | } else start = 0; 177 | train_words++; 178 | if ((debug_mode > 1) && (train_words % 100000 == 0)) { 179 | printf("Words processed: %lldK Vocab size: %lldK %c", train_words / 1000, vocab_size / 1000, 13); 180 | fflush(stdout); 181 | } 182 | i = SearchVocab(word); 183 | if (i == -1) { 184 | a = AddWordToVocab(word); 185 | vocab[a].cn = 1; 186 | } else vocab[i].cn++; 187 | if (start) continue; 188 | sprintf(bigram_word, "%s_%s", last_word, word); 189 | bigram_word[MAX_STRING - 1] = 0; 190 | strcpy(last_word, word); 191 | i = SearchVocab(bigram_word); 192 | if (i == -1) { 193 | a = AddWordToVocab(bigram_word); 194 | vocab[a].cn = 1; 195 | } else vocab[i].cn++; 196 | if (vocab_size > vocab_hash_size * 0.7) ReduceVocab(); 197 | } 198 | SortVocab(); 199 | if (debug_mode > 0) { 200 | printf("\nVocab size (unigrams + bigrams): %lld\n", vocab_size); 201 | printf("Words in train file: %lld\n", train_words); 202 | } 203 | fclose(fin); 204 | } 205 | 206 | void TrainModel() { 207 | long long pa = 0, pb = 0, pab = 0, oov, i, li = -1, cn = 0; 208 | char word[MAX_STRING], last_word[MAX_STRING], bigram_word[MAX_STRING * 2]; 209 | real score; 210 | FILE *fo, *fin; 211 | printf("Starting training using file %s\n", train_file); 212 | LearnVocabFromTrainFile(); 213 | fin = fopen(train_file, "rb"); 214 | fo = fopen(output_file, "wb"); 215 | word[0] = 0; 216 | while (1) { 217 | strcpy(last_word, word); 218 | ReadWord(word, fin); 219 | if (feof(fin)) break; 220 | if (!strcmp(word, "")) { 221 | fprintf(fo, "\n"); 222 | continue; 223 | } 224 | cn++; 225 | if ((debug_mode > 1) && (cn % 100000 == 0)) { 226 | printf("Words written: %lldK%c", cn / 1000, 13); 227 | fflush(stdout); 228 | } 229 | oov = 0; 230 | i = SearchVocab(word); 231 | if (i == -1) oov = 1; else pb = vocab[i].cn; 232 | if (li == -1) oov = 1; 233 | li = i; 234 | sprintf(bigram_word, "%s_%s", last_word, word); 235 | bigram_word[MAX_STRING - 1] = 0; 236 | i = SearchVocab(bigram_word); 237 | if (i == -1) oov = 1; else pab = vocab[i].cn; 238 | if (pa < min_count) oov = 1; 239 | if (pb < min_count) oov = 1; 240 | if (oov) score = 0; else score = (pab - min_count) / (real)pa / (real)pb * (real)train_words; 241 | if (score > threshold) { 242 | fprintf(fo, "_%s", word); 243 | pb = 0; 244 | } else fprintf(fo, " %s", word); 245 | pa = pb; 246 | } 247 | fclose(fo); 248 | fclose(fin); 249 | } 250 | 251 | int ArgPos(char *str, int argc, char **argv) { 252 | int a; 253 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) { 254 | if (a == argc - 1) { 255 | printf("Argument missing for %s\n", str); 256 | exit(1); 257 | } 258 | return a; 259 | } 260 | return -1; 261 | } 262 | 263 | int main(int argc, char **argv) { 264 | int i; 265 | if (argc == 1) { 266 | printf("WORD2PHRASE tool v0.1a\n\n"); 267 | printf("Options:\n"); 268 | printf("Parameters for training:\n"); 269 | printf("\t-train \n"); 270 | printf("\t\tUse text data from to train the model\n"); 271 | printf("\t-output \n"); 272 | printf("\t\tUse to save the resulting word vectors / word clusters / phrases\n"); 273 | printf("\t-min-count \n"); 274 | printf("\t\tThis will discard words that appear less than times; default is 5\n"); 275 | printf("\t-threshold \n"); 276 | printf("\t\t The value represents threshold for forming the phrases (higher means less phrases); default 100\n"); 277 | printf("\t-debug \n"); 278 | printf("\t\tSet the debug mode (default = 2 = more info during training)\n"); 279 | printf("\nExamples:\n"); 280 | printf("./word2phrase -train text.txt -output phrases.txt -threshold 100 -debug 2\n\n"); 281 | return 0; 282 | } 283 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]); 284 | if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]); 285 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 286 | if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]); 287 | if ((i = ArgPos((char *)"-threshold", argc, argv)) > 0) threshold = atof(argv[i + 1]); 288 | vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word)); 289 | vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int)); 290 | TrainModel(); 291 | return 0; 292 | } 293 | -------------------------------------------------------------------------------- /train/third_party/word2vec/word2vec.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #define MAX_STRING 100 22 | #define EXP_TABLE_SIZE 1000 23 | #define MAX_EXP 6 24 | #define MAX_SENTENCE_LENGTH 1000 25 | #define MAX_CODE_LENGTH 40 26 | 27 | const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary 28 | 29 | typedef float real; // Precision of float numbers 30 | 31 | struct vocab_word { 32 | long long cn; 33 | int *point; 34 | char *word, *code, codelen; 35 | }; 36 | 37 | char train_file[MAX_STRING], output_file[MAX_STRING]; 38 | char save_vocab_file[MAX_STRING], read_vocab_file[MAX_STRING]; 39 | struct vocab_word *vocab; 40 | int binary = 0, cbow = 1, debug_mode = 2, window = 5, min_count = 5, num_threads = 12, min_reduce = 1; 41 | int *vocab_hash; 42 | long long vocab_max_size = 1000, vocab_size = 0, layer1_size = 100; 43 | long long train_words = 0, word_count_actual = 0, iter = 5, file_size = 0, classes = 0; 44 | real alpha = 0.025, starting_alpha, sample = 1e-3; 45 | real *syn0, *syn1, *syn1neg, *expTable; 46 | clock_t start; 47 | 48 | int hs = 0, negative = 5; 49 | const int table_size = 1e8; 50 | int *table; 51 | 52 | void InitUnigramTable() { 53 | int a, i; 54 | double train_words_pow = 0; 55 | double d1, power = 0.75; 56 | table = (int *)malloc(table_size * sizeof(int)); 57 | for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power); 58 | i = 0; 59 | d1 = pow(vocab[i].cn, power) / train_words_pow; 60 | for (a = 0; a < table_size; a++) { 61 | table[a] = i; 62 | if (a / (double)table_size > d1) { 63 | i++; 64 | d1 += pow(vocab[i].cn, power) / train_words_pow; 65 | } 66 | if (i >= vocab_size) i = vocab_size - 1; 67 | } 68 | } 69 | 70 | // Reads a single word from a file, assuming space + tab + EOL to be word boundaries 71 | void ReadWord(char *word, FILE *fin) { 72 | int a = 0, ch; 73 | while (!feof(fin)) { 74 | ch = fgetc(fin); 75 | if (ch == 13) continue; 76 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) { 77 | if (a > 0) { 78 | if (ch == '\n') ungetc(ch, fin); 79 | break; 80 | } 81 | if (ch == '\n') { 82 | strcpy(word, (char *)""); 83 | return; 84 | } else continue; 85 | } 86 | word[a] = ch; 87 | a++; 88 | if (a >= MAX_STRING - 1) a--; // Truncate too long words 89 | } 90 | word[a] = 0; 91 | } 92 | 93 | // Returns hash value of a word 94 | int GetWordHash(char *word) { 95 | unsigned long long a, hash = 0; 96 | for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a]; 97 | hash = hash % vocab_hash_size; 98 | return hash; 99 | } 100 | 101 | // Returns position of a word in the vocabulary; if the word is not found, returns -1 102 | int SearchVocab(char *word) { 103 | unsigned int hash = GetWordHash(word); 104 | while (1) { 105 | if (vocab_hash[hash] == -1) return -1; 106 | if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash]; 107 | hash = (hash + 1) % vocab_hash_size; 108 | } 109 | return -1; 110 | } 111 | 112 | // Reads a word and returns its index in the vocabulary 113 | int ReadWordIndex(FILE *fin) { 114 | char word[MAX_STRING]; 115 | ReadWord(word, fin); 116 | if (feof(fin)) return -1; 117 | return SearchVocab(word); 118 | } 119 | 120 | // Adds a word to the vocabulary 121 | int AddWordToVocab(char *word) { 122 | unsigned int hash, length = strlen(word) + 1; 123 | if (length > MAX_STRING) length = MAX_STRING; 124 | vocab[vocab_size].word = (char *)calloc(length, sizeof(char)); 125 | strcpy(vocab[vocab_size].word, word); 126 | vocab[vocab_size].cn = 0; 127 | vocab_size++; 128 | // Reallocate memory if needed 129 | if (vocab_size + 2 >= vocab_max_size) { 130 | vocab_max_size += 1000; 131 | vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word)); 132 | } 133 | hash = GetWordHash(word); 134 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 135 | vocab_hash[hash] = vocab_size - 1; 136 | return vocab_size - 1; 137 | } 138 | 139 | // Used later for sorting by word counts 140 | int VocabCompare(const void *a, const void *b) { 141 | return ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn; 142 | } 143 | 144 | // Sorts the vocabulary by frequency using word counts 145 | void SortVocab() { 146 | int a, size; 147 | unsigned int hash; 148 | // Sort the vocabulary and keep at the first position 149 | qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare); 150 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 151 | size = vocab_size; 152 | train_words = 0; 153 | for (a = 0; a < size; a++) { 154 | // Words occuring less than min_count times will be discarded from the vocab 155 | if ((vocab[a].cn < min_count) && (a != 0)) { 156 | vocab_size--; 157 | free(vocab[a].word); 158 | } else { 159 | // Hash will be re-computed, as after the sorting it is not actual 160 | hash=GetWordHash(vocab[a].word); 161 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 162 | vocab_hash[hash] = a; 163 | train_words += vocab[a].cn; 164 | } 165 | } 166 | vocab = (struct vocab_word *)realloc(vocab, (vocab_size + 1) * sizeof(struct vocab_word)); 167 | // Allocate memory for the binary tree construction 168 | for (a = 0; a < vocab_size; a++) { 169 | vocab[a].code = (char *)calloc(MAX_CODE_LENGTH, sizeof(char)); 170 | vocab[a].point = (int *)calloc(MAX_CODE_LENGTH, sizeof(int)); 171 | } 172 | } 173 | 174 | // Reduces the vocabulary by removing infrequent tokens 175 | void ReduceVocab() { 176 | int a, b = 0; 177 | unsigned int hash; 178 | for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) { 179 | vocab[b].cn = vocab[a].cn; 180 | vocab[b].word = vocab[a].word; 181 | b++; 182 | } else free(vocab[a].word); 183 | vocab_size = b; 184 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 185 | for (a = 0; a < vocab_size; a++) { 186 | // Hash will be re-computed, as it is not actual 187 | hash = GetWordHash(vocab[a].word); 188 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 189 | vocab_hash[hash] = a; 190 | } 191 | fflush(stdout); 192 | min_reduce++; 193 | } 194 | 195 | // Create binary Huffman tree using the word counts 196 | // Frequent words will have short uniqe binary codes 197 | void CreateBinaryTree() { 198 | long long a, b, i, min1i, min2i, pos1, pos2, point[MAX_CODE_LENGTH]; 199 | char code[MAX_CODE_LENGTH]; 200 | long long *count = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long)); 201 | long long *binary = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long)); 202 | long long *parent_node = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long)); 203 | for (a = 0; a < vocab_size; a++) count[a] = vocab[a].cn; 204 | for (a = vocab_size; a < vocab_size * 2; a++) count[a] = 1e15; 205 | pos1 = vocab_size - 1; 206 | pos2 = vocab_size; 207 | // Following algorithm constructs the Huffman tree by adding one node at a time 208 | for (a = 0; a < vocab_size - 1; a++) { 209 | // First, find two smallest nodes 'min1, min2' 210 | if (pos1 >= 0) { 211 | if (count[pos1] < count[pos2]) { 212 | min1i = pos1; 213 | pos1--; 214 | } else { 215 | min1i = pos2; 216 | pos2++; 217 | } 218 | } else { 219 | min1i = pos2; 220 | pos2++; 221 | } 222 | if (pos1 >= 0) { 223 | if (count[pos1] < count[pos2]) { 224 | min2i = pos1; 225 | pos1--; 226 | } else { 227 | min2i = pos2; 228 | pos2++; 229 | } 230 | } else { 231 | min2i = pos2; 232 | pos2++; 233 | } 234 | count[vocab_size + a] = count[min1i] + count[min2i]; 235 | parent_node[min1i] = vocab_size + a; 236 | parent_node[min2i] = vocab_size + a; 237 | binary[min2i] = 1; 238 | } 239 | // Now assign binary code to each vocabulary word 240 | for (a = 0; a < vocab_size; a++) { 241 | b = a; 242 | i = 0; 243 | while (1) { 244 | code[i] = binary[b]; 245 | point[i] = b; 246 | i++; 247 | b = parent_node[b]; 248 | if (b == vocab_size * 2 - 2) break; 249 | } 250 | vocab[a].codelen = i; 251 | vocab[a].point[0] = vocab_size - 2; 252 | for (b = 0; b < i; b++) { 253 | vocab[a].code[i - b - 1] = code[b]; 254 | vocab[a].point[i - b] = point[b] - vocab_size; 255 | } 256 | } 257 | free(count); 258 | free(binary); 259 | free(parent_node); 260 | } 261 | 262 | void LearnVocabFromTrainFile() { 263 | char word[MAX_STRING]; 264 | FILE *fin; 265 | long long a, i; 266 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 267 | fin = fopen(train_file, "rb"); 268 | if (fin == NULL) { 269 | printf("ERROR: training data file not found!\n"); 270 | exit(1); 271 | } 272 | vocab_size = 0; 273 | AddWordToVocab((char *)""); 274 | while (1) { 275 | ReadWord(word, fin); 276 | if (feof(fin)) break; 277 | train_words++; 278 | if ((debug_mode > 1) && (train_words % 100000 == 0)) { 279 | printf("%lldK%c", train_words / 1000, 13); 280 | fflush(stdout); 281 | } 282 | i = SearchVocab(word); 283 | if (i == -1) { 284 | a = AddWordToVocab(word); 285 | vocab[a].cn = 1; 286 | } else vocab[i].cn++; 287 | if (vocab_size > vocab_hash_size * 0.7) ReduceVocab(); 288 | } 289 | SortVocab(); 290 | if (debug_mode > 0) { 291 | printf("Vocab size: %lld\n", vocab_size); 292 | printf("Words in train file: %lld\n", train_words); 293 | } 294 | file_size = ftell(fin); 295 | fclose(fin); 296 | } 297 | 298 | void SaveVocab() { 299 | long long i; 300 | FILE *fo = fopen(save_vocab_file, "wb"); 301 | for (i = 0; i < vocab_size; i++) fprintf(fo, "%s %lld\n", vocab[i].word, vocab[i].cn); 302 | fclose(fo); 303 | } 304 | 305 | void ReadVocab() { 306 | long long a, i = 0; 307 | char c; 308 | char word[MAX_STRING]; 309 | FILE *fin = fopen(read_vocab_file, "rb"); 310 | if (fin == NULL) { 311 | printf("Vocabulary file not found\n"); 312 | exit(1); 313 | } 314 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 315 | vocab_size = 0; 316 | while (1) { 317 | ReadWord(word, fin); 318 | if (feof(fin)) break; 319 | a = AddWordToVocab(word); 320 | fscanf(fin, "%lld%c", &vocab[a].cn, &c); 321 | i++; 322 | } 323 | SortVocab(); 324 | if (debug_mode > 0) { 325 | printf("Vocab size: %lld\n", vocab_size); 326 | printf("Words in train file: %lld\n", train_words); 327 | } 328 | fin = fopen(train_file, "rb"); 329 | if (fin == NULL) { 330 | printf("ERROR: training data file not found!\n"); 331 | exit(1); 332 | } 333 | fseek(fin, 0, SEEK_END); 334 | file_size = ftell(fin); 335 | fclose(fin); 336 | } 337 | 338 | void InitNet() { 339 | long long a, b; 340 | unsigned long long next_random = 1; 341 | a = posix_memalign((void **)&syn0, 128, (long long)vocab_size * layer1_size * sizeof(real)); 342 | if (syn0 == NULL) {printf("Memory allocation failed\n"); exit(1);} 343 | if (hs) { 344 | a = posix_memalign((void **)&syn1, 128, (long long)vocab_size * layer1_size * sizeof(real)); 345 | if (syn1 == NULL) {printf("Memory allocation failed\n"); exit(1);} 346 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) 347 | syn1[a * layer1_size + b] = 0; 348 | } 349 | if (negative>0) { 350 | a = posix_memalign((void **)&syn1neg, 128, (long long)vocab_size * layer1_size * sizeof(real)); 351 | if (syn1neg == NULL) {printf("Memory allocation failed\n"); exit(1);} 352 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) 353 | syn1neg[a * layer1_size + b] = 0; 354 | } 355 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) { 356 | next_random = next_random * (unsigned long long)25214903917 + 11; 357 | syn0[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) / layer1_size; 358 | } 359 | CreateBinaryTree(); 360 | } 361 | 362 | void *TrainModelThread(void *id) { 363 | long long a, b, d, cw, word, last_word, sentence_length = 0, sentence_position = 0; 364 | long long word_count = 0, last_word_count = 0, sen[MAX_SENTENCE_LENGTH + 1]; 365 | long long l1, l2, c, target, label, local_iter = iter; 366 | unsigned long long next_random = (long long)id; 367 | real f, g; 368 | clock_t now; 369 | real *neu1 = (real *)calloc(layer1_size, sizeof(real)); 370 | real *neu1e = (real *)calloc(layer1_size, sizeof(real)); 371 | FILE *fi = fopen(train_file, "rb"); 372 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET); 373 | while (1) { 374 | if (word_count - last_word_count > 10000) { 375 | word_count_actual += word_count - last_word_count; 376 | last_word_count = word_count; 377 | if ((debug_mode > 1)) { 378 | now=clock(); 379 | printf("%cAlpha: %f Progress: %.2f%% Words/thread/sec: %.2fk ", 13, alpha, 380 | word_count_actual / (real)(iter * train_words + 1) * 100, 381 | word_count_actual / ((real)(now - start + 1) / (real)CLOCKS_PER_SEC * 1000)); 382 | fflush(stdout); 383 | } 384 | alpha = starting_alpha * (1 - word_count_actual / (real)(iter * train_words + 1)); 385 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001; 386 | } 387 | if (sentence_length == 0) { 388 | while (1) { 389 | word = ReadWordIndex(fi); 390 | if (feof(fi)) break; 391 | if (word == -1) continue; 392 | word_count++; 393 | if (word == 0) break; 394 | // The subsampling randomly discards frequent words while keeping the ranking same 395 | if (sample > 0) { 396 | real ran = (sqrt(vocab[word].cn / (sample * train_words)) + 1) * (sample * train_words) / vocab[word].cn; 397 | next_random = next_random * (unsigned long long)25214903917 + 11; 398 | if (ran < (next_random & 0xFFFF) / (real)65536) continue; 399 | } 400 | sen[sentence_length] = word; 401 | sentence_length++; 402 | if (sentence_length >= MAX_SENTENCE_LENGTH) break; 403 | } 404 | sentence_position = 0; 405 | } 406 | if (feof(fi) || (word_count > train_words / num_threads)) { 407 | word_count_actual += word_count - last_word_count; 408 | local_iter--; 409 | if (local_iter == 0) break; 410 | word_count = 0; 411 | last_word_count = 0; 412 | sentence_length = 0; 413 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET); 414 | continue; 415 | } 416 | word = sen[sentence_position]; 417 | if (word == -1) continue; 418 | for (c = 0; c < layer1_size; c++) neu1[c] = 0; 419 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0; 420 | next_random = next_random * (unsigned long long)25214903917 + 11; 421 | b = next_random % window; 422 | if (cbow) { //train the cbow architecture 423 | // in -> hidden 424 | cw = 0; 425 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 426 | c = sentence_position - window + a; 427 | if (c < 0) continue; 428 | if (c >= sentence_length) continue; 429 | last_word = sen[c]; 430 | if (last_word == -1) continue; 431 | for (c = 0; c < layer1_size; c++) neu1[c] += syn0[c + last_word * layer1_size]; 432 | cw++; 433 | } 434 | if (cw) { 435 | for (c = 0; c < layer1_size; c++) neu1[c] /= cw; 436 | if (hs) for (d = 0; d < vocab[word].codelen; d++) { 437 | f = 0; 438 | l2 = vocab[word].point[d] * layer1_size; 439 | // Propagate hidden -> output 440 | for (c = 0; c < layer1_size; c++) f += neu1[c] * syn1[c + l2]; 441 | if (f <= -MAX_EXP) continue; 442 | else if (f >= MAX_EXP) continue; 443 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 444 | // 'g' is the gradient multiplied by the learning rate 445 | g = (1 - vocab[word].code[d] - f) * alpha; 446 | // Propagate errors output -> hidden 447 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1[c + l2]; 448 | // Learn weights hidden -> output 449 | for (c = 0; c < layer1_size; c++) syn1[c + l2] += g * neu1[c]; 450 | } 451 | // NEGATIVE SAMPLING 452 | if (negative > 0) for (d = 0; d < negative + 1; d++) { 453 | if (d == 0) { 454 | target = word; 455 | label = 1; 456 | } else { 457 | next_random = next_random * (unsigned long long)25214903917 + 11; 458 | target = table[(next_random >> 16) % table_size]; 459 | if (target == 0) target = next_random % (vocab_size - 1) + 1; 460 | if (target == word) continue; 461 | label = 0; 462 | } 463 | l2 = target * layer1_size; 464 | f = 0; 465 | for (c = 0; c < layer1_size; c++) f += neu1[c] * syn1neg[c + l2]; 466 | if (f > MAX_EXP) g = (label - 1) * alpha; 467 | else if (f < -MAX_EXP) g = (label - 0) * alpha; 468 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 469 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2]; 470 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * neu1[c]; 471 | } 472 | // hidden -> in 473 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 474 | c = sentence_position - window + a; 475 | if (c < 0) continue; 476 | if (c >= sentence_length) continue; 477 | last_word = sen[c]; 478 | if (last_word == -1) continue; 479 | for (c = 0; c < layer1_size; c++) syn0[c + last_word * layer1_size] += neu1e[c]; 480 | } 481 | } 482 | } else { //train skip-gram 483 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 484 | c = sentence_position - window + a; 485 | if (c < 0) continue; 486 | if (c >= sentence_length) continue; 487 | last_word = sen[c]; 488 | if (last_word == -1) continue; 489 | l1 = last_word * layer1_size; 490 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0; 491 | // HIERARCHICAL SOFTMAX 492 | if (hs) for (d = 0; d < vocab[word].codelen; d++) { 493 | f = 0; 494 | l2 = vocab[word].point[d] * layer1_size; 495 | // Propagate hidden -> output 496 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1[c + l2]; 497 | if (f <= -MAX_EXP) continue; 498 | else if (f >= MAX_EXP) continue; 499 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 500 | // 'g' is the gradient multiplied by the learning rate 501 | g = (1 - vocab[word].code[d] - f) * alpha; 502 | // Propagate errors output -> hidden 503 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1[c + l2]; 504 | // Learn weights hidden -> output 505 | for (c = 0; c < layer1_size; c++) syn1[c + l2] += g * syn0[c + l1]; 506 | } 507 | // NEGATIVE SAMPLING 508 | if (negative > 0) for (d = 0; d < negative + 1; d++) { 509 | if (d == 0) { 510 | target = word; 511 | label = 1; 512 | } else { 513 | next_random = next_random * (unsigned long long)25214903917 + 11; 514 | target = table[(next_random >> 16) % table_size]; 515 | if (target == 0) target = next_random % (vocab_size - 1) + 1; 516 | if (target == word) continue; 517 | label = 0; 518 | } 519 | l2 = target * layer1_size; 520 | f = 0; 521 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1neg[c + l2]; 522 | if (f > MAX_EXP) g = (label - 1) * alpha; 523 | else if (f < -MAX_EXP) g = (label - 0) * alpha; 524 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 525 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2]; 526 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * syn0[c + l1]; 527 | } 528 | // Learn weights input -> hidden 529 | for (c = 0; c < layer1_size; c++) syn0[c + l1] += neu1e[c]; 530 | } 531 | } 532 | sentence_position++; 533 | if (sentence_position >= sentence_length) { 534 | sentence_length = 0; 535 | continue; 536 | } 537 | } 538 | fclose(fi); 539 | free(neu1); 540 | free(neu1e); 541 | pthread_exit(NULL); 542 | } 543 | 544 | void TrainModel() { 545 | long a, b, c, d; 546 | FILE *fo; 547 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 548 | printf("Starting training using file %s\n", train_file); 549 | starting_alpha = alpha; 550 | if (read_vocab_file[0] != 0) ReadVocab(); else LearnVocabFromTrainFile(); 551 | if (save_vocab_file[0] != 0) SaveVocab(); 552 | if (output_file[0] == 0) return; 553 | InitNet(); 554 | if (negative > 0) InitUnigramTable(); 555 | start = clock(); 556 | for (a = 0; a < num_threads; a++) pthread_create(&pt[a], NULL, TrainModelThread, (void *)a); 557 | for (a = 0; a < num_threads; a++) pthread_join(pt[a], NULL); 558 | fo = fopen(output_file, "wb"); 559 | if (classes == 0) { 560 | // Save the word vectors 561 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size); 562 | for (a = 0; a < vocab_size; a++) { 563 | fprintf(fo, "%s ", vocab[a].word); 564 | if (binary) for (b = 0; b < layer1_size; b++) fwrite(&syn0[a * layer1_size + b], sizeof(real), 1, fo); 565 | else for (b = 0; b < layer1_size; b++) fprintf(fo, "%lf ", syn0[a * layer1_size + b]); 566 | fprintf(fo, "\n"); 567 | } 568 | } else { 569 | // Run K-means on the word vectors 570 | int clcn = classes, iter = 10, closeid; 571 | int *centcn = (int *)malloc(classes * sizeof(int)); 572 | int *cl = (int *)calloc(vocab_size, sizeof(int)); 573 | real closev, x; 574 | real *cent = (real *)calloc(classes * layer1_size, sizeof(real)); 575 | for (a = 0; a < vocab_size; a++) cl[a] = a % clcn; 576 | for (a = 0; a < iter; a++) { 577 | for (b = 0; b < clcn * layer1_size; b++) cent[b] = 0; 578 | for (b = 0; b < clcn; b++) centcn[b] = 1; 579 | for (c = 0; c < vocab_size; c++) { 580 | for (d = 0; d < layer1_size; d++) cent[layer1_size * cl[c] + d] += syn0[c * layer1_size + d]; 581 | centcn[cl[c]]++; 582 | } 583 | for (b = 0; b < clcn; b++) { 584 | closev = 0; 585 | for (c = 0; c < layer1_size; c++) { 586 | cent[layer1_size * b + c] /= centcn[b]; 587 | closev += cent[layer1_size * b + c] * cent[layer1_size * b + c]; 588 | } 589 | closev = sqrt(closev); 590 | for (c = 0; c < layer1_size; c++) cent[layer1_size * b + c] /= closev; 591 | } 592 | for (c = 0; c < vocab_size; c++) { 593 | closev = -10; 594 | closeid = 0; 595 | for (d = 0; d < clcn; d++) { 596 | x = 0; 597 | for (b = 0; b < layer1_size; b++) x += cent[layer1_size * d + b] * syn0[c * layer1_size + b]; 598 | if (x > closev) { 599 | closev = x; 600 | closeid = d; 601 | } 602 | } 603 | cl[c] = closeid; 604 | } 605 | } 606 | // Save the K-means classes 607 | for (a = 0; a < vocab_size; a++) fprintf(fo, "%s %d\n", vocab[a].word, cl[a]); 608 | free(centcn); 609 | free(cent); 610 | free(cl); 611 | } 612 | fclose(fo); 613 | } 614 | 615 | int ArgPos(char *str, int argc, char **argv) { 616 | int a; 617 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) { 618 | if (a == argc - 1) { 619 | printf("Argument missing for %s\n", str); 620 | exit(1); 621 | } 622 | return a; 623 | } 624 | return -1; 625 | } 626 | 627 | int main(int argc, char **argv) { 628 | int i; 629 | if (argc == 1) { 630 | printf("WORD VECTOR estimation toolkit v 0.1c\n\n"); 631 | printf("Options:\n"); 632 | printf("Parameters for training:\n"); 633 | printf("\t-train \n"); 634 | printf("\t\tUse text data from to train the model\n"); 635 | printf("\t-output \n"); 636 | printf("\t\tUse to save the resulting word vectors / word clusters\n"); 637 | printf("\t-size \n"); 638 | printf("\t\tSet size of word vectors; default is 100\n"); 639 | printf("\t-window \n"); 640 | printf("\t\tSet max skip length between words; default is 5\n"); 641 | printf("\t-sample \n"); 642 | printf("\t\tSet threshold for occurrence of words. Those that appear with higher frequency in the training data\n"); 643 | printf("\t\twill be randomly down-sampled; default is 1e-3, useful range is (0, 1e-5)\n"); 644 | printf("\t-hs \n"); 645 | printf("\t\tUse Hierarchical Softmax; default is 0 (not used)\n"); 646 | printf("\t-negative \n"); 647 | printf("\t\tNumber of negative examples; default is 5, common values are 3 - 10 (0 = not used)\n"); 648 | printf("\t-threads \n"); 649 | printf("\t\tUse threads (default 12)\n"); 650 | printf("\t-iter \n"); 651 | printf("\t\tRun more training iterations (default 5)\n"); 652 | printf("\t-min-count \n"); 653 | printf("\t\tThis will discard words that appear less than times; default is 5\n"); 654 | printf("\t-alpha \n"); 655 | printf("\t\tSet the starting learning rate; default is 0.025 for skip-gram and 0.05 for CBOW\n"); 656 | printf("\t-classes \n"); 657 | printf("\t\tOutput word classes rather than word vectors; default number of classes is 0 (vectors are written)\n"); 658 | printf("\t-debug \n"); 659 | printf("\t\tSet the debug mode (default = 2 = more info during training)\n"); 660 | printf("\t-binary \n"); 661 | printf("\t\tSave the resulting vectors in binary moded; default is 0 (off)\n"); 662 | printf("\t-save-vocab \n"); 663 | printf("\t\tThe vocabulary will be saved to \n"); 664 | printf("\t-read-vocab \n"); 665 | printf("\t\tThe vocabulary will be read from , not constructed from the training data\n"); 666 | printf("\t-cbow \n"); 667 | printf("\t\tUse the continuous bag of words model; default is 1 (use 0 for skip-gram model)\n"); 668 | printf("\nExamples:\n"); 669 | printf("./word2vec -train data.txt -output vec.txt -size 200 -window 5 -sample 1e-4 -negative 5 -hs 0 -binary 0 -cbow 1 -iter 3\n\n"); 670 | return 0; 671 | } 672 | output_file[0] = 0; 673 | save_vocab_file[0] = 0; 674 | read_vocab_file[0] = 0; 675 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]); 676 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]); 677 | if ((i = ArgPos((char *)"-save-vocab", argc, argv)) > 0) strcpy(save_vocab_file, argv[i + 1]); 678 | if ((i = ArgPos((char *)"-read-vocab", argc, argv)) > 0) strcpy(read_vocab_file, argv[i + 1]); 679 | if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]); 680 | if ((i = ArgPos((char *)"-binary", argc, argv)) > 0) binary = atoi(argv[i + 1]); 681 | if ((i = ArgPos((char *)"-cbow", argc, argv)) > 0) cbow = atoi(argv[i + 1]); 682 | if (cbow) alpha = 0.05; 683 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]); 684 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 685 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]); 686 | if ((i = ArgPos((char *)"-sample", argc, argv)) > 0) sample = atof(argv[i + 1]); 687 | if ((i = ArgPos((char *)"-hs", argc, argv)) > 0) hs = atoi(argv[i + 1]); 688 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]); 689 | if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]); 690 | if ((i = ArgPos((char *)"-iter", argc, argv)) > 0) iter = atoi(argv[i + 1]); 691 | if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]); 692 | if ((i = ArgPos((char *)"-classes", argc, argv)) > 0) classes = atoi(argv[i + 1]); 693 | vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word)); 694 | vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int)); 695 | expTable = (real *)malloc((EXP_TABLE_SIZE + 1) * sizeof(real)); 696 | for (i = 0; i < EXP_TABLE_SIZE; i++) { 697 | expTable[i] = exp((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table 698 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1) 699 | } 700 | TrainModel(); 701 | return 0; 702 | } 703 | -------------------------------------------------------------------------------- /train/train_bert.sh: -------------------------------------------------------------------------------- 1 | 2 | python ./train_bert_ner.py --data_dir=data/bid_train_data \ 3 | --bert_config_file=./pretrainmodel/bert_config.json \ 4 | --init_checkpoint=./pretrainmodel/bert_model.ckpt \ 5 | --vocab_file=vocab.txt \ 6 | --output_dir=./output/all_bid_result_dir/ --do_predict --do_export 7 | 8 | -------------------------------------------------------------------------------- /train/train_bert_ner.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | 4 | """ 5 | Copyright 2018 The Google AI Language Team Authors. 6 | BASED ON Google_BERT. 7 | @Author:zhoukaiyin 8 | Adjust code for chinese ner 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import collections 15 | import os 16 | from bert import modeling 17 | from bert import optimization 18 | from bert import tokenization 19 | import tensorflow as tf 20 | from sklearn.metrics import f1_score, precision_score, recall_score 21 | from tensorflow.python.ops import math_ops 22 | from train import tf_metrics 23 | import pickle 24 | 25 | flags = tf.flags 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_string( 30 | "data_dir", None, 31 | "The input datadir.", 32 | ) 33 | 34 | flags.DEFINE_string( 35 | "bert_config_file", None, 36 | "The config json file corresponding to the pre-trained BERT model." 37 | ) 38 | 39 | flags.DEFINE_string( 40 | "task_name", "NER", "The name of the task to train." 41 | ) 42 | flags.DEFINE_bool("crf", False, "use crf!") 43 | flags.DEFINE_string( 44 | "output_dir", None, 45 | "The output directory where the model checkpoints will be written." 46 | ) 47 | 48 | ## Other parameters 49 | flags.DEFINE_string( 50 | "init_checkpoint", None, 51 | "Initial checkpoint (usually from a pre-trained BERT model)." 52 | ) 53 | 54 | flags.DEFINE_bool( 55 | "do_lower_case", True, 56 | "Whether to lower case the input text." 57 | ) 58 | 59 | flags.DEFINE_integer( 60 | "max_seq_length", 256, 61 | "The maximum total input sequence length after WordPiece tokenization." 62 | ) 63 | 64 | flags.DEFINE_bool( 65 | "do_train", True, 66 | "Whether to run training." 67 | ) 68 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 69 | 70 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 71 | 72 | flags.DEFINE_bool("do_predict", False, "Whether to run the model in inference mode on the test set.") 73 | 74 | flags.DEFINE_bool('do_export', False, 'Export model ') 75 | 76 | flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.") 77 | 78 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 79 | 80 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 81 | 82 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 83 | 84 | flags.DEFINE_float("num_train_epochs", 4.0, "Total number of training epochs to perform.") 85 | 86 | flags.DEFINE_float( 87 | "warmup_proportion", 0.1, 88 | "Proportion of training to perform linear learning rate warmup for. " 89 | "E.g., 0.1 = 10% of training.") 90 | 91 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 92 | "How often to save the model checkpoint.") 93 | 94 | flags.DEFINE_integer("iterations_per_loop", 1000, 95 | "How many steps to make in each estimator call.") 96 | 97 | flags.DEFINE_string("vocab_file", None, 98 | "The vocabulary file that the BERT model was trained on.") 99 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 100 | flags.DEFINE_integer( 101 | "num_tpu_cores", 8, 102 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 103 | 104 | 105 | class InputExample(object): 106 | """A single training/test example for simple sequence classification.""" 107 | 108 | def __init__(self, guid, text, label=None): 109 | """Constructs a InputExample. 110 | 111 | Args: 112 | guid: Unique id for the example. 113 | text_a: string. The untokenized text of the first sequence. For single 114 | sequence tasks, only this sequence must be specified. 115 | label: (Optional) string. The label of the example. This should be 116 | specified for train and dev examples, but not for test examples. 117 | """ 118 | self.guid = guid 119 | self.text = text 120 | self.label = label 121 | 122 | 123 | class InputFeatures(object): 124 | """A single set of features of data.""" 125 | 126 | def __init__(self, input_ids, input_mask, segment_ids, label_ids, ): 127 | self.input_ids = input_ids 128 | self.input_mask = input_mask 129 | self.segment_ids = segment_ids 130 | self.label_ids = label_ids 131 | # self.label_mask = label_mask 132 | 133 | 134 | class DataProcessor(object): 135 | """Base class for data converters for sequence classification data sets.""" 136 | 137 | def get_train_examples(self, data_dir): 138 | """Gets a collection of `InputExample`s for the train set.""" 139 | raise NotImplementedError() 140 | 141 | def get_dev_examples(self, data_dir): 142 | """Gets a collection of `InputExample`s for the dev set.""" 143 | raise NotImplementedError() 144 | 145 | def get_labels(self): 146 | """Gets the list of labels for this data set.""" 147 | raise NotImplementedError() 148 | 149 | @classmethod 150 | def _read_data(cls, input_file): 151 | """Reads a BIO data.""" 152 | with open(input_file) as f: 153 | lines = [] 154 | words = [] 155 | labels = [] 156 | for line in f: 157 | contends = line.strip() 158 | word = line.strip().split(' ')[0] 159 | label = line.strip().split(' ')[-1] 160 | if contends.startswith("-DOCSTART-"): 161 | words.append('') 162 | continue 163 | # if len(contends) == 0 and words[-1] == '。': 164 | if len(contends) == 0: 165 | l = ' '.join([label for label in labels if len(label) > 0]) 166 | w = ' '.join([word for word in words if len(word) > 0]) 167 | lines.append([l, w]) 168 | words = [] 169 | labels = [] 170 | continue 171 | words.append(word) 172 | labels.append(label) 173 | return lines 174 | 175 | 176 | class NerProcessor(DataProcessor): 177 | def get_train_examples(self, data_dir): 178 | return self._create_example( 179 | self._read_data(os.path.join(data_dir, "train.txt")), "train" 180 | ) 181 | 182 | def get_dev_examples(self, data_dir): 183 | return self._create_example( 184 | self._read_data(os.path.join(data_dir, "dev.txt")), "dev" 185 | ) 186 | 187 | def get_test_examples(self, data_dir): 188 | return self._create_example( 189 | self._read_data(os.path.join(data_dir, "test.txt")), "test") 190 | 191 | def get_labels(self): 192 | # prevent potential bug for chinese text mixed with english text 193 | # return ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]","[SEP]"] 194 | # labels = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X","[CLS]","[SEP]"] 195 | labels = ['O', "B-COMAPNY", "I-COMAPNY", "B-REAL", "I-REAL", "B-AMOUT", "I-AMOUT", "[CLS]", "[SEP]"] 196 | return labels 197 | 198 | def _create_example(self, lines, set_type): 199 | examples = [] 200 | for (i, line) in enumerate(lines): 201 | guid = "%s-%s" % (set_type, i) 202 | text = tokenization.convert_to_unicode(line[1]) 203 | label = tokenization.convert_to_unicode(line[0]) 204 | examples.append(InputExample(guid=guid, text=text, label=label)) 205 | return examples 206 | 207 | 208 | def write_tokens(tokens, mode): 209 | if mode == "test": 210 | path = os.path.join(FLAGS.output_dir, "token_" + mode + ".txt") 211 | wf = open(path, 'a') 212 | for token in tokens: 213 | if token != "**NULL**": 214 | wf.write(token + '\n') 215 | wf.close() 216 | 217 | 218 | def convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer, mode): 219 | textlist = example.text.split(' ') 220 | labellist = example.label.split(' ') 221 | tokens = [] 222 | labels = [] 223 | # print(textlist) 224 | for i, word in enumerate(textlist): 225 | token = tokenizer.tokenize(word) 226 | # print(token) 227 | tokens.extend(token) 228 | label_1 = labellist[i] 229 | # print(label_1) 230 | for m in range(len(token)): 231 | if m == 0: 232 | labels.append(label_1) 233 | else: 234 | labels.append("X") 235 | # print(tokens, labels) 236 | # tokens = tokenizer.tokenize(example.text) 237 | if len(tokens) >= max_seq_length - 1: 238 | tokens = tokens[0:(max_seq_length - 2)] 239 | labels = labels[0:(max_seq_length - 2)] 240 | ntokens = [] 241 | segment_ids = [] 242 | label_ids = [] 243 | ntokens.append("[CLS]") 244 | segment_ids.append(0) 245 | # append("O") or append("[CLS]") not sure! 246 | label_ids.append(label_map["[CLS]"]) 247 | for i, token in enumerate(tokens): 248 | ntokens.append(token) 249 | segment_ids.append(0) 250 | label_ids.append(label_map[labels[i]]) 251 | ntokens.append("[SEP]") 252 | segment_ids.append(0) 253 | # append("O") or append("[SEP]") not sure! 254 | label_ids.append(label_map["[SEP]"]) 255 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 256 | input_mask = [1] * len(input_ids) 257 | # label_mask = [1] * len(input_ids) 258 | while len(input_ids) < max_seq_length: 259 | input_ids.append(0) 260 | input_mask.append(0) 261 | segment_ids.append(0) 262 | # we don't concerned about it! 263 | label_ids.append(0) 264 | ntokens.append("**NULL**") 265 | # label_mask.append(0) 266 | # print(len(input_ids)) 267 | assert len(input_ids) == max_seq_length 268 | assert len(input_mask) == max_seq_length 269 | assert len(segment_ids) == max_seq_length 270 | assert len(label_ids) == max_seq_length 271 | # assert len(label_mask) == max_seq_length 272 | 273 | if ex_index < 5: 274 | tf.logging.info("*** Example ***") 275 | tf.logging.info("guid: %s" % (example.guid)) 276 | tf.logging.info("tokens: %s" % " ".join( 277 | [tokenization.printable_text(x) for x in tokens])) 278 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 279 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 280 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 281 | tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids])) 282 | # tf.logging.info("label_mask: %s" % " ".join([str(x) for x in label_mask])) 283 | 284 | feature = InputFeatures( 285 | input_ids=input_ids, 286 | input_mask=input_mask, 287 | segment_ids=segment_ids, 288 | label_ids=label_ids, 289 | # label_mask = label_mask 290 | ) 291 | write_tokens(ntokens, mode) 292 | return feature 293 | 294 | 295 | def filed_based_convert_examples_to_features( 296 | examples, label_list, max_seq_length, tokenizer, output_file, mode=None 297 | ): 298 | if os.path.exists(output_file): 299 | return 300 | label_map = {} 301 | for (i, label) in enumerate(label_list, 1): 302 | label_map[label] = i 303 | with open('./output/label2id.pkl', 'wb') as w: 304 | pickle.dump(label_map, w) 305 | 306 | writer = tf.python_io.TFRecordWriter(output_file) 307 | for (ex_index, example) in enumerate(examples): 308 | if ex_index % 5000 == 0: 309 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 310 | feature = convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer, mode) 311 | 312 | def create_int_feature(values): 313 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 314 | return f 315 | 316 | features = collections.OrderedDict() 317 | features["input_ids"] = create_int_feature(feature.input_ids) 318 | features["input_mask"] = create_int_feature(feature.input_mask) 319 | features["segment_ids"] = create_int_feature(feature.segment_ids) 320 | features["label_ids"] = create_int_feature(feature.label_ids) 321 | # features["label_mask"] = create_int_feature(feature.label_mask) 322 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 323 | writer.write(tf_example.SerializeToString()) 324 | 325 | 326 | def serving_input_receiver_fn(): 327 | input_str = tf.placeholder(tf.string, name='inputs') 328 | seq_length = 256 329 | name_to_features = { 330 | "input_ids": tf.VarLenFeature(tf.int64), 331 | "input_mask": tf.VarLenFeature(tf.int64), 332 | "segment_ids": tf.VarLenFeature(tf.int64), 333 | # "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 334 | } 335 | 336 | def _decode_record(record, name_to_features): 337 | example = tf.parse_single_example(record, name_to_features) 338 | for name in list(example.keys()): 339 | t = example[name] 340 | t = tf.sparse_tensor_to_dense(t, default_value=0) 341 | t = tf.expand_dims(t, axis=0) 342 | if t.dtype == tf.int64: 343 | t = tf.to_int32(t) 344 | example[name] = t 345 | return example 346 | 347 | features = _decode_record(input_str, name_to_features) 348 | receiver_tensors = {"inputs": input_str} 349 | return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) 350 | 351 | 352 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): 353 | name_to_features = { 354 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 355 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 356 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 357 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64), 358 | # "label_ids":tf.VarLenFeature(tf.int64), 359 | # "label_mask": tf.FixedLenFeature([seq_length], tf.int64), 360 | } 361 | 362 | def _decode_record(record, name_to_features): 363 | example = tf.parse_single_example(record, name_to_features) 364 | for name in list(example.keys()): 365 | t = example[name] 366 | if t.dtype == tf.int64: 367 | t = tf.to_int32(t) 368 | example[name] = t 369 | return example 370 | 371 | def input_fn(params): 372 | batch_size = params["batch_size"] 373 | d = tf.data.TFRecordDataset(input_file) 374 | if is_training: 375 | d = d.repeat() 376 | d = d.shuffle(buffer_size=100) 377 | d = d.apply(tf.contrib.data.map_and_batch( 378 | lambda record: _decode_record(record, name_to_features), 379 | batch_size=batch_size, 380 | drop_remainder=drop_remainder 381 | )) 382 | return d 383 | 384 | return input_fn 385 | 386 | 387 | def create_model(bert_config, is_training, input_ids, input_mask, 388 | segment_ids, num_labels, use_one_hot_embeddings): 389 | model = modeling.BertModel( 390 | config=bert_config, 391 | is_training=is_training, 392 | input_ids=input_ids, 393 | input_mask=input_mask, 394 | token_type_ids=segment_ids, 395 | use_one_hot_embeddings=use_one_hot_embeddings 396 | ) 397 | 398 | output_layer = model.get_sequence_output() 399 | 400 | hidden_size = output_layer.shape[-1].value 401 | 402 | output_weight = tf.get_variable( 403 | "output_weights", [num_labels, hidden_size], 404 | initializer=tf.truncated_normal_initializer(stddev=0.02) 405 | ) 406 | output_bias = tf.get_variable( 407 | "output_bias", [num_labels], initializer=tf.zeros_initializer() 408 | ) 409 | if is_training: 410 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 411 | output_layer = tf.reshape(output_layer, [-1, hidden_size]) 412 | logits = tf.matmul(output_layer, output_weight, transpose_b=True) 413 | logits = tf.nn.bias_add(logits, output_bias) 414 | print("logits====>", logits) 415 | return logits 416 | ########################################################################## 417 | 418 | 419 | def crf_loss(logits, labels, mask, num_labels, mask2len): 420 | """ 421 | :param logits: 422 | :param labels: 423 | :param mask2len:each sample's length 424 | :return: 425 | """ 426 | # TODO 427 | with tf.variable_scope("crf_loss"): 428 | trans = tf.get_variable( 429 | "transition", 430 | shape=[num_labels, num_labels], 431 | initializer=tf.contrib.layers.xavier_initializer() 432 | ) 433 | 434 | log_likelihood, transition = tf.contrib.crf.crf_log_likelihood(logits, labels, transition_params=trans, 435 | sequence_lengths=mask2len) 436 | loss = tf.math.reduce_mean(-log_likelihood) 437 | return loss 438 | 439 | 440 | def softmax_loss(logits, labels, num_labels, mask): 441 | print("logits ===>", logits) 442 | print("labels===>", labels) 443 | logits = tf.reshape(logits, [-1, num_labels]) 444 | labels = tf.reshape(labels, [-1]) 445 | mask = tf.cast(mask, dtype=tf.float32) 446 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 447 | loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels) 448 | loss *= tf.reshape(mask, [-1]) 449 | loss = tf.reduce_sum(loss) 450 | total_size = tf.reduce_sum(mask) 451 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 452 | loss /= total_size 453 | return loss 454 | 455 | 456 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 457 | num_train_steps, num_warmup_steps, use_tpu, 458 | use_one_hot_embeddings): 459 | def model_fn(features, mode, params): 460 | tf.logging.info("*** Features ***") 461 | for name in sorted(features.keys()): 462 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 463 | input_ids = features["input_ids"] 464 | input_mask = features["input_mask"] 465 | segment_ids = features["segment_ids"] 466 | # label_mask = features["label_mask"] 467 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 468 | if mode != tf.estimator.ModeKeys.PREDICT: 469 | label_ids = features["label_ids"] 470 | 471 | logits = create_model( 472 | bert_config, is_training, input_ids, input_mask, segment_ids, 473 | num_labels, use_one_hot_embeddings) 474 | if mode != tf.estimator.ModeKeys.PREDICT: 475 | logits = tf.reshape(logits, [-1, FLAGS.max_seq_length, num_labels]) 476 | else: 477 | shapes = input_ids.get_shape().as_list() 478 | print("input shapes:", shapes) 479 | logits = tf.reshape(logits, [1, -1, num_labels]) 480 | tvars = tf.trainable_variables() 481 | scaffold_fn = None 482 | if init_checkpoint: 483 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 484 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 485 | if use_tpu: 486 | def tpu_scaffold(): 487 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 488 | return tf.train.Scaffold() 489 | 490 | scaffold_fn = tpu_scaffold 491 | else: 492 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 493 | tf.logging.info("**** Trainable Variables ****") 494 | 495 | for var in tvars: 496 | init_string = "" 497 | if var.name in initialized_variable_names: 498 | init_string = ", *INIT_FROM_CKPT*" 499 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 500 | init_string) 501 | output_spec = None 502 | if mode != tf.estimator.ModeKeys.PREDICT: 503 | if FLAGS.crf: 504 | mask2len = tf.reduce_sum(input_mask, axis=1) 505 | loss = crf_loss(logits, label_ids, input_mask, num_labels, mask2len) 506 | else: 507 | loss = softmax_loss(logits, label_ids, num_labels, input_mask) 508 | if mode == tf.estimator.ModeKeys.TRAIN: 509 | train_op = optimization.create_optimizer( 510 | loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 511 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 512 | mode=mode, 513 | loss=loss, 514 | train_op=train_op, 515 | scaffold_fn=scaffold_fn) 516 | elif mode == tf.estimator.ModeKeys.EVAL: 517 | def metric_fn(label_ids, logits): 518 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 519 | precision = tf_metrics.precision(label_ids, predictions, num_labels, [2, 3, 4, 5, 6, 7], average="macro") 520 | recall = tf_metrics.recall(label_ids, predictions, num_labels, [2, 3, 4, 5, 6, 7], average="macro") 521 | f = tf_metrics.f1(label_ids, predictions, num_labels, [2, 3, 4, 5, 6, 7], average="macro") 522 | return { 523 | "eval_precision": precision, 524 | "eval_recall": recall, 525 | "eval_f": f, 526 | # "eval_loss": loss, 527 | } 528 | 529 | eval_metrics = (metric_fn, [label_ids, logits]) 530 | # eval_metrics = (metric_fn, [label_ids, logits]) 531 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 532 | mode=mode, 533 | loss=loss, 534 | eval_metrics=eval_metrics, 535 | scaffold_fn=scaffold_fn) 536 | else: 537 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 538 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 539 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn 540 | ) 541 | return output_spec 542 | 543 | return model_fn 544 | 545 | 546 | def main(_): 547 | tf.logging.set_verbosity(tf.logging.INFO) 548 | processors = { 549 | "ner": NerProcessor 550 | } 551 | if not FLAGS.do_train and not FLAGS.do_eval: 552 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 553 | 554 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 555 | 556 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 557 | raise ValueError( 558 | "Cannot use sequence length %d because the BERT model " 559 | "was only trained up to sequence length %d" % 560 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 561 | 562 | task_name = FLAGS.task_name.lower() 563 | if task_name not in processors: 564 | raise ValueError("Task not found: %s" % (task_name)) 565 | processor = processors[task_name]() 566 | 567 | label_list = processor.get_labels() 568 | 569 | tokenizer = tokenization.FullTokenizer( 570 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 571 | tpu_cluster_resolver = None 572 | if FLAGS.use_tpu and FLAGS.tpu_name: 573 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 574 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 575 | 576 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 577 | if not os.path.exists(FLAGS.output_dir): 578 | os.mkdir(FLAGS.output_dir) 579 | run_config = tf.contrib.tpu.RunConfig( 580 | cluster=tpu_cluster_resolver, 581 | master=FLAGS.master, 582 | model_dir=FLAGS.output_dir, 583 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 584 | tpu_config=tf.contrib.tpu.TPUConfig( 585 | iterations_per_loop=FLAGS.iterations_per_loop, 586 | num_shards=FLAGS.num_tpu_cores, 587 | per_host_input_for_training=is_per_host)) 588 | 589 | train_examples = None 590 | num_train_steps = None 591 | num_warmup_steps = None 592 | 593 | if FLAGS.do_train: 594 | train_examples = processor.get_train_examples(FLAGS.data_dir) 595 | num_train_steps = int( 596 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 597 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 598 | 599 | model_fn = model_fn_builder( 600 | bert_config=bert_config, 601 | num_labels=len(label_list) + 1, 602 | init_checkpoint=FLAGS.init_checkpoint, 603 | learning_rate=FLAGS.learning_rate, 604 | num_train_steps=num_train_steps, 605 | num_warmup_steps=num_warmup_steps, 606 | use_tpu=FLAGS.use_tpu, 607 | use_one_hot_embeddings=FLAGS.use_tpu) 608 | 609 | estimator = tf.contrib.tpu.TPUEstimator( 610 | use_tpu=FLAGS.use_tpu, 611 | model_fn=model_fn, 612 | config=run_config, 613 | train_batch_size=FLAGS.train_batch_size, 614 | eval_batch_size=FLAGS.eval_batch_size, 615 | predict_batch_size=FLAGS.predict_batch_size) 616 | 617 | if FLAGS.do_train: 618 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 619 | filed_based_convert_examples_to_features( 620 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 621 | tf.logging.info("***** Running training *****") 622 | tf.logging.info(" Num examples = %d", len(train_examples)) 623 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 624 | tf.logging.info(" Num steps = %d", num_train_steps) 625 | train_input_fn = file_based_input_fn_builder( 626 | input_file=train_file, 627 | seq_length=FLAGS.max_seq_length, 628 | is_training=True, 629 | drop_remainder=True) 630 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 631 | if FLAGS.do_eval: 632 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 633 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 634 | filed_based_convert_examples_to_features( 635 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 636 | 637 | tf.logging.info("***** Running evaluation *****") 638 | tf.logging.info(" Num examples = %d", len(eval_examples)) 639 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 640 | eval_steps = None 641 | if FLAGS.use_tpu: 642 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 643 | eval_drop_remainder = True if FLAGS.use_tpu else False 644 | eval_input_fn = file_based_input_fn_builder( 645 | input_file=eval_file, 646 | seq_length=FLAGS.max_seq_length, 647 | is_training=False, 648 | drop_remainder=eval_drop_remainder) 649 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 650 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 651 | with open(output_eval_file, "w") as writer: 652 | tf.logging.info("***** Eval results *****") 653 | for key in sorted(result.keys()): 654 | tf.logging.info(" %s = %s", key, str(result[key])) 655 | writer.write("%s = %s\n" % (key, str(result[key]))) 656 | if FLAGS.do_predict: 657 | token_path = os.path.join(FLAGS.output_dir, "token_test.txt") 658 | with open('./output/label2id.pkl', 'rb') as rf: 659 | label2id = pickle.load(rf) 660 | id2label = {value: key for key, value in label2id.items()} 661 | if os.path.exists(token_path): 662 | os.remove(token_path) 663 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 664 | 665 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 666 | filed_based_convert_examples_to_features(predict_examples, label_list, 667 | FLAGS.max_seq_length, tokenizer, 668 | predict_file, mode="test") 669 | 670 | tf.logging.info("***** Running prediction*****") 671 | tf.logging.info(" Num examples = %d", len(predict_examples)) 672 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 673 | if FLAGS.use_tpu: 674 | # Warning: According to tpu_estimator.py Prediction on TPU is an 675 | # experimental feature and hence not supported here 676 | raise ValueError("Prediction in TPU not supported") 677 | predict_drop_remainder = True if FLAGS.use_tpu else False 678 | predict_input_fn = file_based_input_fn_builder( 679 | input_file=predict_file, 680 | seq_length=FLAGS.max_seq_length, 681 | is_training=False, 682 | drop_remainder=predict_drop_remainder) 683 | 684 | result = estimator.predict(input_fn=predict_input_fn) 685 | output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt") 686 | with open(output_predict_file, 'w') as writer: 687 | for prediction in result: 688 | output_line = "\n".join(id2label[id] for id in prediction if id != 0) + "\n" 689 | writer.write(output_line) 690 | if FLAGS.do_export and FLAGS.do_predict: 691 | estimator.export_saved_model('./export_models', serving_input_receiver_fn) 692 | 693 | 694 | if __name__ == "__main__": 695 | flags.mark_flag_as_required("data_dir") 696 | flags.mark_flag_as_required("task_name") 697 | flags.mark_flag_as_required("vocab_file") 698 | flags.mark_flag_as_required("bert_config_file") 699 | flags.mark_flag_as_required("output_dir") 700 | tf.app.run() 701 | 702 | -------------------------------------------------------------------------------- /train/word2vec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding=utf-8-*- 3 | 4 | import numpy as np 5 | 6 | class Word2vec(object): 7 | def __init__(self): 8 | self.wv = {} 9 | 10 | 11 | def load_w2v_array(self, path, id_to_word, is_binary=False): 12 | """ 13 | 14 | :param path: 15 | :param vocab_index2word: vocab index to word 16 | :param is_binary: 17 | :return: 18 | """ 19 | 20 | if not is_binary: 21 | f = open(path, errors="ignore") 22 | m, n = f.readline().split() 23 | dim = int(n) 24 | 25 | print("%s words dim : %s"% (m, n )) 26 | for i, line in enumerate(f): 27 | line = line.strip("\n").strip().split(" ") 28 | word = line[0] 29 | vec =[float(v) for v in line[1:]] 30 | if len(vec)!= dim: 31 | continue 32 | 33 | self.wv[word] = vec 34 | 35 | vocab_size = len(id_to_word) 36 | embedding = [] 37 | 38 | bound = np.sqrt(6.0) / np.sqrt(vocab_size) 39 | word2vec_oov_count = 0 40 | 41 | for i in range(vocab_size): 42 | word = id_to_word.get(i) 43 | if word in self.wv: 44 | embedding.append(self.wv.get(word)) 45 | else: 46 | # todo 随机赋值为何? 47 | word2vec_oov_count += 1 48 | embedding.append(np.random.uniform(-bound, bound, dim)); 49 | 50 | print("word2vec oov count: %d"%(word2vec_oov_count,)) 51 | return np.array(embedding) 52 | 53 | 54 | --------------------------------------------------------------------------------