├── Extra ├── Bi-LSTM_CNN.html ├── Bi_LSTM_CNN.jpg └── Bi_LSTM_CNN_CRF.jpg ├── LICENSE.txt ├── PRF_Score.py ├── README.md ├── preprocess.py ├── sycws ├── __init__.py ├── data_iterator.py ├── indices.txt ├── main_body.py ├── model.py ├── model_helper.py ├── prf_script.py └── sycws.py └── third_party ├── compile_w2v.sh ├── word2vec └── word2vec.c /Extra/Bi-LSTM_CNN.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Bi-LSTM_CNN_CRF 6 | 7 | 8 |
9 | 10 | 11 | -------------------------------------------------------------------------------- /Extra/Bi_LSTM_CNN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/Extra/Bi_LSTM_CNN.jpg -------------------------------------------------------------------------------- /Extra/Bi_LSTM_CNN_CRF.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/Extra/Bi_LSTM_CNN_CRF.jpg -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /PRF_Score.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | # Perceptron word segment for Chinese sentences 4 | # 5 | # File usage: 6 | # Results score. PRF 7 | # 8 | # Author: 9 | # Synrey Yee 10 | 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import codecs 15 | import sys 16 | 17 | e = 0 # wrong words number 18 | c = 0 # correct words number 19 | N = 0 # gold words number 20 | TN = 0 # test words number 21 | 22 | test_file = sys.argv[1] 23 | gold_file = sys.argv[2] 24 | 25 | test_raw = [] 26 | with codecs.open(test_file, 'r', "utf-8") as inpt1: 27 | for line in inpt1: 28 | sent = line.strip().split() 29 | if sent: 30 | test_raw.append(sent) 31 | 32 | gold_raw = [] 33 | with codecs.open(gold_file, 'r', "utf-8") as inpt2: 34 | for line in inpt2: 35 | sent = line.strip().split() 36 | if sent: 37 | gold_raw.append(sent) 38 | N += len(sent) 39 | 40 | for i, gold_sent in enumerate(gold_raw): 41 | test_sent = test_raw[i] 42 | 43 | ig = 0 44 | it = 0 45 | glen = len(gold_sent) 46 | tlen = len(test_sent) 47 | while True: 48 | if ig >= glen or it >= tlen: 49 | break 50 | 51 | gword = gold_sent[ig] 52 | tword = test_sent[it] 53 | if gword == tword: 54 | c += 1 55 | else: 56 | lg = len(gword) 57 | lt = len(tword) 58 | while lg != lt: 59 | try: 60 | if lg < lt: 61 | ig += 1 62 | gword = gold_sent[ig] 63 | lg += len(gword) 64 | else: 65 | it += 1 66 | tword = test_sent[it] 67 | lt += len(tword) 68 | except Exception as e: 69 | # pdb.set_trace() 70 | print ("Line: %d" % (i + 1)) 71 | print ("\nIt is the user's responsibility that a sentence in must", end = ' ') 72 | print ("have the SAME LENGTH with its corresponding sentence in .\n") 73 | raise e 74 | 75 | ig += 1 76 | it += 1 77 | 78 | TN += len(test_sent) 79 | 80 | e = TN - c 81 | precision = c / TN 82 | recall = c / N 83 | F = 2 * precision * recall / (precision + recall) 84 | error_rate = e / N 85 | 86 | print ("Correct words: %d"%c) 87 | print ("Error words: %d"%e) 88 | print ("Gold words: %d\n"%N) 89 | print ("precision: %f"%precision) 90 | print ("recall: %f"%recall) 91 | print ("F-Value: %f"%F) 92 | print ("error_rate: %f"%error_rate) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSTM-(CNN)-CRF for CWS 2 | [![Python 2&3](https://img.shields.io/badge/python-2&3-brightgreen.svg)](https://www.python.org/) 3 | [![Tensorflow-1.7](https://img.shields.io/badge/tensorflow-1.7-orange.svg)](https://www.tensorflow.org/)
4 | Bi-LSTM+CNN+CRF for Chinese word segmentation.

5 | The **new version** has come. However, the old version is still available on another branch. 6 | 7 | ## Usage 8 | ***What's new?*** 9 | * The new system is arranged **more orderly**; 10 | * The CNN model has been tweaked; 11 | * Remove the limitation of maximum length of sentences, although you can still set it; 12 | * Add gradient clipping; 13 | * Pre-training is your choice (whether to use the pretrained embeddings or not), while I didn't see a non-trivial margin in my experiments; 14 | * The system can save the best model during training, scored by F-value. 15 | ### Command Step by Step 16 | * Preprocessing
17 | Used to generate training files from the Corpora such as [**People 2014**](http://www.all-terms.com/bbs/thread-7977-1-1.html) and [**icwb2-data**](http://sighan.cs.uchicago.edu/bakeoff2005/). See the source code or run *python preprocess.py -h* to see more details.
18 | 19 | For example, for the *People* data, use the default arguments; (The input file is just *--all_corpora*, the others are output files.)
20 | 21 | For the icwb2-data such as PKU: (The input files are *--all_corpora* and *--gold_file*)
22 | *python3 preprocess.py --all_corpora /home/synrey/data/icwb2-data/training/pku_training.utf8 --vob_path /home/synrey/data/icwb2-data/data-pku/vocab.txt --char_file /home/synrey/data/icwb2-data/data-pku/chars.txt --train_file_pre /home/synrey/data/icwb2-data/data-pku/train --eval_file_pre /home/synrey/data/icwb2-data/data-pku/eval --gold_file /home/synrey/data/icwb2-data/gold/pku_test_gold.utf8 --is_people False --word_freq 2* 23 | 24 | * Pretraining
25 | You may need to use the file third_party/compile_w2v.sh to compile word2vec.c firstly.
26 | For the PKU corpus:
27 | *./third_party/word2vec -train /home/synrey/data/icwb2-data/data-pku/chars.txt -output /home/synrey/data/icwb2-data/data-pku/char_vec.txt -size 100 -sample 1e-4 -negative 0 -hs 1 -min-count 2* 28 | 29 | For the People corpus:
30 | *./third_party/word2vec -train /home/synrey/data/cws-v2-data/chars.txt -output /home/synrey/data/cws-v2-data/char_vec.txt -size 100 -sample 1e-4 -negative 0 -hs 1 -min-count 3* 31 | 32 | * Training
33 | For example:
34 | 35 | *python3 -m sycws.sycws --train_prefix /home/synrey/data/cws-v2-data/train --eval_prefix /home/synrey/data/cws-v2-data/eval --vocab_file /home/synrey/data/cws-v2-data/vocab.txt --out_dir /home/synrey/data/cws-v2-data/model --model CNN-CRF* 36 | 37 | If you want to use the pretrained embeddings, set the argument **--embed_file** to the path of your embeddings, such as *--embed_file /home/synrey/data/cws-v2-data/char_vec.txt*
38 | 39 | See the source code for more args' configuration. It shuold perform well with the default parameters. Naturally, you may also try out other parameter settings. 40 | 41 | ## About the Models 42 | ### Bi-LSTM-SL-CRF 43 | Take reference to [Guillaume Lample, Miguel Ballesteros, Sandeep Subramanian, Kazuya Kawakami, and Chris Dyer. Neural Architectures for Named Entity Recognition. In Proc. ACL. 2016.](http://www.aclweb.org/anthology/N16-1030)

44 | Actually, there is a *single layer* (SL) between BiLSTM and CRF. 45 | 46 | ### Bi-LSTM-CNN-CRF 47 | See [Here](http://htmlpreview.github.io/?https://github.com/MeteorYee/LSTM-CNN-CWS/blob/master/Extra/Bi-LSTM_CNN.html).
48 | Namely, the single layer between BiLSTM and CRF is replaced by a layer of CNN. 49 | 50 | ### Comparison 51 | Experiments on corpus [**People 2014**](http://www.all-terms.com/bbs/thread-7977-1-1.html). 52 | 53 | | Models | Bi-LSTM-SL-CRF | Bi-LSTM-CNN-CRF | 54 | | :-----------: | :--------------: | :---------------: | 55 | | Precision | 96.25% | 96.30% | 56 | | Recall | 95.34% | 95.70% | 57 | | F-value | 95.79% | **96.00%** | 58 | 59 | ## Segmentation 60 | * Inference
61 | For example, to use model **BiLSTM-CNN-CRF** for decoding.
62 | 63 | *python3 -m sycws.sycws --vocab_file /home/synrey/data/cws-v2-data/vocab.txt --out_dir /home/synrey/data/cws-v2-data/model/best_Fvalue --inference_input_file /home/synrey/data/cws-v2-data/test.txt --inference_output_file /home/synrey/data/cws-v2-data/result.txt* 64 | 65 | Set *--model CRF* to use model **BiLSTM-SL-CRF** for inference. 66 | Note, Even if you use pretrained embeddings, the inference command is still the same. 67 | 68 | * PRF Scoring
69 | 70 | *python3 PRF_Score.py * 71 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 03/22/2018 6 | # 7 | # Description: preprocessing for people 2014 corpora 8 | # 9 | # Last Modified at: 05/20/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import print_function 30 | 31 | from collections import defaultdict 32 | 33 | import codecs 34 | import argparse 35 | import os 36 | 37 | 38 | NE_LEFT = u'[' 39 | NE_RIGHT = u']' 40 | DIVIDER = u'/' 41 | SPACE = u' ' 42 | UNK = u"unk" 43 | 44 | WORD_S = u'0' 45 | WORD_B = u'1' 46 | WORD_M = u'2' 47 | WORD_E = u'3' 48 | 49 | 50 | def clean_sentence(line_list): 51 | new_line_list = [] 52 | for token in line_list: 53 | div_id = token.rfind(DIVIDER) 54 | 55 | if div_id < 1: 56 | # the div_id shouldn't be lower than 1 57 | # if it does, give up the word 58 | continue 59 | 60 | word = token[ : div_id] 61 | tag = token[(div_id + 1) : ] 62 | 63 | if word[0] == NE_LEFT and len(word) > 1: 64 | new_line_list.append(word[1 : ]) 65 | 66 | elif word[-1] == NE_RIGHT and len(word) > 1: 67 | div_id = word.rfind(DIVIDER) 68 | if div_id < 1: 69 | new_line_list.append(word[ : (len(word)-1)]) 70 | else: 71 | new_line_list.append(word[ : div_id]) 72 | 73 | else: 74 | new_line_list.append(word) 75 | 76 | return new_line_list 77 | 78 | 79 | def write_line(line_list, outstream, sep = SPACE): 80 | line = sep.join(line_list) 81 | outstream.write(line + u'\n') 82 | 83 | 84 | def analyze_line(line_list, vob_dict): 85 | char_list = [] 86 | label_list = [] 87 | 88 | for word in line_list: 89 | length = len(word) 90 | if length == 1: 91 | char_list.append(word) 92 | label_list.append(WORD_S) 93 | vob_dict[word] += 1 94 | 95 | else: 96 | for pos, char in enumerate(word): 97 | if pos == 0: 98 | label_list.append(WORD_B) 99 | elif pos == (length - 1): 100 | label_list.append(WORD_E) 101 | else: 102 | label_list.append(WORD_M) 103 | 104 | char_list.append(char) 105 | vob_dict[char] += 1 106 | 107 | assert len(char_list) == len(label_list) 108 | return char_list, label_list 109 | 110 | 111 | def generate_files(corpora, vob_path, char_file, train_word_file, 112 | train_label_file, eval_word_file, eval_label_file, eval_gold_file, 113 | test_file, gold_file, step, freq, max_len): 114 | 115 | inp = codecs.open(corpora, 'r', "utf-8") 116 | 117 | tr_wd_wr = codecs.open(train_word_file, 'w', "utf-8") 118 | tr_lb_wr = codecs.open(train_label_file, 'w', "utf-8") 119 | ev_wd_wr = codecs.open(eval_word_file, 'w', "utf-8") 120 | ev_lb_wr = codecs.open(eval_label_file, 'w', "utf-8") 121 | 122 | ev_gold_wr = codecs.open(eval_gold_file, 'w', "utf-8") 123 | test_wr = codecs.open(test_file, 'w', "utf-8") 124 | gold_wr = codecs.open(gold_file, 'w', "utf-8") 125 | 126 | dump_cnt = 0 127 | vob_dict = defaultdict(int) 128 | isEval = True 129 | 130 | with inp, tr_wd_wr, tr_lb_wr, ev_wd_wr, ev_lb_wr, ev_gold_wr, test_wr, gold_wr: 131 | for ind, line in enumerate(inp): 132 | line_list = line.strip().split() 133 | if len(line_list) > max_len: 134 | dump_cnt += 1 135 | continue 136 | 137 | cleaned_line = clean_sentence(line_list) 138 | if not cleaned_line: 139 | dump_cnt += 1 140 | continue 141 | 142 | char_list, label_list = analyze_line(cleaned_line, vob_dict) 143 | 144 | if ind % step == 0: 145 | if isEval: 146 | write_line(char_list, ev_wd_wr) 147 | write_line(label_list, ev_lb_wr) 148 | write_line(cleaned_line, ev_gold_wr) 149 | isEval = False 150 | else: 151 | write_line(cleaned_line, test_wr, sep = u'') 152 | write_line(cleaned_line, gold_wr) 153 | isEval = True 154 | else: 155 | write_line(char_list, tr_wd_wr) 156 | write_line(label_list, tr_lb_wr) 157 | 158 | inp = codecs.open(corpora, 'r', "utf-8") 159 | ch_wr = codecs.open(char_file, 'w', "utf-8") 160 | with inp, ch_wr: 161 | for line in inp: 162 | line_list = line.strip().split() 163 | if len(line_list) > max_len: 164 | continue 165 | 166 | cleaned_line = clean_sentence(line_list) 167 | if not cleaned_line: 168 | continue 169 | char_list = [] 170 | for phr in cleaned_line: 171 | for ch in phr: 172 | if vob_dict[ch] < freq: 173 | char_list.append(UNK) 174 | else: 175 | char_list.append(ch) 176 | 177 | write_line(char_list, ch_wr) 178 | 179 | word_cnt = 0 180 | with codecs.open(vob_path, 'w', "utf-8") as vob_wr: 181 | vob_wr.write(UNK + u'\n') 182 | for word, fq in vob_dict.items(): 183 | if fq >= freq: 184 | vob_wr.write(word + u'\n') 185 | word_cnt += 1 186 | 187 | print("Finished, give up %d sentences." % dump_cnt) 188 | print("Select %d chars from the original %d chars" % (word_cnt, len(vob_dict))) 189 | 190 | 191 | # used for people corpora 192 | def people_main(args): 193 | corpora = args.all_corpora 194 | assert os.path.exists(corpora) 195 | 196 | total_line = 0 197 | # count the total number of lines 198 | with open(corpora, 'rb') as inp: 199 | for line in inp: 200 | total_line += 1 201 | 202 | base = 2 * args.line_cnt 203 | assert base < total_line 204 | step = total_line // base 205 | 206 | train_word_file = args.train_file_pre + ".txt" 207 | train_label_file = args.train_file_pre + ".lb" 208 | eval_word_file = args.eval_file_pre + ".txt" 209 | eval_label_file = args.eval_file_pre + ".lb" 210 | 211 | generate_files(corpora, args.vob_path, args.char_file, 212 | train_word_file, train_label_file, eval_word_file, 213 | eval_label_file, args.eval_gold_file, args.test_file, 214 | args.gold_file, step, args.word_freq, args.max_len) 215 | 216 | 217 | def analyze_write(inp, word_writer, label_writer, 218 | vob_dict = defaultdict(int)): 219 | with inp, word_writer, label_writer: 220 | for line in inp: 221 | line_list = line.strip().split() 222 | if len(line_list) < 1: 223 | continue 224 | 225 | char_list, label_list = analyze_line(line_list, vob_dict) 226 | write_line(char_list, word_writer) 227 | write_line(label_list, label_writer) 228 | 229 | 230 | # used for icwb2 data 231 | def icwb_main(args): 232 | corpora = args.all_corpora 233 | assert os.path.exists(corpora) 234 | gold_file = args.gold_file 235 | assert os.path.exists(gold_file) 236 | freq = args.word_freq 237 | 238 | train_word_file = args.train_file_pre + ".txt" 239 | train_label_file = args.train_file_pre + ".lb" 240 | eval_word_file = args.eval_file_pre + ".txt" 241 | eval_label_file = args.eval_file_pre + ".lb" 242 | 243 | train_inp = codecs.open(corpora, 'r', "utf-8") 244 | gold_inp = codecs.open(gold_file, 'r', "utf-8") 245 | 246 | ch_wr = codecs.open(args.char_file, 'w', "utf-8") 247 | tr_wd_wr = codecs.open(train_word_file, 'w', "utf-8") 248 | tr_lb_wr = codecs.open(train_label_file, 'w', "utf-8") 249 | ev_wd_wr = codecs.open(eval_word_file, 'w', "utf-8") 250 | ev_lb_wr = codecs.open(eval_label_file, 'w', "utf-8") 251 | 252 | vob_dict = defaultdict(int) 253 | analyze_write(train_inp, tr_wd_wr, tr_lb_wr, vob_dict) 254 | analyze_write(gold_inp, ev_wd_wr, ev_lb_wr) 255 | 256 | train_inp = codecs.open(corpora, 'r', "utf-8") 257 | with train_inp, ch_wr: 258 | for line in train_inp: 259 | phrases = line.strip().split() 260 | char_list = [] 261 | for phr in phrases: 262 | for ch in phr: 263 | if vob_dict[ch] < freq: 264 | char_list.append(UNK) 265 | else: 266 | char_list.append(ch) 267 | 268 | write_line(char_list, ch_wr) 269 | 270 | word_cnt = 0 271 | with codecs.open(args.vob_path, 'w', "utf-8") as vob_wr: 272 | vob_wr.write(UNK + u'\n') 273 | for word, fq in vob_dict.items(): 274 | if fq >= freq: 275 | vob_wr.write(word + u'\n') 276 | word_cnt += 1 277 | 278 | print("Finished, handling icwb2 data.") 279 | print("Select %d chars from the original %d chars" % (word_cnt, len(vob_dict))) 280 | 281 | 282 | if __name__ == '__main__': 283 | parser = argparse.ArgumentParser() 284 | parser.register("type", "bool", lambda v: v.lower() == "true") 285 | 286 | # input 287 | parser.add_argument( 288 | "--all_corpora", 289 | type = str, 290 | default = "/home/synrey/data/people2014All.txt", 291 | help = "all the corpora") 292 | 293 | # output 294 | parser.add_argument( 295 | "--vob_path", 296 | type = str, 297 | default = "/home/synrey/data/cws-v2-data/vocab.txt", 298 | help = "vocabulary's path") 299 | 300 | parser.add_argument( 301 | "--char_file", 302 | type = str, 303 | default = "/home/synrey/data/cws-v2-data/chars.txt", 304 | help = "the file used for word2vec pretraining") 305 | 306 | parser.add_argument( 307 | "--train_file_pre", 308 | type = str, 309 | default = "/home/synrey/data/cws-v2-data/train", 310 | help = "training file's prefix") 311 | 312 | parser.add_argument( 313 | "--eval_file_pre", 314 | type = str, 315 | default = "/home/synrey/data/cws-v2-data/eval", 316 | help = "eval file's prefix") 317 | 318 | parser.add_argument( 319 | "--eval_gold_file", 320 | type = str, 321 | default = "/home/synrey/data/cws-v2-data/eval_gold.txt", 322 | help = """gold file, used for the evaluation during training, \ 323 | only generated for the 'people' corpus""") 324 | 325 | parser.add_argument( 326 | "--test_file", 327 | type = str, 328 | default = "/home/synrey/data/cws-v2-data/test.txt", 329 | help = "test file, raw sentences") 330 | 331 | parser.add_argument( 332 | "--gold_file", 333 | type = str, 334 | default = "/home/synrey/data/cws-v2-data/gold.txt", 335 | help = "gold file, segmented sentences") 336 | 337 | # parameters 338 | parser.add_argument( 339 | "--word_freq", 340 | type = int, 341 | default = 3, 342 | help = "word frequency") 343 | 344 | parser.add_argument( 345 | "--line_cnt", 346 | type = int, 347 | default = 8000, 348 | help = "the number of lines in eval or test file") 349 | 350 | # NOTE: It is the max length of word sequence, not char. 351 | parser.add_argument( 352 | "--max_len", 353 | type = int, 354 | default = 120, 355 | help = "deprecate the sentences longer than ") 356 | 357 | parser.add_argument( 358 | "--is_people", 359 | type = "bool", 360 | default = True, 361 | help = "Whether it is handling with People corpora") 362 | 363 | args = parser.parse_args() 364 | if args.is_people: 365 | people_main(args) 366 | else: 367 | icwb_main(args) -------------------------------------------------------------------------------- /sycws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/sycws/__init__.py -------------------------------------------------------------------------------- /sycws/data_iterator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 03/24/2018 6 | # 7 | # Description: The iterator of training and inference data 8 | # 9 | # Last Modified at: 04/03/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import print_function 30 | 31 | from collections import namedtuple 32 | 33 | import tensorflow as tf 34 | 35 | __all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"] 36 | 37 | 38 | class BatchedInput( 39 | namedtuple("BatchedInput", 40 | ("initializer", "text", "label", 41 | "text_raw", "sequence_length"))): 42 | pass 43 | 44 | 45 | def get_infer_iterator(src_dataset, 46 | vocab_table, 47 | index_table, 48 | batch_size, 49 | max_len = None): 50 | 51 | src_dataset = src_dataset.map(lambda src : tf.string_split([src]).values) 52 | 53 | if max_len: 54 | src_dataset = src_dataset.map(lambda src : src[:max_len]) 55 | # Convert the word strings to ids 56 | src_dataset = src_dataset.map( 57 | lambda src : (src, tf.cast(vocab_table.lookup(src), tf.int32))) 58 | # Add in the word counts. 59 | # Add the fake labels, refer to model_helper.py line 160. 60 | src_dataset = src_dataset.map(lambda src_raw, src : (src_raw, src, 61 | tf.cast(index_table.lookup(src_raw), tf.int32), tf.size(src))) 62 | 63 | def batching_func(x): 64 | return x.padded_batch( 65 | batch_size, 66 | # The entry is the source line rows; 67 | # this has unknown-length vectors. The last entry is 68 | # the source row size; this is a scalar. 69 | padded_shapes=( 70 | tf.TensorShape([None]), # src_raw 71 | tf.TensorShape([None]), # src_ids 72 | tf.TensorShape([None]), # fake label ids 73 | tf.TensorShape([])), # src_len 74 | # Pad the source sequences with eos tokens. 75 | # (Though notice we don't generally need to do this since 76 | # later on we will be masking out calculations past the true sequence. 77 | 78 | # padding_values = 0, default value 79 | ) 80 | 81 | batched_dataset = batching_func(src_dataset) 82 | batched_iter = batched_dataset.make_initializable_iterator() 83 | (src_raw, src_ids, lb_ids, src_seq_len) = batched_iter.get_next() 84 | return BatchedInput( 85 | initializer = batched_iter.initializer, 86 | text = src_ids, 87 | label = lb_ids, 88 | text_raw = src_raw, 89 | sequence_length = src_seq_len) 90 | 91 | 92 | def get_iterator(txt_dataset, 93 | lb_dataset, 94 | vocab_table, 95 | index_table, 96 | batch_size, 97 | num_buckets, 98 | max_len = None, 99 | output_buffer_size = None, 100 | num_parallel_calls = 4): 101 | 102 | if not output_buffer_size: 103 | output_buffer_size = batch_size * 1000 104 | txt_lb_dataset = tf.data.Dataset.zip((txt_dataset, lb_dataset)) 105 | txt_lb_dataset = txt_lb_dataset.shuffle(output_buffer_size) 106 | 107 | txt_lb_dataset = txt_lb_dataset.map( 108 | lambda txt, lb : ( 109 | tf.string_split([txt]).values, tf.string_split([lb]).values), 110 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size) 111 | 112 | # Filter zero length input sequences. 113 | txt_lb_dataset = txt_lb_dataset.filter( 114 | lambda txt, lb : tf.logical_and(tf.size(txt) > 0, tf.size(lb) > 0)) 115 | 116 | if max_len: 117 | txt_lb_dataset = txt_lb_dataset.map( 118 | lambda txt, lb : (txt[ : max_len], lb[ : max_len]), 119 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size) 120 | # Convert the word strings to ids. Word strings that are not in the 121 | # vocab get the lookup table's default_value integer. 122 | 123 | txt_lb_dataset = txt_lb_dataset.map( 124 | lambda txt, lb : ( 125 | tf.cast(vocab_table.lookup(txt), tf.int32), 126 | tf.cast(index_table.lookup(lb), tf.int32)), 127 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size) 128 | 129 | # Add in sequence lengths. 130 | txt_lb_dataset = txt_lb_dataset.map( 131 | lambda txt, lb : ( 132 | txt, lb, tf.size(txt)), 133 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size) 134 | 135 | # Bucket by sequence length (buckets for lengths 0-9, 10-19, ...) 136 | def batching_func(x): 137 | return x.padded_batch( 138 | batch_size, 139 | # The first two entries are the text and label line rows; 140 | # these have unknown-length vectors. The last entry is 141 | # the row sizes; these are scalars. 142 | padded_shapes = ( 143 | tf.TensorShape([None]), # txt 144 | tf.TensorShape([None]), # lb 145 | tf.TensorShape([])), # length 146 | # Pad the text and label sequences with eos tokens. 147 | # (Though notice we don't generally need to do this since 148 | # later on we will be masking out calculations past the true sequence. 149 | 150 | # padding_values = (0, 0, 0) # default values, 151 | # 0 for length--unused though 152 | ) 153 | 154 | if num_buckets > 1: 155 | 156 | def key_func(unused_1, unused_2, seq_len): 157 | # Calculate bucket_width by maximum text sequence length. 158 | # Pairs with length [0, bucket_width) go to bucket 0, length 159 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length 160 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket. 161 | if max_len: 162 | bucket_width = (max_len + num_buckets - 1) // num_buckets 163 | else: 164 | bucket_width = 25 165 | 166 | # Bucket sentence pairs by the length of their text sentence and label 167 | # sentence. 168 | bucket_id = seq_len // bucket_width 169 | return tf.to_int64(tf.minimum(num_buckets, bucket_id)) 170 | 171 | def reduce_func(unused_key, windowed_data): 172 | return batching_func(windowed_data) 173 | 174 | batched_dataset = txt_lb_dataset.apply( 175 | tf.contrib.data.group_by_window( 176 | key_func = key_func, reduce_func = reduce_func, window_size = batch_size)) 177 | 178 | # One batch could have multiple windows, although there is just one window 179 | # in a batch. 180 | 181 | else: 182 | batched_dataset = batching_func(txt_lb_dataset) 183 | batched_iter = batched_dataset.make_initializable_iterator() 184 | (txt_ids, lb_ids, seq_len) = batched_iter.get_next() 185 | return BatchedInput( 186 | initializer = batched_iter.initializer, 187 | text = txt_ids, 188 | label = lb_ids, 189 | text_raw = None, 190 | sequence_length = seq_len) -------------------------------------------------------------------------------- /sycws/indices.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | -------------------------------------------------------------------------------- /sycws/main_body.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 03/24/2018 6 | # 7 | # Description: Training, evaluation and inference 8 | # 9 | # Last Modified at: 05/21/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import absolute_import 30 | from __future__ import print_function 31 | from __future__ import division 32 | 33 | from . import model_helper 34 | from . import prf_script 35 | 36 | import tensorflow as tf 37 | import numpy as np 38 | 39 | import time 40 | import os 41 | import codecs 42 | 43 | 44 | TAG_S = 0 45 | TAG_B = 1 46 | TAG_M = 2 47 | TAG_E = 3 48 | 49 | 50 | def train(hparams, model_creator): 51 | out_dir = hparams.out_dir 52 | num_train_steps = hparams.num_train_steps 53 | steps_per_stats = hparams.steps_per_stats 54 | steps_per_external_eval = hparams.steps_per_external_eval 55 | if not steps_per_external_eval: 56 | steps_per_external_eval = 10 * steps_per_stats 57 | 58 | train_model = model_helper.create_train_model(hparams, model_creator) 59 | eval_model = model_helper.create_eval_model(hparams, model_creator) 60 | infer_model = model_helper.create_infer_model(hparams, model_creator) 61 | 62 | eval_txt_file = "%s.%s" % (hparams.eval_prefix, "txt") 63 | eval_lb_file = "%s.%s" % (hparams.eval_prefix, "lb") 64 | eval_iterator_feed_dict = { 65 | eval_model.txt_file_placeholder : eval_txt_file, 66 | eval_model.lb_file_placeholder : eval_lb_file 67 | } 68 | 69 | model_dir = hparams.out_dir 70 | 71 | # TensorFlow model 72 | train_sess = tf.Session(graph = train_model.graph) 73 | eval_sess = tf.Session(graph = eval_model.graph) 74 | infer_sess = tf.Session(graph = infer_model.graph) 75 | 76 | # Read the infer data 77 | with codecs.getreader("utf-8")( 78 | tf.gfile.GFile(eval_txt_file, mode="rb")) as f: 79 | infer_data = f.read().splitlines() 80 | 81 | with train_model.graph.as_default(): 82 | loaded_train_model, global_step = model_helper.create_or_load_model( 83 | train_model.model, model_dir, train_sess, "train", init = True) 84 | 85 | print("First evaluation:") 86 | _run_full_eval(hparams, loaded_train_model, train_sess, eval_model, 87 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model, 88 | infer_sess, infer_data, global_step, init = True) 89 | 90 | print("# Initialize train iterator...") 91 | train_sess.run(train_model.iterator.initializer) 92 | 93 | process_time = 0.0 94 | while global_step < num_train_steps: 95 | # train a batch 96 | start_time = time.time() 97 | try: 98 | step_result = loaded_train_model.train(train_sess) 99 | process_time += time.time() - start_time 100 | except tf.errors.OutOfRangeError: 101 | # finish one epoch 102 | print( 103 | "# Finished an epoch, step %d. Perform evaluation" % 104 | global_step) 105 | 106 | # Save checkpoint 107 | loaded_train_model.saver.save( 108 | train_sess, 109 | os.path.join(out_dir, "segmentation.ckpt"), 110 | global_step = global_step) 111 | _run_full_eval(hparams, loaded_train_model, train_sess, eval_model, 112 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model, 113 | infer_sess, infer_data, global_step, init = False) 114 | 115 | train_sess.run(train_model.iterator.initializer) 116 | continue 117 | 118 | _, train_loss, global_step, batch_size = step_result 119 | if global_step % steps_per_stats == 0: 120 | avg_time = process_time / steps_per_stats 121 | # print loss info 122 | print("[%d][loss]: %f, time per step: %.2fs" % (global_step, 123 | train_loss, avg_time)) 124 | process_time = 0.0 125 | 126 | if global_step % steps_per_external_eval == 0: 127 | # Save checkpoint 128 | loaded_train_model.saver.save( 129 | train_sess, 130 | os.path.join(out_dir, "segmentation.ckpt"), 131 | global_step = global_step) 132 | 133 | print("External Evaluation:") 134 | _run_full_eval(hparams, loaded_train_model, train_sess, eval_model, 135 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model, 136 | infer_sess, infer_data, global_step, init = False) 137 | 138 | 139 | def evaluation(eval_model, model_dir, eval_sess, 140 | eval_iterator_feed_dict, init = True): 141 | with eval_model.graph.as_default(): 142 | loaded_eval_model, global_step = model_helper.create_or_load_model( 143 | eval_model.model, model_dir, eval_sess, "eval", init) 144 | 145 | eval_sess.run(eval_model.iterator.initializer, 146 | feed_dict = eval_iterator_feed_dict) 147 | 148 | total_char_cnt = 0 149 | total_right_cnt = 0 150 | total_line = 0 151 | while True: 152 | try: 153 | (batch_char_cnt, batch_right_cnt, batch_size, 154 | batch_lens) = loaded_eval_model.eval(eval_sess) 155 | 156 | for right_cnt, length in zip(batch_right_cnt, batch_lens): 157 | total_right_cnt += np.sum(right_cnt[ : length]) 158 | 159 | total_char_cnt += batch_char_cnt 160 | total_line += batch_size 161 | except tf.errors.OutOfRangeError: 162 | # finish the evaluation 163 | break 164 | 165 | precision = total_right_cnt / total_char_cnt 166 | print("Tagging precision: %.3f, of total %d lines" % (precision, total_line)) 167 | 168 | 169 | def _eval_inference(infer_model, infer_sess, infer_data, model_dir, hparams, init): 170 | with infer_model.graph.as_default(): 171 | loaded_infer_model, global_step = model_helper.create_or_load_model( 172 | infer_model.model, model_dir, infer_sess, "infer", init) 173 | 174 | infer_sess.run( 175 | infer_model.iterator.initializer, 176 | feed_dict = { 177 | infer_model.txt_placeholder: infer_data, 178 | infer_model.batch_size_placeholder: hparams.infer_batch_size 179 | }) 180 | 181 | test_list = [] 182 | while True: 183 | try: 184 | text_raw, decoded_tags, seq_lens = loaded_infer_model.infer(infer_sess) 185 | except tf.errors.OutOfRangeError: 186 | # finish the evaluation 187 | break 188 | _decode_by_function(lambda x : test_list.append(x), text_raw, decoded_tags, seq_lens) 189 | 190 | gold_file = hparams.eval_gold_file 191 | score = prf_script.get_prf_score(test_list, gold_file) 192 | return score 193 | 194 | 195 | def _run_full_eval(hparams, loaded_train_model, train_sess, eval_model, 196 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model, 197 | infer_sess, infer_data, global_step, init): 198 | 199 | evaluation(eval_model, model_dir, eval_sess, 200 | eval_iterator_feed_dict) 201 | score = _eval_inference(infer_model, infer_sess, infer_data, 202 | model_dir, hparams, init) 203 | # save the best model 204 | if score > getattr(hparams, "best_Fvalue"): 205 | setattr(hparams, "best_Fvalue", score) 206 | loaded_train_model.saver.save( 207 | train_sess, 208 | os.path.join( 209 | getattr(hparams, "best_Fvalue_dir"), "segmentation.ckpt"), 210 | global_step = global_step) 211 | 212 | 213 | def load_data(inference_input_file): 214 | # Load inference data. 215 | inference_data = [] 216 | with codecs.getreader("utf-8")( 217 | tf.gfile.GFile(inference_input_file, mode="rb")) as f: 218 | for line in f: 219 | line = line.strip() 220 | if line: 221 | inference_data.append(u' '.join(list(line))) 222 | 223 | return inference_data 224 | 225 | 226 | def _decode_by_function(writer_function, text_raw, decoded_tags, seq_lens): 227 | assert len(text_raw) == len(decoded_tags) 228 | assert len(seq_lens) == len(decoded_tags) 229 | 230 | for text_line, tags_line, length in zip(text_raw, decoded_tags, seq_lens): 231 | text_line = text_line[ : length] 232 | tags_line = tags_line[ : length] 233 | newline = u"" 234 | 235 | for char, tag in zip(text_line, tags_line): 236 | char = char.decode("utf-8") 237 | if tag == TAG_S or tag == TAG_B: 238 | newline += u' ' + char 239 | else: 240 | newline += char 241 | 242 | newline = newline.strip() 243 | writer_function(newline + u'\n') 244 | 245 | 246 | def inference(ckpt, input_file, trans_file, hparams, model_creator): 247 | infer_model = model_helper.create_infer_model(hparams, model_creator) 248 | 249 | infer_sess = tf.Session(graph = infer_model.graph) 250 | with infer_model.graph.as_default(): 251 | loaded_infer_model = model_helper.load_model(infer_model.model, 252 | ckpt, infer_sess, "infer", init = True) 253 | 254 | # Read data 255 | infer_data = load_data(input_file) 256 | infer_sess.run( 257 | infer_model.iterator.initializer, 258 | feed_dict = { 259 | infer_model.txt_placeholder: infer_data, 260 | infer_model.batch_size_placeholder: hparams.infer_batch_size 261 | }) 262 | 263 | with codecs.getwriter("utf-8")( 264 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f: 265 | while True: 266 | try: 267 | text_raw, decoded_tags, seq_lens = loaded_infer_model.infer(infer_sess) 268 | except tf.errors.OutOfRangeError: 269 | # finish the evaluation 270 | break 271 | _decode_by_function(lambda x : trans_f.write(x), text_raw, decoded_tags, seq_lens) -------------------------------------------------------------------------------- /sycws/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 03/24/2018 6 | # 7 | # Description: Segmentation models 8 | # 9 | # Last Modified at: 05/21/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import absolute_import 30 | from __future__ import print_function 31 | 32 | from . import model_helper 33 | from . import data_iterator 34 | 35 | import tensorflow as tf 36 | 37 | 38 | class BasicModel(object): 39 | """Bi-LSTM + single layer + CRF""" 40 | 41 | def __init__(self, hparams, mode, iterator, vocab_table): 42 | assert isinstance(iterator, data_iterator.BatchedInput) 43 | 44 | self.iterator = iterator 45 | self.mode = mode 46 | self.vocab_table = vocab_table 47 | self.vocab_size = hparams.vocab_size 48 | self.time_major = hparams.time_major 49 | 50 | # Initializer 51 | self.initializer = tf.truncated_normal_initializer 52 | 53 | # Embeddings 54 | self.init_embeddings(hparams) 55 | # Because usually the last batch may be less than the batch 56 | # size we set, we should get batch_size dynamically. 57 | self.batch_size = tf.size(self.iterator.sequence_length) 58 | 59 | ## Train graph 60 | loss = self.build_graph(hparams) 61 | 62 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN: 63 | self.train_loss = loss 64 | 65 | elif self.mode == tf.contrib.learn.ModeKeys.EVAL: 66 | self.char_count = tf.reduce_sum(self.iterator.sequence_length) 67 | self.right_count = self._calculate_right() 68 | 69 | elif self.mode == tf.contrib.learn.ModeKeys.INFER: 70 | self.decode_tags = self._decode() 71 | 72 | self.global_step = tf.Variable(0, trainable = False) 73 | params = tf.trainable_variables() 74 | 75 | # Gradients and SGD update operation for training the model. 76 | # Arrage for the embedding vars to appear at the beginning. 77 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN: 78 | self.learning_rate = tf.constant(hparams.learning_rate) 79 | 80 | # Optimizer, SGD 81 | opt = tf.train.AdamOptimizer(self.learning_rate) 82 | 83 | # Gradients 84 | gradients = tf.gradients(self.train_loss, params) 85 | 86 | clipped_grads, grad_norm = tf.clip_by_global_norm( 87 | gradients, hparams.max_gradient_norm) 88 | 89 | # self.grad_norm = grad_norm 90 | self.update = opt.apply_gradients( 91 | zip(clipped_grads, params), global_step = self.global_step) 92 | 93 | # Saver, saves 5 checkpoints by default 94 | self.saver = tf.train.Saver( 95 | tf.global_variables(), max_to_keep = 5) 96 | 97 | # Print trainable variables 98 | print("# Trainable variables") 99 | for param in params: 100 | print(" %s, %s" % (param.name, str(param.get_shape()))) 101 | 102 | 103 | def init_embeddings(self, hparams, dtype = tf.float32): 104 | with tf.variable_scope("embeddings", dtype = dtype) as scope: 105 | if hparams.embed_file: 106 | self.char_embedding = model_helper.create_pretrained_emb_from_txt( 107 | hparams.vocab_file, hparams.embed_file) 108 | else: 109 | self.char_embedding = tf.get_variable( 110 | "char_embedding", [self.vocab_size, hparams.num_units], dtype, 111 | initializer = self.initializer(stddev = hparams.init_std)) 112 | 113 | 114 | def build_graph(self, hparams): 115 | print("# creating %s graph ..." % self.mode) 116 | dtype = tf.float32 117 | 118 | with tf.variable_scope("model_body", dtype = dtype): 119 | # build Bi-LSTM 120 | encoder_outputs, encoder_state = self._encode_layer(hparams) 121 | 122 | # middle layer 123 | middle_outputs = self._middle_layer(encoder_outputs, hparams) 124 | self.middle_outputs = middle_outputs 125 | 126 | # Decoder layer 127 | xentropy = self._decode_layer(middle_outputs) 128 | 129 | # Loss 130 | if self.mode != tf.contrib.learn.ModeKeys.INFER: 131 | # get the regularization loss 132 | reg_loss = tf.losses.get_regularization_loss() 133 | loss = tf.reduce_mean(xentropy) + reg_loss 134 | else: 135 | loss = None 136 | return loss 137 | 138 | 139 | def _encode_layer(self, hparams, dtype = tf.float32): 140 | # Bi-LSTM 141 | iterator = self.iterator 142 | 143 | text = iterator.text 144 | # [batch_size, txt_ids] 145 | if self.time_major: 146 | text = tf.transpose(text) 147 | # [txt_ids, batch_size] 148 | 149 | with tf.variable_scope("encoder", dtype = dtype) as scope: 150 | # Look up embedding, emp_inp: [max_time, batch_size, num_units] 151 | encoder_emb_inp = tf.nn.embedding_lookup( 152 | self.char_embedding, text) 153 | 154 | encoder_outputs, encoder_state = ( 155 | self._build_bidirectional_rnn( 156 | inputs = encoder_emb_inp, 157 | sequence_length = iterator.sequence_length, 158 | dtype = dtype, 159 | hparams = hparams)) 160 | 161 | # Encoder_outpus (time_major): [max_time, batch_size, 2*num_units] 162 | return encoder_outputs, encoder_state 163 | 164 | 165 | def _build_bidirectional_rnn(self, inputs, sequence_length, 166 | dtype, hparams): 167 | 168 | num_units = hparams.num_units 169 | mode = self.mode 170 | dropout = hparams.dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 171 | 172 | fw_cell = tf.contrib.rnn.BasicLSTMCell(num_units) 173 | bw_cell = tf.contrib.rnn.BasicLSTMCell(num_units) 174 | 175 | if dropout > 0.0: 176 | fw_cell = tf.contrib.rnn.DropoutWrapper( 177 | cell = fw_cell, input_keep_prob = (1.0 - dropout)) 178 | 179 | bw_cell = tf.contrib.rnn.DropoutWrapper( 180 | cell = bw_cell, input_keep_prob = (1.0 - dropout)) 181 | 182 | bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn( 183 | fw_cell, 184 | bw_cell, 185 | inputs, 186 | dtype = dtype, 187 | sequence_length = sequence_length, 188 | time_major = self.time_major) 189 | 190 | # concatenate the fw and bw outputs 191 | return tf.concat(bi_outputs, -1), bi_state 192 | 193 | 194 | def _single_layer(self, encoder_outputs, hparams, dtype): 195 | num_units = hparams.num_units 196 | num_tags = hparams.num_tags 197 | 198 | with tf.variable_scope('middle', dtype = dtype) as scope: 199 | hidden_W = tf.get_variable( 200 | shape = [num_units * 2, hparams.num_tags], 201 | initializer = self.initializer(stddev = hparams.init_std), 202 | name = "weights", 203 | regularizer = tf.contrib.layers.l2_regularizer(0.001)) 204 | 205 | hidden_b = tf.Variable(tf.zeros([num_tags], name = "bias")) 206 | 207 | encoder_outputs = tf.reshape(encoder_outputs, [-1, (2 * num_units)]) 208 | middle_outputs = tf.add(tf.matmul(encoder_outputs, hidden_W), hidden_b) 209 | if self.time_major: 210 | middle_outputs = tf.reshape(middle_outputs, 211 | [-1, self.batch_size, num_tags]) 212 | middle_outputs = tf.transpose(middle_outputs, [1, 0, 2]) 213 | else: 214 | middle_outputs = tf.reshape(middle_outputs, 215 | [self.batch_size, -1, num_tags]) 216 | # [batch_size, max_time, num_tags] 217 | return middle_outputs 218 | 219 | 220 | def _cnn_layer(self, encoder_outputs, hparams, dtype): 221 | num_units = hparams.num_units 222 | # CNN 223 | with tf.variable_scope('middle', dtype = dtype) as scope: 224 | cfilter = tf.get_variable( 225 | "cfilter", 226 | shape = [1, 2, 2 * num_units, hparams.num_tags], 227 | regularizer = tf.contrib.layers.l2_regularizer(0.0001), 228 | initializer = self.initializer(stddev = hparams.filter_init_std), 229 | dtype = tf.float32) 230 | 231 | return model_helper.create_cnn_layer(encoder_outputs, self.time_major, 232 | self.batch_size, num_units, cfilter) 233 | 234 | 235 | def _middle_layer(self, encoder_outputs, hparams, dtype = tf.float32): 236 | # single layer 237 | return self._single_layer(encoder_outputs, hparams, dtype) 238 | 239 | 240 | def _decode_layer(self, middle_outputs, dtype = tf.float32): 241 | # CRF 242 | with tf.variable_scope('decoder', dtype = dtype) as scope: 243 | log_likelihood, trans_params = tf.contrib.crf.crf_log_likelihood( 244 | middle_outputs, self.iterator.label, self.iterator.sequence_length) 245 | 246 | self.trans_params = trans_params 247 | return -log_likelihood 248 | 249 | 250 | def _decode(self): 251 | decode_tags, _ = tf.contrib.crf.crf_decode(self.middle_outputs, 252 | self.trans_params, self.iterator.sequence_length) 253 | # [batch_size, max_time] 254 | 255 | return decode_tags 256 | 257 | 258 | def _calculate_right(self): 259 | decode_tags = self._decode() 260 | sign_tensor = tf.equal(decode_tags, self.iterator.label) 261 | right_count = tf.cast(sign_tensor, tf.int32) 262 | 263 | return right_count 264 | 265 | 266 | def train(self, sess): 267 | assert self.mode == tf.contrib.learn.ModeKeys.TRAIN 268 | return sess.run([self.update, 269 | self.train_loss, 270 | self.global_step, 271 | self.batch_size]) 272 | 273 | 274 | def eval(self, sess): 275 | assert self.mode == tf.contrib.learn.ModeKeys.EVAL 276 | return sess.run([self.char_count, 277 | self.right_count, 278 | self.batch_size, 279 | self.iterator.sequence_length]) 280 | 281 | 282 | def infer(self, sess): 283 | assert self.mode == tf.contrib.learn.ModeKeys.INFER 284 | return sess.run([self.iterator.text_raw, 285 | self.decode_tags, 286 | self.iterator.sequence_length]) 287 | 288 | 289 | class CnnCrfModel(BasicModel): 290 | """Bi-LSTM + CNN + CRF""" 291 | 292 | def _middle_layer(self, encoder_outputs, hparams, dtype = tf.float32): 293 | # CNN 294 | return self._cnn_layer(encoder_outputs, hparams, dtype) -------------------------------------------------------------------------------- /sycws/model_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 03/24/2018 6 | # 7 | # Description: Help functions for the implementation of the models 8 | # 9 | # Last Modified at: 05/20/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import absolute_import 30 | from __future__ import print_function 31 | 32 | from collections import namedtuple 33 | from tensorflow.python.ops import lookup_ops 34 | 35 | from . import data_iterator 36 | 37 | import tensorflow as tf 38 | import numpy as np 39 | import time 40 | import codecs 41 | 42 | __all__ = [ 43 | "create_train_model", "create_eval_model", 44 | "create_infer_model", "create_cnn_layer" 45 | "create_or_load_model", "load_model", 46 | "create_pretrained_emb_from_txt" 47 | ] 48 | 49 | 50 | UNK_ID = 0 51 | 52 | 53 | # subclass to purify the __dict__ 54 | class TrainModel( 55 | namedtuple("TrainModel", ("graph", "model", 56 | "iterator"))): 57 | pass 58 | 59 | 60 | def create_train_model(hparams, model_creator): 61 | txt_file = "%s.%s" % (hparams.train_prefix, "txt") 62 | lb_file = "%s.%s" % (hparams.train_prefix, "lb") 63 | vocab_file = hparams.vocab_file 64 | index_file = hparams.index_file 65 | 66 | graph = tf.Graph() 67 | 68 | with graph.as_default(), tf.container("train"): 69 | vocab_table = lookup_ops.index_table_from_file( 70 | vocab_file, default_value = UNK_ID) 71 | # for the labels 72 | index_table = lookup_ops.index_table_from_file( 73 | index_file, default_value = 0) 74 | 75 | txt_dataset = tf.data.TextLineDataset(txt_file) 76 | lb_dataset = tf.data.TextLineDataset(lb_file) 77 | 78 | iterator = data_iterator.get_iterator( 79 | txt_dataset, 80 | lb_dataset, 81 | vocab_table, 82 | index_table, 83 | batch_size = hparams.batch_size, 84 | num_buckets = hparams.num_buckets, 85 | max_len = hparams.max_len) 86 | 87 | model = model_creator( 88 | hparams, 89 | iterator = iterator, 90 | mode = tf.contrib.learn.ModeKeys.TRAIN, 91 | vocab_table = vocab_table) 92 | 93 | return TrainModel( 94 | graph = graph, 95 | model = model, 96 | iterator = iterator) 97 | 98 | 99 | class EvalModel( 100 | namedtuple("EvalModel", 101 | ("graph", "model", "txt_file_placeholder", 102 | "lb_file_placeholder", "iterator"))): 103 | pass 104 | 105 | 106 | def create_eval_model(hparams, model_creator): 107 | vocab_file = hparams.vocab_file 108 | index_file = hparams.index_file 109 | graph = tf.Graph() 110 | 111 | with graph.as_default(), tf.container("eval"): 112 | vocab_table = lookup_ops.index_table_from_file( 113 | vocab_file, default_value = UNK_ID) 114 | # for the labels 115 | index_table = lookup_ops.index_table_from_file( 116 | index_file, default_value = 0) 117 | 118 | # the file's name 119 | txt_file_placeholder = tf.placeholder(shape = (), dtype = tf.string) 120 | lb_file_placeholder = tf.placeholder(shape = (), dtype = tf.string) 121 | txt_dataset = tf.data.TextLineDataset(txt_file_placeholder) 122 | lb_dataset = tf.data.TextLineDataset(lb_file_placeholder) 123 | 124 | iterator = data_iterator.get_iterator( 125 | txt_dataset, 126 | lb_dataset, 127 | vocab_table, 128 | index_table, 129 | batch_size = hparams.batch_size, 130 | num_buckets = hparams.num_buckets, 131 | max_len = hparams.max_len) 132 | 133 | model = model_creator( 134 | hparams, 135 | iterator = iterator, 136 | mode = tf.contrib.learn.ModeKeys.EVAL, 137 | vocab_table = vocab_table) 138 | 139 | return EvalModel( 140 | graph = graph, 141 | model = model, 142 | txt_file_placeholder = txt_file_placeholder, 143 | lb_file_placeholder = lb_file_placeholder, 144 | iterator = iterator) 145 | 146 | 147 | class InferModel( 148 | namedtuple("InferModel", 149 | ("graph", "model", "txt_placeholder", 150 | "batch_size_placeholder", "iterator"))): 151 | pass 152 | 153 | 154 | def create_infer_model(hparams, model_creator): 155 | """Create inference model.""" 156 | graph = tf.Graph() 157 | vocab_file = hparams.vocab_file 158 | 159 | with graph.as_default(), tf.container("infer"): 160 | vocab_table = lookup_ops.index_table_from_file( 161 | vocab_file, default_value = UNK_ID) 162 | # for the labels 163 | ''' 164 | Although this is nonsense for the inference procedure, this is to ensure 165 | the labels are not None when building the model graph. 166 | (refer to model.BasicModel._decode_layer) 167 | ''' 168 | mapping_strings = tf.constant(['0']) 169 | index_table = tf.contrib.lookup.index_table_from_tensor( 170 | mapping = mapping_strings, default_value = 0) 171 | 172 | txt_placeholder = tf.placeholder(shape=[None], dtype = tf.string) 173 | batch_size_placeholder = tf.placeholder(shape = [], dtype = tf.int64) 174 | 175 | txt_dataset = tf.data.Dataset.from_tensor_slices( 176 | txt_placeholder) 177 | iterator = data_iterator.get_infer_iterator( 178 | txt_dataset, 179 | vocab_table, 180 | index_table, 181 | batch_size = batch_size_placeholder) 182 | 183 | model = model_creator( 184 | hparams, 185 | iterator = iterator, 186 | mode = tf.contrib.learn.ModeKeys.INFER, 187 | vocab_table = vocab_table) 188 | 189 | return InferModel( 190 | graph = graph, 191 | model = model, 192 | txt_placeholder = txt_placeholder, 193 | batch_size_placeholder = batch_size_placeholder, 194 | iterator = iterator) 195 | 196 | 197 | def _load_vocab(vocab_file): 198 | vocab = [] 199 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: 200 | vocab_size = 0 201 | for word in f: 202 | vocab_size += 1 203 | vocab.append(word.strip()) 204 | return vocab, vocab_size 205 | 206 | 207 | def _load_embed_txt(embed_file): 208 | """Load embed_file into a python dictionary. 209 | 210 | Note: the embed_file should be a Glove formated txt file. Assuming 211 | embed_size=5, for example: 212 | 213 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 214 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 215 | and 0.20327 0.47348 0.050877 0.002103 0.060547 216 | 217 | Note: The first line stores the information of the # of embeddings and 218 | the size of an embedding. 219 | 220 | Args: 221 | embed_file: file path to the embedding file. 222 | Returns: 223 | a dictionary that maps word to vector, and the size of embedding dimensions. 224 | """ 225 | emb_dict = dict() 226 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f: 227 | emb_num, emb_size = f.readline().strip().split() 228 | emb_num = int(emb_num) 229 | emb_size = int(emb_size) 230 | for line in f: 231 | tokens = line.strip().split(" ") 232 | word = tokens[0] 233 | vec = list(map(float, tokens[1:])) 234 | emb_dict[word] = vec 235 | assert emb_size == len(vec), "All embedding size should be same." 236 | return emb_dict, emb_size 237 | 238 | 239 | def create_pretrained_emb_from_txt(vocab_file, embed_file, dtype = tf.float32): 240 | """Load pretrain embeding from embed_file, and return an embedding matrix. 241 | 242 | Args: 243 | embed_file: Path to a Glove formated embedding txt file. 244 | Note: we only need the embeddings whose corresponding words are in the 245 | vocab_file. 246 | """ 247 | vocab, _ = _load_vocab(vocab_file) 248 | 249 | print('# Using pretrained embedding: %s.' % embed_file) 250 | emb_dict, emb_size = _load_embed_txt(embed_file) 251 | 252 | emb_mat = np.array( 253 | [emb_dict[token] for token in vocab], dtype = dtype.as_numpy_dtype()) 254 | 255 | # The commented codes below are used for creating untrainable embeddings, which means 256 | # the value of each embedding is settled. 257 | # num_trainable_tokens = 1 # the unk token is trainable 258 | # emb_mat = tf.constant(emb_mat) 259 | # emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) 260 | # with tf.variable_scope("pretrain_embeddings", dtype = dtype) as scope: 261 | # emb_mat_var = tf.get_variable( 262 | # "emb_mat_var", [num_trainable_tokens, emb_size]) 263 | # return tf.concat([emb_mat_var, emb_mat_const], 0) 264 | 265 | # Whereas we use the pretrained embedding values as initial values, 266 | # so the embeddings can be trainable and their values can be changed. 267 | return tf.Variable(emb_mat, name = "char_embedding") 268 | 269 | 270 | def _char_convolution(inputs, cfilter): 271 | conv1 = tf.nn.conv2d(inputs, cfilter, [1, 1, 1, 1], 272 | padding = 'VALID') 273 | # inputs.shape = [batch_size, 1, 3, 2*num_units] 274 | # namely, in_height = 1, in_width = 3, in_channels = 2*num_units 275 | # filter.shape = [1, 2, 2*num_units, num_tags], strides = [1, 1, 1, 1] 276 | # conv1.shape = [batch_size, 1, 2, num_tags] 277 | conv1 = tf.nn.relu(conv1) 278 | pool1 = tf.nn.max_pool(conv1, 279 | ksize = [1, 1, 2, 1], 280 | strides = [1, 1, 1, 1], 281 | padding = 'VALID') 282 | 283 | # pool1.shape = [batch_size, 1, 1, num_tags] 284 | pool1 = tf.squeeze(pool1, [1, 2]) 285 | # pool1.shape = [batch_size, num_tags] 286 | return pool1 287 | 288 | 289 | def create_cnn_layer(inputs, time_major, batch_size, num_units, cfilter): 290 | if not time_major: 291 | # trnaspose 292 | inputs = tf.trnaspose(inputs, [1, 0, 2]) 293 | 294 | inputs = tf.expand_dims(inputs, 2) 295 | # [max_time, batch_size, 1, 2*num_units] 296 | 297 | left = inputs[1 : ] 298 | right = inputs[ : -1] 299 | left = tf.pad(left, [[1, 0], [0, 0], [0, 0], [0, 0]], "CONSTANT") 300 | right = tf.pad(right, [[0, 1], [0, 0], [0, 0], [0, 0]], "CONSTANT") 301 | 302 | char_blocks = tf.concat([left, inputs, right], 3) 303 | # [max_time, batch_size, 1, 3*2*num_units] 304 | char_blocks = tf.reshape(char_blocks, [-1, batch_size, 3, (2 * num_units)]) 305 | char_blocks = tf.expand_dims(char_blocks, 2) 306 | # [max_time, batch_size, 1, 3, 2*num_units] 307 | 308 | do_char_conv = lambda x : _char_convolution(x, cfilter) 309 | cnn_outputs = tf.map_fn(do_char_conv, char_blocks) 310 | # [max_time, batch_size, num_tags] 311 | 312 | return tf.transpose(cnn_outputs, [1, 0, 2]) 313 | 314 | 315 | def load_model(model, ckpt, session, name, init): 316 | start_time = time.time() 317 | model.saver.restore(session, ckpt) 318 | if init: 319 | session.run(tf.tables_initializer()) 320 | print( 321 | " loaded %s model parameters from %s, time %.2fs" % 322 | (name, ckpt, time.time() - start_time)) 323 | return model 324 | 325 | 326 | def create_or_load_model(model, model_dir, session, name, init): 327 | """Create segmentation model and initialize or load parameters in session.""" 328 | latest_ckpt = tf.train.latest_checkpoint(model_dir) 329 | if latest_ckpt: 330 | model = load_model(model, latest_ckpt, session, name, init) 331 | else: 332 | start_time = time.time() 333 | session.run(tf.global_variables_initializer()) 334 | session.run(tf.tables_initializer()) 335 | print(" created %s model with fresh parameters, time %.2fs" % 336 | (name, time.time() - start_time)) 337 | 338 | global_step = model.global_step.eval(session = session) 339 | return model, global_step -------------------------------------------------------------------------------- /sycws/prf_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 05/20/2018 6 | # 7 | # Description: The PRF scoring script used for evaluation 8 | # 9 | # Last Modified at: 05/20/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import print_function 30 | from __future__ import division 31 | 32 | import codecs 33 | 34 | def get_prf_score(test_list, gold_file): 35 | e = 0 # wrong words number 36 | c = 0 # correct words number 37 | N = 0 # gold words number 38 | TN = 0 # test words number 39 | 40 | assert type(test_list) == list 41 | test_raw = [] 42 | for line in test_list: 43 | sent = line.strip().split() 44 | if sent: 45 | test_raw.append(sent) 46 | 47 | gold_raw = [] 48 | with codecs.open(gold_file, 'r', "utf-8") as inpt2: 49 | for line in inpt2: 50 | sent = line.strip().split() 51 | if sent: 52 | gold_raw.append(sent) 53 | N += len(sent) 54 | 55 | for i, gold_sent in enumerate(gold_raw): 56 | test_sent = test_raw[i] 57 | 58 | ig = 0 59 | it = 0 60 | glen = len(gold_sent) 61 | tlen = len(test_sent) 62 | while True: 63 | if ig >= glen or it >= tlen: 64 | break 65 | 66 | gword = gold_sent[ig] 67 | tword = test_sent[it] 68 | if gword == tword: 69 | c += 1 70 | else: 71 | lg = len(gword) 72 | lt = len(tword) 73 | while lg != lt: 74 | try: 75 | if lg < lt: 76 | ig += 1 77 | gword = gold_sent[ig] 78 | lg += len(gword) 79 | else: 80 | it += 1 81 | tword = test_sent[it] 82 | lt += len(tword) 83 | except Exception as e: 84 | # pdb.set_trace() 85 | print ("Line: %d" % (i + 1)) 86 | print ("\nIt is the user's responsibility that a sentence in must", end = ' ') 87 | print ("have the SAME LENGTH with its corresponding sentence in .\n") 88 | raise e 89 | 90 | ig += 1 91 | it += 1 92 | 93 | TN += len(test_sent) 94 | 95 | e = TN - c 96 | precision = c / TN 97 | recall = c / N 98 | F = 2 * precision * recall / (precision + recall) 99 | error_rate = e / N 100 | 101 | print ("Correct words: %d"%c) 102 | print ("Error words: %d"%e) 103 | print ("Gold words: %d\n"%N) 104 | print ("precision: %f"%precision) 105 | print ("recall: %f"%recall) 106 | print ("F-Value: %f"%F) 107 | print ("error_rate: %f"%error_rate) 108 | 109 | return F -------------------------------------------------------------------------------- /sycws/sycws.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 03/21/2018 6 | # 7 | # Description: A neural network tool for Chinese word segmentation 8 | # 9 | # Last Modified at: 05/21/2018, by: Synrey Yee 10 | 11 | ''' 12 | ========================================================================== 13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved. 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | ========================================================================== 27 | ''' 28 | 29 | from __future__ import absolute_import 30 | from __future__ import print_function 31 | 32 | from . import main_body 33 | from . import model as model_r 34 | 35 | import tensorflow as tf 36 | 37 | import argparse 38 | import sys 39 | import os 40 | import codecs 41 | 42 | 43 | # the parsed parameters 44 | FLAGS = None 45 | UNK = u"unk" 46 | 47 | 48 | # build the training parameters 49 | def add_arguments(parser): 50 | parser.register("type", "bool", lambda v: v.lower() == "true") 51 | 52 | # data 53 | parser.add_argument("--train_prefix", type = str, default = None, 54 | help = "Train prefix, expect files with .txt/.lb suffixes.") 55 | parser.add_argument("--eval_prefix", type = str, default = None, 56 | help = "Eval prefix, expect files with .txt/.lb suffixes.") 57 | parser.add_argument("--eval_gold_file", type = str, default = None, 58 | help = "Eval gold file.") 59 | 60 | parser.add_argument("--vocab_file", type = str, default = None, 61 | help = "Vocabulary file.") 62 | parser.add_argument("--embed_file", type = str, default = None, help="""\ 63 | Pretrained embedding files, The expecting files should be Glove formated txt files.\ 64 | """) 65 | parser.add_argument("--index_file", type = str, default = "./sycws/indices.txt", 66 | help = "Indices file.") 67 | parser.add_argument("--out_dir", type = str, default = None, 68 | help = "Store log/model files.") 69 | parser.add_argument("--max_len", type = int, default = 150, 70 | help = "Max length of char sequences during training.") 71 | 72 | # hyperparameters 73 | parser.add_argument("--num_units", type = int, default = 100, help = "Network size.") 74 | parser.add_argument("--model", type = str, default = "CNN-CRF", 75 | help = "2 kind of models: BiLSTM + (CRF | CNN-CRF)") 76 | 77 | parser.add_argument("--learning_rate", type = float, default = 0.001, 78 | help = "Learning rate. Adam: 0.001 | 0.0001") 79 | 80 | parser.add_argument("--num_train_steps", 81 | type = int, default = 45000, help = "Num steps to train.") 82 | 83 | parser.add_argument("--init_std", type = float, default = 0.05, 84 | help = "for truncated normal init_op") 85 | parser.add_argument("--filter_init_std", type = float, default = 0.035, 86 | help = "truncated normal initialization for CNN's filter") 87 | 88 | parser.add_argument("--dropout", type = float, default = 0.3, 89 | help = "Dropout rate (not keep_prob)") 90 | parser.add_argument("--max_gradient_norm", type = float, default = 5.0, 91 | help = "Clip gradients to this norm.") 92 | parser.add_argument("--batch_size", type = int, default = 128, help = "Batch size.") 93 | 94 | parser.add_argument("--steps_per_stats", type = int, default = 100, 95 | help = "How many training steps to print loss.") 96 | parser.add_argument("--num_buckets", type = int, default = 5, 97 | help = "Put data into similar-length buckets.") 98 | parser.add_argument("--steps_per_external_eval", type = int, default = None, 99 | help = """\ 100 | How many training steps to do per external evaluation. Automatically set 101 | based on data if None.\ 102 | """) 103 | 104 | parser.add_argument("--num_tags", type = int, default = 4, help = "BMES") 105 | parser.add_argument("--time_major", type = "bool", nargs = "?", const = True, 106 | default = True, 107 | help = "Whether to use time_major mode for dynamic RNN.") 108 | 109 | # Inference 110 | parser.add_argument("--ckpt", type = str, default = None, 111 | help = "Checkpoint file to load a model for inference.") 112 | parser.add_argument("--inference_input_file", type = str, default = None, 113 | help = "Set to the text to decode.") 114 | parser.add_argument("--infer_batch_size", type = int, default = 32, 115 | help = "Batch size for inference mode.") 116 | parser.add_argument("--inference_output_file", type = str, default = None, 117 | help = "Output file to store decoding results.") 118 | 119 | def create_hparams(flags): 120 | """Create training hparams.""" 121 | return tf.contrib.training.HParams( 122 | # data 123 | train_prefix = flags.train_prefix, 124 | eval_prefix = flags.eval_prefix, 125 | eval_gold_file = flags.eval_gold_file, 126 | vocab_file = flags.vocab_file, 127 | embed_file = flags.embed_file, 128 | 129 | index_file = flags.index_file, 130 | out_dir = flags.out_dir, 131 | max_len = flags.max_len, 132 | 133 | # hparams 134 | num_units = flags.num_units, 135 | model = flags.model, 136 | learning_rate = flags.learning_rate, 137 | num_train_steps = flags.num_train_steps, 138 | 139 | init_std = flags.init_std, 140 | filter_init_std = flags.filter_init_std, 141 | dropout = flags.dropout, 142 | max_gradient_norm = flags.max_gradient_norm, 143 | 144 | batch_size = flags.batch_size, 145 | num_buckets = flags.num_buckets, 146 | steps_per_stats = flags.steps_per_stats, 147 | steps_per_external_eval = flags.steps_per_external_eval, 148 | 149 | num_tags = flags.num_tags, 150 | time_major = flags.time_major, 151 | 152 | # inference 153 | ckpt = flags.ckpt, 154 | inference_input_file = flags.inference_input_file, 155 | infer_batch_size = flags.infer_batch_size, 156 | inference_output_file = flags.inference_output_file, 157 | ) 158 | 159 | 160 | def check_corpora(train_prefix, eval_prefix): 161 | train_txt = train_prefix + ".txt" 162 | train_lb = train_prefix + ".lb" 163 | eval_txt = eval_prefix + ".txt" 164 | eval_lb = eval_prefix + ".lb" 165 | 166 | def _inner_check(txt_reader, lb_reader): 167 | for txt_line, lb_line in zip(txt_reader, lb_reader): 168 | txt_length = len(txt_line.strip().split()) 169 | lb_length = len(lb_line.strip().split()) 170 | assert txt_length == lb_length 171 | 172 | train_txt_rd = codecs.open(train_txt, 'r', "utf-8") 173 | train_lb_rd = codecs.open(train_lb, 'r', "utf-8") 174 | eval_txt_rd = codecs.open(eval_txt, 'r', "utf-8") 175 | eval_lb_rd = codecs.open(eval_lb, 'r', "utf-8") 176 | 177 | with train_txt_rd, train_lb_rd: 178 | _inner_check(train_txt_rd, train_lb_rd) 179 | 180 | with eval_txt_rd, eval_lb_rd: 181 | _inner_check(eval_txt_rd, eval_lb_rd) 182 | 183 | 184 | def check_vocab(vocab_file): 185 | vocab = [] 186 | with codecs.open(vocab_file, 'r', "utf-8") as vob_inp: 187 | for word in vob_inp: 188 | vocab.append(word.strip()) 189 | 190 | if vocab[0] != UNK: 191 | vocab = [UNK] + vocab 192 | with codecs.open(vocab_file, 'w', "utf-8") as vob_opt: 193 | for word in vocab: 194 | vob_opt.write(word + u'\n') 195 | 196 | return len(vocab) 197 | 198 | 199 | def print_hparams(hparams): 200 | values = hparams.values() 201 | for key in sorted(values.keys()): 202 | print(" %s = %s" % (key, str(values[key]))) 203 | 204 | 205 | def main(unused_argv): 206 | out_dir = FLAGS.out_dir 207 | if not tf.gfile.Exists(out_dir): 208 | tf.gfile.MakeDirs(out_dir) 209 | 210 | hparams = create_hparams(FLAGS) 211 | model = hparams.model.upper() 212 | if model == "CRF": 213 | model_creator = model_r.BasicModel 214 | elif model == "CNN-CRF": 215 | model_creator = model_r.CnnCrfModel 216 | else: 217 | raise ValueError("Unknown model %s" % model) 218 | 219 | assert tf.gfile.Exists(hparams.vocab_file) 220 | vocab_size = check_vocab(hparams.vocab_file) 221 | hparams.add_hparam("vocab_size", vocab_size) 222 | 223 | if FLAGS.inference_input_file: 224 | # Inference 225 | trans_file = FLAGS.inference_output_file 226 | ckpt = FLAGS.ckpt 227 | if not ckpt: 228 | ckpt = tf.train.latest_checkpoint(out_dir) 229 | 230 | main_body.inference(ckpt, FLAGS.inference_input_file, 231 | trans_file, hparams, model_creator) 232 | else: 233 | # Train 234 | check_corpora(FLAGS.train_prefix, FLAGS.eval_prefix) 235 | 236 | # used for evaluation 237 | hparams.add_hparam("best_Fvalue", 0) # larger is better 238 | best_metric_dir = os.path.join(hparams.out_dir, "best_Fvalue") 239 | hparams.add_hparam("best_Fvalue_dir", best_metric_dir) 240 | tf.gfile.MakeDirs(best_metric_dir) 241 | 242 | print_hparams(hparams) 243 | main_body.train(hparams, model_creator) 244 | 245 | 246 | if __name__ == '__main__': 247 | parser = argparse.ArgumentParser() 248 | add_arguments(parser) 249 | FLAGS, unparsed = parser.parse_known_args() 250 | tf.app.run(main = main, argv = [sys.argv[0]] + unparsed) -------------------------------------------------------------------------------- /third_party/compile_w2v.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Author: Synrey Yee 4 | # 5 | # Created at: 05/07/2017 6 | # 7 | # Description: compile word2vec.c 8 | # 9 | # Last Modified at: 05/07/2017, by: Synrey Yee 10 | 11 | gcc word2vec.c -o word2vec -lm -lpthread -------------------------------------------------------------------------------- /third_party/word2vec: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/third_party/word2vec -------------------------------------------------------------------------------- /third_party/word2vec.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #define MAX_STRING 100 22 | #define EXP_TABLE_SIZE 1000 23 | #define MAX_EXP 6 24 | #define MAX_SENTENCE_LENGTH 1000 25 | #define MAX_CODE_LENGTH 40 26 | 27 | const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary 28 | 29 | typedef float real; // Precision of float numbers 30 | 31 | struct vocab_word { 32 | long long cn; 33 | int *point; 34 | char *word, *code, codelen; 35 | }; 36 | 37 | char train_file[MAX_STRING], output_file[MAX_STRING]; 38 | char save_vocab_file[MAX_STRING], read_vocab_file[MAX_STRING]; 39 | struct vocab_word *vocab; 40 | int binary = 0, cbow = 1, debug_mode = 2, window = 5, min_count = 5, num_threads = 12, min_reduce = 1; 41 | int *vocab_hash; 42 | long long vocab_max_size = 1000, vocab_size = 0, layer1_size = 100; 43 | long long train_words = 0, word_count_actual = 0, iter = 5, file_size = 0, classes = 0; 44 | real alpha = 0.025, starting_alpha, sample = 1e-3; 45 | real *syn0, *syn1, *syn1neg, *expTable; 46 | clock_t start; 47 | 48 | int hs = 0, negative = 5; 49 | const int table_size = 1e8; 50 | int *table; 51 | 52 | void InitUnigramTable() { 53 | int a, i; 54 | double train_words_pow = 0; 55 | double d1, power = 0.75; 56 | table = (int *)malloc(table_size * sizeof(int)); 57 | for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power); 58 | i = 0; 59 | d1 = pow(vocab[i].cn, power) / train_words_pow; 60 | for (a = 0; a < table_size; a++) { 61 | table[a] = i; 62 | if (a / (double)table_size > d1) { 63 | i++; 64 | d1 += pow(vocab[i].cn, power) / train_words_pow; 65 | } 66 | if (i >= vocab_size) i = vocab_size - 1; 67 | } 68 | } 69 | 70 | // Reads a single word from a file, assuming space + tab + EOL to be word boundaries 71 | void ReadWord(char *word, FILE *fin) { 72 | int a = 0, ch; 73 | while (!feof(fin)) { 74 | ch = fgetc(fin); 75 | if (ch == 13) continue; 76 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) { 77 | if (a > 0) { 78 | if (ch == '\n') ungetc(ch, fin); 79 | break; 80 | } 81 | if (ch == '\n') { 82 | strcpy(word, (char *)""); 83 | return; 84 | } else continue; 85 | } 86 | word[a] = ch; 87 | a++; 88 | if (a >= MAX_STRING - 1) a--; // Truncate too long words 89 | } 90 | word[a] = 0; 91 | } 92 | 93 | // Returns hash value of a word 94 | int GetWordHash(char *word) { 95 | unsigned long long a, hash = 0; 96 | for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a]; 97 | hash = hash % vocab_hash_size; 98 | return hash; 99 | } 100 | 101 | // Returns position of a word in the vocabulary; if the word is not found, returns -1 102 | int SearchVocab(char *word) { 103 | unsigned int hash = GetWordHash(word); 104 | while (1) { 105 | if (vocab_hash[hash] == -1) return -1; 106 | if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash]; 107 | hash = (hash + 1) % vocab_hash_size; 108 | } 109 | return -1; 110 | } 111 | 112 | // Reads a word and returns its index in the vocabulary 113 | int ReadWordIndex(FILE *fin) { 114 | char word[MAX_STRING]; 115 | ReadWord(word, fin); 116 | if (feof(fin)) return -1; 117 | return SearchVocab(word); 118 | } 119 | 120 | // Adds a word to the vocabulary 121 | int AddWordToVocab(char *word) { 122 | unsigned int hash, length = strlen(word) + 1; 123 | if (length > MAX_STRING) length = MAX_STRING; 124 | vocab[vocab_size].word = (char *)calloc(length, sizeof(char)); 125 | strcpy(vocab[vocab_size].word, word); 126 | vocab[vocab_size].cn = 0; 127 | vocab_size++; 128 | // Reallocate memory if needed 129 | if (vocab_size + 2 >= vocab_max_size) { 130 | vocab_max_size += 1000; 131 | vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word)); 132 | } 133 | hash = GetWordHash(word); 134 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 135 | vocab_hash[hash] = vocab_size - 1; 136 | return vocab_size - 1; 137 | } 138 | 139 | // Used later for sorting by word counts 140 | int VocabCompare(const void *a, const void *b) { 141 | return ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn; 142 | } 143 | 144 | // Sorts the vocabulary by frequency using word counts 145 | void SortVocab() { 146 | int a, size; 147 | unsigned int hash; 148 | // Sort the vocabulary and keep at the first position 149 | qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare); 150 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 151 | size = vocab_size; 152 | train_words = 0; 153 | for (a = 0; a < size; a++) { 154 | // Words occuring less than min_count times will be discarded from the vocab 155 | if ((vocab[a].cn < min_count) && (a != 0)) { 156 | vocab_size--; 157 | free(vocab[a].word); 158 | } else { 159 | // Hash will be re-computed, as after the sorting it is not actual 160 | hash=GetWordHash(vocab[a].word); 161 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 162 | vocab_hash[hash] = a; 163 | train_words += vocab[a].cn; 164 | } 165 | } 166 | vocab = (struct vocab_word *)realloc(vocab, (vocab_size + 1) * sizeof(struct vocab_word)); 167 | // Allocate memory for the binary tree construction 168 | for (a = 0; a < vocab_size; a++) { 169 | vocab[a].code = (char *)calloc(MAX_CODE_LENGTH, sizeof(char)); 170 | vocab[a].point = (int *)calloc(MAX_CODE_LENGTH, sizeof(int)); 171 | } 172 | } 173 | 174 | // Reduces the vocabulary by removing infrequent tokens 175 | void ReduceVocab() { 176 | int a, b = 0; 177 | unsigned int hash; 178 | for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) { 179 | vocab[b].cn = vocab[a].cn; 180 | vocab[b].word = vocab[a].word; 181 | b++; 182 | } else free(vocab[a].word); 183 | vocab_size = b; 184 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 185 | for (a = 0; a < vocab_size; a++) { 186 | // Hash will be re-computed, as it is not actual 187 | hash = GetWordHash(vocab[a].word); 188 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 189 | vocab_hash[hash] = a; 190 | } 191 | fflush(stdout); 192 | min_reduce++; 193 | } 194 | 195 | // Create binary Huffman tree using the word counts 196 | // Frequent words will have short uniqe binary codes 197 | void CreateBinaryTree() { 198 | long long a, b, i, min1i, min2i, pos1, pos2, point[MAX_CODE_LENGTH]; 199 | char code[MAX_CODE_LENGTH]; 200 | long long *count = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long)); 201 | long long *binary = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long)); 202 | long long *parent_node = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long)); 203 | for (a = 0; a < vocab_size; a++) count[a] = vocab[a].cn; 204 | for (a = vocab_size; a < vocab_size * 2; a++) count[a] = 1e15; 205 | pos1 = vocab_size - 1; 206 | pos2 = vocab_size; 207 | // Following algorithm constructs the Huffman tree by adding one node at a time 208 | for (a = 0; a < vocab_size - 1; a++) { 209 | // First, find two smallest nodes 'min1, min2' 210 | if (pos1 >= 0) { 211 | if (count[pos1] < count[pos2]) { 212 | min1i = pos1; 213 | pos1--; 214 | } else { 215 | min1i = pos2; 216 | pos2++; 217 | } 218 | } else { 219 | min1i = pos2; 220 | pos2++; 221 | } 222 | if (pos1 >= 0) { 223 | if (count[pos1] < count[pos2]) { 224 | min2i = pos1; 225 | pos1--; 226 | } else { 227 | min2i = pos2; 228 | pos2++; 229 | } 230 | } else { 231 | min2i = pos2; 232 | pos2++; 233 | } 234 | count[vocab_size + a] = count[min1i] + count[min2i]; 235 | parent_node[min1i] = vocab_size + a; 236 | parent_node[min2i] = vocab_size + a; 237 | binary[min2i] = 1; 238 | } 239 | // Now assign binary code to each vocabulary word 240 | for (a = 0; a < vocab_size; a++) { 241 | b = a; 242 | i = 0; 243 | while (1) { 244 | code[i] = binary[b]; 245 | point[i] = b; 246 | i++; 247 | b = parent_node[b]; 248 | if (b == vocab_size * 2 - 2) break; 249 | } 250 | vocab[a].codelen = i; 251 | vocab[a].point[0] = vocab_size - 2; 252 | for (b = 0; b < i; b++) { 253 | vocab[a].code[i - b - 1] = code[b]; 254 | vocab[a].point[i - b] = point[b] - vocab_size; 255 | } 256 | } 257 | free(count); 258 | free(binary); 259 | free(parent_node); 260 | } 261 | 262 | void LearnVocabFromTrainFile() { 263 | char word[MAX_STRING]; 264 | FILE *fin; 265 | long long a, i; 266 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 267 | fin = fopen(train_file, "rb"); 268 | if (fin == NULL) { 269 | printf("ERROR: training data file not found!\n"); 270 | exit(1); 271 | } 272 | vocab_size = 0; 273 | AddWordToVocab((char *)""); 274 | while (1) { 275 | ReadWord(word, fin); 276 | if (feof(fin)) break; 277 | train_words++; 278 | if ((debug_mode > 1) && (train_words % 100000 == 0)) { 279 | printf("%lldK%c", train_words / 1000, 13); 280 | fflush(stdout); 281 | } 282 | i = SearchVocab(word); 283 | if (i == -1) { 284 | a = AddWordToVocab(word); 285 | vocab[a].cn = 1; 286 | } else vocab[i].cn++; 287 | if (vocab_size > vocab_hash_size * 0.7) ReduceVocab(); 288 | } 289 | SortVocab(); 290 | if (debug_mode > 0) { 291 | printf("Vocab size: %lld\n", vocab_size); 292 | printf("Words in train file: %lld\n", train_words); 293 | } 294 | file_size = ftell(fin); 295 | fclose(fin); 296 | } 297 | 298 | void SaveVocab() { 299 | long long i; 300 | FILE *fo = fopen(save_vocab_file, "wb"); 301 | for (i = 0; i < vocab_size; i++) fprintf(fo, "%s %lld\n", vocab[i].word, vocab[i].cn); 302 | fclose(fo); 303 | } 304 | 305 | void ReadVocab() { 306 | long long a, i = 0; 307 | char c; 308 | char word[MAX_STRING]; 309 | FILE *fin = fopen(read_vocab_file, "rb"); 310 | if (fin == NULL) { 311 | printf("Vocabulary file not found\n"); 312 | exit(1); 313 | } 314 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 315 | vocab_size = 0; 316 | while (1) { 317 | ReadWord(word, fin); 318 | if (feof(fin)) break; 319 | a = AddWordToVocab(word); 320 | fscanf(fin, "%lld%c", &vocab[a].cn, &c); 321 | i++; 322 | } 323 | SortVocab(); 324 | if (debug_mode > 0) { 325 | printf("Vocab size: %lld\n", vocab_size); 326 | printf("Words in train file: %lld\n", train_words); 327 | } 328 | fin = fopen(train_file, "rb"); 329 | if (fin == NULL) { 330 | printf("ERROR: training data file not found!\n"); 331 | exit(1); 332 | } 333 | fseek(fin, 0, SEEK_END); 334 | file_size = ftell(fin); 335 | fclose(fin); 336 | } 337 | 338 | void InitNet() { 339 | long long a, b; 340 | unsigned long long next_random = 1; 341 | a = posix_memalign((void **)&syn0, 128, (long long)vocab_size * layer1_size * sizeof(real)); 342 | if (syn0 == NULL) {printf("Memory allocation failed\n"); exit(1);} 343 | if (hs) { 344 | a = posix_memalign((void **)&syn1, 128, (long long)vocab_size * layer1_size * sizeof(real)); 345 | if (syn1 == NULL) {printf("Memory allocation failed\n"); exit(1);} 346 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) 347 | syn1[a * layer1_size + b] = 0; 348 | } 349 | if (negative>0) { 350 | a = posix_memalign((void **)&syn1neg, 128, (long long)vocab_size * layer1_size * sizeof(real)); 351 | if (syn1neg == NULL) {printf("Memory allocation failed\n"); exit(1);} 352 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) 353 | syn1neg[a * layer1_size + b] = 0; 354 | } 355 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) { 356 | next_random = next_random * (unsigned long long)25214903917 + 11; 357 | syn0[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) / layer1_size; 358 | } 359 | CreateBinaryTree(); 360 | } 361 | 362 | void *TrainModelThread(void *id) { 363 | long long a, b, d, cw, word, last_word, sentence_length = 0, sentence_position = 0; 364 | long long word_count = 0, last_word_count = 0, sen[MAX_SENTENCE_LENGTH + 1]; 365 | long long l1, l2, c, target, label, local_iter = iter; 366 | unsigned long long next_random = (long long)id; 367 | real f, g; 368 | clock_t now; 369 | real *neu1 = (real *)calloc(layer1_size, sizeof(real)); 370 | real *neu1e = (real *)calloc(layer1_size, sizeof(real)); 371 | FILE *fi = fopen(train_file, "rb"); 372 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET); 373 | while (1) { 374 | if (word_count - last_word_count > 10000) { 375 | word_count_actual += word_count - last_word_count; 376 | last_word_count = word_count; 377 | if ((debug_mode > 1)) { 378 | now=clock(); 379 | printf("%cAlpha: %f Progress: %.2f%% Words/thread/sec: %.2fk ", 13, alpha, 380 | word_count_actual / (real)(iter * train_words + 1) * 100, 381 | word_count_actual / ((real)(now - start + 1) / (real)CLOCKS_PER_SEC * 1000)); 382 | fflush(stdout); 383 | } 384 | alpha = starting_alpha * (1 - word_count_actual / (real)(iter * train_words + 1)); 385 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001; 386 | } 387 | if (sentence_length == 0) { 388 | while (1) { 389 | word = ReadWordIndex(fi); 390 | if (feof(fi)) break; 391 | if (word == -1) continue; 392 | word_count++; 393 | if (word == 0) break; 394 | // The subsampling randomly discards frequent words while keeping the ranking same 395 | if (sample > 0) { 396 | real ran = (sqrt(vocab[word].cn / (sample * train_words)) + 1) * (sample * train_words) / vocab[word].cn; 397 | next_random = next_random * (unsigned long long)25214903917 + 11; 398 | if (ran < (next_random & 0xFFFF) / (real)65536) continue; 399 | } 400 | sen[sentence_length] = word; 401 | sentence_length++; 402 | if (sentence_length >= MAX_SENTENCE_LENGTH) break; 403 | } 404 | sentence_position = 0; 405 | } 406 | if (feof(fi) || (word_count > train_words / num_threads)) { 407 | word_count_actual += word_count - last_word_count; 408 | local_iter--; 409 | if (local_iter == 0) break; 410 | word_count = 0; 411 | last_word_count = 0; 412 | sentence_length = 0; 413 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET); 414 | continue; 415 | } 416 | word = sen[sentence_position]; 417 | if (word == -1) continue; 418 | for (c = 0; c < layer1_size; c++) neu1[c] = 0; 419 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0; 420 | next_random = next_random * (unsigned long long)25214903917 + 11; 421 | b = next_random % window; 422 | if (cbow) { //train the cbow architecture 423 | // in -> hidden 424 | cw = 0; 425 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 426 | c = sentence_position - window + a; 427 | if (c < 0) continue; 428 | if (c >= sentence_length) continue; 429 | last_word = sen[c]; 430 | if (last_word == -1) continue; 431 | for (c = 0; c < layer1_size; c++) neu1[c] += syn0[c + last_word * layer1_size]; 432 | cw++; 433 | } 434 | if (cw) { 435 | for (c = 0; c < layer1_size; c++) neu1[c] /= cw; 436 | if (hs) for (d = 0; d < vocab[word].codelen; d++) { 437 | f = 0; 438 | l2 = vocab[word].point[d] * layer1_size; 439 | // Propagate hidden -> output 440 | for (c = 0; c < layer1_size; c++) f += neu1[c] * syn1[c + l2]; 441 | if (f <= -MAX_EXP) continue; 442 | else if (f >= MAX_EXP) continue; 443 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 444 | // 'g' is the gradient multiplied by the learning rate 445 | g = (1 - vocab[word].code[d] - f) * alpha; 446 | // Propagate errors output -> hidden 447 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1[c + l2]; 448 | // Learn weights hidden -> output 449 | for (c = 0; c < layer1_size; c++) syn1[c + l2] += g * neu1[c]; 450 | } 451 | // NEGATIVE SAMPLING 452 | if (negative > 0) for (d = 0; d < negative + 1; d++) { 453 | if (d == 0) { 454 | target = word; 455 | label = 1; 456 | } else { 457 | next_random = next_random * (unsigned long long)25214903917 + 11; 458 | target = table[(next_random >> 16) % table_size]; 459 | if (target == 0) target = next_random % (vocab_size - 1) + 1; 460 | if (target == word) continue; 461 | label = 0; 462 | } 463 | l2 = target * layer1_size; 464 | f = 0; 465 | for (c = 0; c < layer1_size; c++) f += neu1[c] * syn1neg[c + l2]; 466 | if (f > MAX_EXP) g = (label - 1) * alpha; 467 | else if (f < -MAX_EXP) g = (label - 0) * alpha; 468 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 469 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2]; 470 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * neu1[c]; 471 | } 472 | // hidden -> in 473 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 474 | c = sentence_position - window + a; 475 | if (c < 0) continue; 476 | if (c >= sentence_length) continue; 477 | last_word = sen[c]; 478 | if (last_word == -1) continue; 479 | for (c = 0; c < layer1_size; c++) syn0[c + last_word * layer1_size] += neu1e[c]; 480 | } 481 | } 482 | } else { //train skip-gram 483 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 484 | c = sentence_position - window + a; 485 | if (c < 0) continue; 486 | if (c >= sentence_length) continue; 487 | last_word = sen[c]; 488 | if (last_word == -1) continue; 489 | l1 = last_word * layer1_size; 490 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0; 491 | // HIERARCHICAL SOFTMAX 492 | if (hs) for (d = 0; d < vocab[word].codelen; d++) { 493 | f = 0; 494 | l2 = vocab[word].point[d] * layer1_size; 495 | // Propagate hidden -> output 496 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1[c + l2]; 497 | if (f <= -MAX_EXP) continue; 498 | else if (f >= MAX_EXP) continue; 499 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 500 | // 'g' is the gradient multiplied by the learning rate 501 | g = (1 - vocab[word].code[d] - f) * alpha; 502 | // Propagate errors output -> hidden 503 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1[c + l2]; 504 | // Learn weights hidden -> output 505 | for (c = 0; c < layer1_size; c++) syn1[c + l2] += g * syn0[c + l1]; 506 | } 507 | // NEGATIVE SAMPLING 508 | if (negative > 0) for (d = 0; d < negative + 1; d++) { 509 | if (d == 0) { 510 | target = word; 511 | label = 1; 512 | } else { 513 | next_random = next_random * (unsigned long long)25214903917 + 11; 514 | target = table[(next_random >> 16) % table_size]; 515 | if (target == 0) target = next_random % (vocab_size - 1) + 1; 516 | if (target == word) continue; 517 | label = 0; 518 | } 519 | l2 = target * layer1_size; 520 | f = 0; 521 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1neg[c + l2]; 522 | if (f > MAX_EXP) g = (label - 1) * alpha; 523 | else if (f < -MAX_EXP) g = (label - 0) * alpha; 524 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 525 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2]; 526 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * syn0[c + l1]; 527 | } 528 | // Learn weights input -> hidden 529 | for (c = 0; c < layer1_size; c++) syn0[c + l1] += neu1e[c]; 530 | } 531 | } 532 | sentence_position++; 533 | if (sentence_position >= sentence_length) { 534 | sentence_length = 0; 535 | continue; 536 | } 537 | } 538 | fclose(fi); 539 | free(neu1); 540 | free(neu1e); 541 | pthread_exit(NULL); 542 | } 543 | 544 | void TrainModel() { 545 | long a, b, c, d; 546 | FILE *fo; 547 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 548 | printf("Starting training using file %s\n", train_file); 549 | starting_alpha = alpha; 550 | if (read_vocab_file[0] != 0) ReadVocab(); else LearnVocabFromTrainFile(); 551 | if (save_vocab_file[0] != 0) SaveVocab(); 552 | if (output_file[0] == 0) return; 553 | InitNet(); 554 | if (negative > 0) InitUnigramTable(); 555 | start = clock(); 556 | for (a = 0; a < num_threads; a++) pthread_create(&pt[a], NULL, TrainModelThread, (void *)a); 557 | for (a = 0; a < num_threads; a++) pthread_join(pt[a], NULL); 558 | fo = fopen(output_file, "wb"); 559 | if (classes == 0) { 560 | // Save the word vectors 561 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size); 562 | for (a = 0; a < vocab_size; a++) { 563 | fprintf(fo, "%s ", vocab[a].word); 564 | if (binary) for (b = 0; b < layer1_size; b++) fwrite(&syn0[a * layer1_size + b], sizeof(real), 1, fo); 565 | else for (b = 0; b < layer1_size; b++) fprintf(fo, "%lf ", syn0[a * layer1_size + b]); 566 | fprintf(fo, "\n"); 567 | } 568 | } else { 569 | // Run K-means on the word vectors 570 | int clcn = classes, iter = 10, closeid; 571 | int *centcn = (int *)malloc(classes * sizeof(int)); 572 | int *cl = (int *)calloc(vocab_size, sizeof(int)); 573 | real closev, x; 574 | real *cent = (real *)calloc(classes * layer1_size, sizeof(real)); 575 | for (a = 0; a < vocab_size; a++) cl[a] = a % clcn; 576 | for (a = 0; a < iter; a++) { 577 | for (b = 0; b < clcn * layer1_size; b++) cent[b] = 0; 578 | for (b = 0; b < clcn; b++) centcn[b] = 1; 579 | for (c = 0; c < vocab_size; c++) { 580 | for (d = 0; d < layer1_size; d++) cent[layer1_size * cl[c] + d] += syn0[c * layer1_size + d]; 581 | centcn[cl[c]]++; 582 | } 583 | for (b = 0; b < clcn; b++) { 584 | closev = 0; 585 | for (c = 0; c < layer1_size; c++) { 586 | cent[layer1_size * b + c] /= centcn[b]; 587 | closev += cent[layer1_size * b + c] * cent[layer1_size * b + c]; 588 | } 589 | closev = sqrt(closev); 590 | for (c = 0; c < layer1_size; c++) cent[layer1_size * b + c] /= closev; 591 | } 592 | for (c = 0; c < vocab_size; c++) { 593 | closev = -10; 594 | closeid = 0; 595 | for (d = 0; d < clcn; d++) { 596 | x = 0; 597 | for (b = 0; b < layer1_size; b++) x += cent[layer1_size * d + b] * syn0[c * layer1_size + b]; 598 | if (x > closev) { 599 | closev = x; 600 | closeid = d; 601 | } 602 | } 603 | cl[c] = closeid; 604 | } 605 | } 606 | // Save the K-means classes 607 | for (a = 0; a < vocab_size; a++) fprintf(fo, "%s %d\n", vocab[a].word, cl[a]); 608 | free(centcn); 609 | free(cent); 610 | free(cl); 611 | } 612 | fclose(fo); 613 | } 614 | 615 | int ArgPos(char *str, int argc, char **argv) { 616 | int a; 617 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) { 618 | if (a == argc - 1) { 619 | printf("Argument missing for %s\n", str); 620 | exit(1); 621 | } 622 | return a; 623 | } 624 | return -1; 625 | } 626 | 627 | int main(int argc, char **argv) { 628 | int i; 629 | if (argc == 1) { 630 | printf("WORD VECTOR estimation toolkit v 0.1c\n\n"); 631 | printf("Options:\n"); 632 | printf("Parameters for training:\n"); 633 | printf("\t-train \n"); 634 | printf("\t\tUse text data from to train the model\n"); 635 | printf("\t-output \n"); 636 | printf("\t\tUse to save the resulting word vectors / word clusters\n"); 637 | printf("\t-size \n"); 638 | printf("\t\tSet size of word vectors; default is 100\n"); 639 | printf("\t-window \n"); 640 | printf("\t\tSet max skip length between words; default is 5\n"); 641 | printf("\t-sample \n"); 642 | printf("\t\tSet threshold for occurrence of words. Those that appear with higher frequency in the training data\n"); 643 | printf("\t\twill be randomly down-sampled; default is 1e-3, useful range is (0, 1e-5)\n"); 644 | printf("\t-hs \n"); 645 | printf("\t\tUse Hierarchical Softmax; default is 0 (not used)\n"); 646 | printf("\t-negative \n"); 647 | printf("\t\tNumber of negative examples; default is 5, common values are 3 - 10 (0 = not used)\n"); 648 | printf("\t-threads \n"); 649 | printf("\t\tUse threads (default 12)\n"); 650 | printf("\t-iter \n"); 651 | printf("\t\tRun more training iterations (default 5)\n"); 652 | printf("\t-min-count \n"); 653 | printf("\t\tThis will discard words that appear less than times; default is 5\n"); 654 | printf("\t-alpha \n"); 655 | printf("\t\tSet the starting learning rate; default is 0.025 for skip-gram and 0.05 for CBOW\n"); 656 | printf("\t-classes \n"); 657 | printf("\t\tOutput word classes rather than word vectors; default number of classes is 0 (vectors are written)\n"); 658 | printf("\t-debug \n"); 659 | printf("\t\tSet the debug mode (default = 2 = more info during training)\n"); 660 | printf("\t-binary \n"); 661 | printf("\t\tSave the resulting vectors in binary moded; default is 0 (off)\n"); 662 | printf("\t-save-vocab \n"); 663 | printf("\t\tThe vocabulary will be saved to \n"); 664 | printf("\t-read-vocab \n"); 665 | printf("\t\tThe vocabulary will be read from , not constructed from the training data\n"); 666 | printf("\t-cbow \n"); 667 | printf("\t\tUse the continuous bag of words model; default is 1 (use 0 for skip-gram model)\n"); 668 | printf("\nExamples:\n"); 669 | printf("./word2vec -train data.txt -output vec.txt -size 200 -window 5 -sample 1e-4 -negative 5 -hs 0 -binary 0 -cbow 1 -iter 3\n\n"); 670 | return 0; 671 | } 672 | output_file[0] = 0; 673 | save_vocab_file[0] = 0; 674 | read_vocab_file[0] = 0; 675 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]); 676 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]); 677 | if ((i = ArgPos((char *)"-save-vocab", argc, argv)) > 0) strcpy(save_vocab_file, argv[i + 1]); 678 | if ((i = ArgPos((char *)"-read-vocab", argc, argv)) > 0) strcpy(read_vocab_file, argv[i + 1]); 679 | if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]); 680 | if ((i = ArgPos((char *)"-binary", argc, argv)) > 0) binary = atoi(argv[i + 1]); 681 | if ((i = ArgPos((char *)"-cbow", argc, argv)) > 0) cbow = atoi(argv[i + 1]); 682 | if (cbow) alpha = 0.05; 683 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]); 684 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 685 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]); 686 | if ((i = ArgPos((char *)"-sample", argc, argv)) > 0) sample = atof(argv[i + 1]); 687 | if ((i = ArgPos((char *)"-hs", argc, argv)) > 0) hs = atoi(argv[i + 1]); 688 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]); 689 | if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]); 690 | if ((i = ArgPos((char *)"-iter", argc, argv)) > 0) iter = atoi(argv[i + 1]); 691 | if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]); 692 | if ((i = ArgPos((char *)"-classes", argc, argv)) > 0) classes = atoi(argv[i + 1]); 693 | vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word)); 694 | vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int)); 695 | expTable = (real *)malloc((EXP_TABLE_SIZE + 1) * sizeof(real)); 696 | for (i = 0; i < EXP_TABLE_SIZE; i++) { 697 | expTable[i] = exp((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table 698 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1) 699 | } 700 | TrainModel(); 701 | return 0; 702 | } --------------------------------------------------------------------------------