├── Extra
├── Bi-LSTM_CNN.html
├── Bi_LSTM_CNN.jpg
└── Bi_LSTM_CNN_CRF.jpg
├── LICENSE.txt
├── PRF_Score.py
├── README.md
├── preprocess.py
├── sycws
├── __init__.py
├── data_iterator.py
├── indices.txt
├── main_body.py
├── model.py
├── model_helper.py
├── prf_script.py
└── sycws.py
└── third_party
├── compile_w2v.sh
├── word2vec
└── word2vec.c
/Extra/Bi-LSTM_CNN.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Bi-LSTM_CNN_CRF
6 |
7 |
8 | 
9 |
10 |
11 |
--------------------------------------------------------------------------------
/Extra/Bi_LSTM_CNN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/Extra/Bi_LSTM_CNN.jpg
--------------------------------------------------------------------------------
/Extra/Bi_LSTM_CNN_CRF.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/Extra/Bi_LSTM_CNN_CRF.jpg
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee)
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/PRF_Score.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | # Perceptron word segment for Chinese sentences
4 | #
5 | # File usage:
6 | # Results score. PRF
7 | #
8 | # Author:
9 | # Synrey Yee
10 |
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import codecs
15 | import sys
16 |
17 | e = 0 # wrong words number
18 | c = 0 # correct words number
19 | N = 0 # gold words number
20 | TN = 0 # test words number
21 |
22 | test_file = sys.argv[1]
23 | gold_file = sys.argv[2]
24 |
25 | test_raw = []
26 | with codecs.open(test_file, 'r', "utf-8") as inpt1:
27 | for line in inpt1:
28 | sent = line.strip().split()
29 | if sent:
30 | test_raw.append(sent)
31 |
32 | gold_raw = []
33 | with codecs.open(gold_file, 'r', "utf-8") as inpt2:
34 | for line in inpt2:
35 | sent = line.strip().split()
36 | if sent:
37 | gold_raw.append(sent)
38 | N += len(sent)
39 |
40 | for i, gold_sent in enumerate(gold_raw):
41 | test_sent = test_raw[i]
42 |
43 | ig = 0
44 | it = 0
45 | glen = len(gold_sent)
46 | tlen = len(test_sent)
47 | while True:
48 | if ig >= glen or it >= tlen:
49 | break
50 |
51 | gword = gold_sent[ig]
52 | tword = test_sent[it]
53 | if gword == tword:
54 | c += 1
55 | else:
56 | lg = len(gword)
57 | lt = len(tword)
58 | while lg != lt:
59 | try:
60 | if lg < lt:
61 | ig += 1
62 | gword = gold_sent[ig]
63 | lg += len(gword)
64 | else:
65 | it += 1
66 | tword = test_sent[it]
67 | lt += len(tword)
68 | except Exception as e:
69 | # pdb.set_trace()
70 | print ("Line: %d" % (i + 1))
71 | print ("\nIt is the user's responsibility that a sentence in must", end = ' ')
72 | print ("have the SAME LENGTH with its corresponding sentence in .\n")
73 | raise e
74 |
75 | ig += 1
76 | it += 1
77 |
78 | TN += len(test_sent)
79 |
80 | e = TN - c
81 | precision = c / TN
82 | recall = c / N
83 | F = 2 * precision * recall / (precision + recall)
84 | error_rate = e / N
85 |
86 | print ("Correct words: %d"%c)
87 | print ("Error words: %d"%e)
88 | print ("Gold words: %d\n"%N)
89 | print ("precision: %f"%precision)
90 | print ("recall: %f"%recall)
91 | print ("F-Value: %f"%F)
92 | print ("error_rate: %f"%error_rate)
93 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LSTM-(CNN)-CRF for CWS
2 | [](https://www.python.org/)
3 | [](https://www.tensorflow.org/)
4 | Bi-LSTM+CNN+CRF for Chinese word segmentation.
5 | The **new version** has come. However, the old version is still available on another branch.
6 |
7 | ## Usage
8 | ***What's new?***
9 | * The new system is arranged **more orderly**;
10 | * The CNN model has been tweaked;
11 | * Remove the limitation of maximum length of sentences, although you can still set it;
12 | * Add gradient clipping;
13 | * Pre-training is your choice (whether to use the pretrained embeddings or not), while I didn't see a non-trivial margin in my experiments;
14 | * The system can save the best model during training, scored by F-value.
15 | ### Command Step by Step
16 | * Preprocessing
17 | Used to generate training files from the Corpora such as [**People 2014**](http://www.all-terms.com/bbs/thread-7977-1-1.html) and [**icwb2-data**](http://sighan.cs.uchicago.edu/bakeoff2005/). See the source code or run *python preprocess.py -h* to see more details.
18 |
19 | For example, for the *People* data, use the default arguments; (The input file is just *--all_corpora*, the others are output files.)
20 |
21 | For the icwb2-data such as PKU: (The input files are *--all_corpora* and *--gold_file*)
22 | *python3 preprocess.py --all_corpora /home/synrey/data/icwb2-data/training/pku_training.utf8 --vob_path /home/synrey/data/icwb2-data/data-pku/vocab.txt --char_file /home/synrey/data/icwb2-data/data-pku/chars.txt --train_file_pre /home/synrey/data/icwb2-data/data-pku/train --eval_file_pre /home/synrey/data/icwb2-data/data-pku/eval --gold_file /home/synrey/data/icwb2-data/gold/pku_test_gold.utf8 --is_people False --word_freq 2*
23 |
24 | * Pretraining
25 | You may need to use the file third_party/compile_w2v.sh to compile word2vec.c firstly.
26 | For the PKU corpus:
27 | *./third_party/word2vec -train /home/synrey/data/icwb2-data/data-pku/chars.txt -output /home/synrey/data/icwb2-data/data-pku/char_vec.txt -size 100 -sample 1e-4 -negative 0 -hs 1 -min-count 2*
28 |
29 | For the People corpus:
30 | *./third_party/word2vec -train /home/synrey/data/cws-v2-data/chars.txt -output /home/synrey/data/cws-v2-data/char_vec.txt -size 100 -sample 1e-4 -negative 0 -hs 1 -min-count 3*
31 |
32 | * Training
33 | For example:
34 |
35 | *python3 -m sycws.sycws --train_prefix /home/synrey/data/cws-v2-data/train --eval_prefix /home/synrey/data/cws-v2-data/eval --vocab_file /home/synrey/data/cws-v2-data/vocab.txt --out_dir /home/synrey/data/cws-v2-data/model --model CNN-CRF*
36 |
37 | If you want to use the pretrained embeddings, set the argument **--embed_file** to the path of your embeddings, such as *--embed_file /home/synrey/data/cws-v2-data/char_vec.txt*
38 |
39 | See the source code for more args' configuration. It shuold perform well with the default parameters. Naturally, you may also try out other parameter settings.
40 |
41 | ## About the Models
42 | ### Bi-LSTM-SL-CRF
43 | Take reference to [Guillaume Lample, Miguel Ballesteros, Sandeep Subramanian, Kazuya Kawakami, and Chris Dyer. Neural Architectures for Named Entity Recognition. In Proc. ACL. 2016.](http://www.aclweb.org/anthology/N16-1030)
44 | Actually, there is a *single layer* (SL) between BiLSTM and CRF.
45 |
46 | ### Bi-LSTM-CNN-CRF
47 | See [Here](http://htmlpreview.github.io/?https://github.com/MeteorYee/LSTM-CNN-CWS/blob/master/Extra/Bi-LSTM_CNN.html).
48 | Namely, the single layer between BiLSTM and CRF is replaced by a layer of CNN.
49 |
50 | ### Comparison
51 | Experiments on corpus [**People 2014**](http://www.all-terms.com/bbs/thread-7977-1-1.html).
52 |
53 | | Models | Bi-LSTM-SL-CRF | Bi-LSTM-CNN-CRF |
54 | | :-----------: | :--------------: | :---------------: |
55 | | Precision | 96.25% | 96.30% |
56 | | Recall | 95.34% | 95.70% |
57 | | F-value | 95.79% | **96.00%** |
58 |
59 | ## Segmentation
60 | * Inference
61 | For example, to use model **BiLSTM-CNN-CRF** for decoding.
62 |
63 | *python3 -m sycws.sycws --vocab_file /home/synrey/data/cws-v2-data/vocab.txt --out_dir /home/synrey/data/cws-v2-data/model/best_Fvalue --inference_input_file /home/synrey/data/cws-v2-data/test.txt --inference_output_file /home/synrey/data/cws-v2-data/result.txt*
64 |
65 | Set *--model CRF* to use model **BiLSTM-SL-CRF** for inference.
66 | Note, Even if you use pretrained embeddings, the inference command is still the same.
67 |
68 | * PRF Scoring
69 |
70 | *python3 PRF_Score.py *
71 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 03/22/2018
6 | #
7 | # Description: preprocessing for people 2014 corpora
8 | #
9 | # Last Modified at: 05/20/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import print_function
30 |
31 | from collections import defaultdict
32 |
33 | import codecs
34 | import argparse
35 | import os
36 |
37 |
38 | NE_LEFT = u'['
39 | NE_RIGHT = u']'
40 | DIVIDER = u'/'
41 | SPACE = u' '
42 | UNK = u"unk"
43 |
44 | WORD_S = u'0'
45 | WORD_B = u'1'
46 | WORD_M = u'2'
47 | WORD_E = u'3'
48 |
49 |
50 | def clean_sentence(line_list):
51 | new_line_list = []
52 | for token in line_list:
53 | div_id = token.rfind(DIVIDER)
54 |
55 | if div_id < 1:
56 | # the div_id shouldn't be lower than 1
57 | # if it does, give up the word
58 | continue
59 |
60 | word = token[ : div_id]
61 | tag = token[(div_id + 1) : ]
62 |
63 | if word[0] == NE_LEFT and len(word) > 1:
64 | new_line_list.append(word[1 : ])
65 |
66 | elif word[-1] == NE_RIGHT and len(word) > 1:
67 | div_id = word.rfind(DIVIDER)
68 | if div_id < 1:
69 | new_line_list.append(word[ : (len(word)-1)])
70 | else:
71 | new_line_list.append(word[ : div_id])
72 |
73 | else:
74 | new_line_list.append(word)
75 |
76 | return new_line_list
77 |
78 |
79 | def write_line(line_list, outstream, sep = SPACE):
80 | line = sep.join(line_list)
81 | outstream.write(line + u'\n')
82 |
83 |
84 | def analyze_line(line_list, vob_dict):
85 | char_list = []
86 | label_list = []
87 |
88 | for word in line_list:
89 | length = len(word)
90 | if length == 1:
91 | char_list.append(word)
92 | label_list.append(WORD_S)
93 | vob_dict[word] += 1
94 |
95 | else:
96 | for pos, char in enumerate(word):
97 | if pos == 0:
98 | label_list.append(WORD_B)
99 | elif pos == (length - 1):
100 | label_list.append(WORD_E)
101 | else:
102 | label_list.append(WORD_M)
103 |
104 | char_list.append(char)
105 | vob_dict[char] += 1
106 |
107 | assert len(char_list) == len(label_list)
108 | return char_list, label_list
109 |
110 |
111 | def generate_files(corpora, vob_path, char_file, train_word_file,
112 | train_label_file, eval_word_file, eval_label_file, eval_gold_file,
113 | test_file, gold_file, step, freq, max_len):
114 |
115 | inp = codecs.open(corpora, 'r', "utf-8")
116 |
117 | tr_wd_wr = codecs.open(train_word_file, 'w', "utf-8")
118 | tr_lb_wr = codecs.open(train_label_file, 'w', "utf-8")
119 | ev_wd_wr = codecs.open(eval_word_file, 'w', "utf-8")
120 | ev_lb_wr = codecs.open(eval_label_file, 'w', "utf-8")
121 |
122 | ev_gold_wr = codecs.open(eval_gold_file, 'w', "utf-8")
123 | test_wr = codecs.open(test_file, 'w', "utf-8")
124 | gold_wr = codecs.open(gold_file, 'w', "utf-8")
125 |
126 | dump_cnt = 0
127 | vob_dict = defaultdict(int)
128 | isEval = True
129 |
130 | with inp, tr_wd_wr, tr_lb_wr, ev_wd_wr, ev_lb_wr, ev_gold_wr, test_wr, gold_wr:
131 | for ind, line in enumerate(inp):
132 | line_list = line.strip().split()
133 | if len(line_list) > max_len:
134 | dump_cnt += 1
135 | continue
136 |
137 | cleaned_line = clean_sentence(line_list)
138 | if not cleaned_line:
139 | dump_cnt += 1
140 | continue
141 |
142 | char_list, label_list = analyze_line(cleaned_line, vob_dict)
143 |
144 | if ind % step == 0:
145 | if isEval:
146 | write_line(char_list, ev_wd_wr)
147 | write_line(label_list, ev_lb_wr)
148 | write_line(cleaned_line, ev_gold_wr)
149 | isEval = False
150 | else:
151 | write_line(cleaned_line, test_wr, sep = u'')
152 | write_line(cleaned_line, gold_wr)
153 | isEval = True
154 | else:
155 | write_line(char_list, tr_wd_wr)
156 | write_line(label_list, tr_lb_wr)
157 |
158 | inp = codecs.open(corpora, 'r', "utf-8")
159 | ch_wr = codecs.open(char_file, 'w', "utf-8")
160 | with inp, ch_wr:
161 | for line in inp:
162 | line_list = line.strip().split()
163 | if len(line_list) > max_len:
164 | continue
165 |
166 | cleaned_line = clean_sentence(line_list)
167 | if not cleaned_line:
168 | continue
169 | char_list = []
170 | for phr in cleaned_line:
171 | for ch in phr:
172 | if vob_dict[ch] < freq:
173 | char_list.append(UNK)
174 | else:
175 | char_list.append(ch)
176 |
177 | write_line(char_list, ch_wr)
178 |
179 | word_cnt = 0
180 | with codecs.open(vob_path, 'w', "utf-8") as vob_wr:
181 | vob_wr.write(UNK + u'\n')
182 | for word, fq in vob_dict.items():
183 | if fq >= freq:
184 | vob_wr.write(word + u'\n')
185 | word_cnt += 1
186 |
187 | print("Finished, give up %d sentences." % dump_cnt)
188 | print("Select %d chars from the original %d chars" % (word_cnt, len(vob_dict)))
189 |
190 |
191 | # used for people corpora
192 | def people_main(args):
193 | corpora = args.all_corpora
194 | assert os.path.exists(corpora)
195 |
196 | total_line = 0
197 | # count the total number of lines
198 | with open(corpora, 'rb') as inp:
199 | for line in inp:
200 | total_line += 1
201 |
202 | base = 2 * args.line_cnt
203 | assert base < total_line
204 | step = total_line // base
205 |
206 | train_word_file = args.train_file_pre + ".txt"
207 | train_label_file = args.train_file_pre + ".lb"
208 | eval_word_file = args.eval_file_pre + ".txt"
209 | eval_label_file = args.eval_file_pre + ".lb"
210 |
211 | generate_files(corpora, args.vob_path, args.char_file,
212 | train_word_file, train_label_file, eval_word_file,
213 | eval_label_file, args.eval_gold_file, args.test_file,
214 | args.gold_file, step, args.word_freq, args.max_len)
215 |
216 |
217 | def analyze_write(inp, word_writer, label_writer,
218 | vob_dict = defaultdict(int)):
219 | with inp, word_writer, label_writer:
220 | for line in inp:
221 | line_list = line.strip().split()
222 | if len(line_list) < 1:
223 | continue
224 |
225 | char_list, label_list = analyze_line(line_list, vob_dict)
226 | write_line(char_list, word_writer)
227 | write_line(label_list, label_writer)
228 |
229 |
230 | # used for icwb2 data
231 | def icwb_main(args):
232 | corpora = args.all_corpora
233 | assert os.path.exists(corpora)
234 | gold_file = args.gold_file
235 | assert os.path.exists(gold_file)
236 | freq = args.word_freq
237 |
238 | train_word_file = args.train_file_pre + ".txt"
239 | train_label_file = args.train_file_pre + ".lb"
240 | eval_word_file = args.eval_file_pre + ".txt"
241 | eval_label_file = args.eval_file_pre + ".lb"
242 |
243 | train_inp = codecs.open(corpora, 'r', "utf-8")
244 | gold_inp = codecs.open(gold_file, 'r', "utf-8")
245 |
246 | ch_wr = codecs.open(args.char_file, 'w', "utf-8")
247 | tr_wd_wr = codecs.open(train_word_file, 'w', "utf-8")
248 | tr_lb_wr = codecs.open(train_label_file, 'w', "utf-8")
249 | ev_wd_wr = codecs.open(eval_word_file, 'w', "utf-8")
250 | ev_lb_wr = codecs.open(eval_label_file, 'w', "utf-8")
251 |
252 | vob_dict = defaultdict(int)
253 | analyze_write(train_inp, tr_wd_wr, tr_lb_wr, vob_dict)
254 | analyze_write(gold_inp, ev_wd_wr, ev_lb_wr)
255 |
256 | train_inp = codecs.open(corpora, 'r', "utf-8")
257 | with train_inp, ch_wr:
258 | for line in train_inp:
259 | phrases = line.strip().split()
260 | char_list = []
261 | for phr in phrases:
262 | for ch in phr:
263 | if vob_dict[ch] < freq:
264 | char_list.append(UNK)
265 | else:
266 | char_list.append(ch)
267 |
268 | write_line(char_list, ch_wr)
269 |
270 | word_cnt = 0
271 | with codecs.open(args.vob_path, 'w', "utf-8") as vob_wr:
272 | vob_wr.write(UNK + u'\n')
273 | for word, fq in vob_dict.items():
274 | if fq >= freq:
275 | vob_wr.write(word + u'\n')
276 | word_cnt += 1
277 |
278 | print("Finished, handling icwb2 data.")
279 | print("Select %d chars from the original %d chars" % (word_cnt, len(vob_dict)))
280 |
281 |
282 | if __name__ == '__main__':
283 | parser = argparse.ArgumentParser()
284 | parser.register("type", "bool", lambda v: v.lower() == "true")
285 |
286 | # input
287 | parser.add_argument(
288 | "--all_corpora",
289 | type = str,
290 | default = "/home/synrey/data/people2014All.txt",
291 | help = "all the corpora")
292 |
293 | # output
294 | parser.add_argument(
295 | "--vob_path",
296 | type = str,
297 | default = "/home/synrey/data/cws-v2-data/vocab.txt",
298 | help = "vocabulary's path")
299 |
300 | parser.add_argument(
301 | "--char_file",
302 | type = str,
303 | default = "/home/synrey/data/cws-v2-data/chars.txt",
304 | help = "the file used for word2vec pretraining")
305 |
306 | parser.add_argument(
307 | "--train_file_pre",
308 | type = str,
309 | default = "/home/synrey/data/cws-v2-data/train",
310 | help = "training file's prefix")
311 |
312 | parser.add_argument(
313 | "--eval_file_pre",
314 | type = str,
315 | default = "/home/synrey/data/cws-v2-data/eval",
316 | help = "eval file's prefix")
317 |
318 | parser.add_argument(
319 | "--eval_gold_file",
320 | type = str,
321 | default = "/home/synrey/data/cws-v2-data/eval_gold.txt",
322 | help = """gold file, used for the evaluation during training, \
323 | only generated for the 'people' corpus""")
324 |
325 | parser.add_argument(
326 | "--test_file",
327 | type = str,
328 | default = "/home/synrey/data/cws-v2-data/test.txt",
329 | help = "test file, raw sentences")
330 |
331 | parser.add_argument(
332 | "--gold_file",
333 | type = str,
334 | default = "/home/synrey/data/cws-v2-data/gold.txt",
335 | help = "gold file, segmented sentences")
336 |
337 | # parameters
338 | parser.add_argument(
339 | "--word_freq",
340 | type = int,
341 | default = 3,
342 | help = "word frequency")
343 |
344 | parser.add_argument(
345 | "--line_cnt",
346 | type = int,
347 | default = 8000,
348 | help = "the number of lines in eval or test file")
349 |
350 | # NOTE: It is the max length of word sequence, not char.
351 | parser.add_argument(
352 | "--max_len",
353 | type = int,
354 | default = 120,
355 | help = "deprecate the sentences longer than ")
356 |
357 | parser.add_argument(
358 | "--is_people",
359 | type = "bool",
360 | default = True,
361 | help = "Whether it is handling with People corpora")
362 |
363 | args = parser.parse_args()
364 | if args.is_people:
365 | people_main(args)
366 | else:
367 | icwb_main(args)
--------------------------------------------------------------------------------
/sycws/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/sycws/__init__.py
--------------------------------------------------------------------------------
/sycws/data_iterator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 03/24/2018
6 | #
7 | # Description: The iterator of training and inference data
8 | #
9 | # Last Modified at: 04/03/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import print_function
30 |
31 | from collections import namedtuple
32 |
33 | import tensorflow as tf
34 |
35 | __all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"]
36 |
37 |
38 | class BatchedInput(
39 | namedtuple("BatchedInput",
40 | ("initializer", "text", "label",
41 | "text_raw", "sequence_length"))):
42 | pass
43 |
44 |
45 | def get_infer_iterator(src_dataset,
46 | vocab_table,
47 | index_table,
48 | batch_size,
49 | max_len = None):
50 |
51 | src_dataset = src_dataset.map(lambda src : tf.string_split([src]).values)
52 |
53 | if max_len:
54 | src_dataset = src_dataset.map(lambda src : src[:max_len])
55 | # Convert the word strings to ids
56 | src_dataset = src_dataset.map(
57 | lambda src : (src, tf.cast(vocab_table.lookup(src), tf.int32)))
58 | # Add in the word counts.
59 | # Add the fake labels, refer to model_helper.py line 160.
60 | src_dataset = src_dataset.map(lambda src_raw, src : (src_raw, src,
61 | tf.cast(index_table.lookup(src_raw), tf.int32), tf.size(src)))
62 |
63 | def batching_func(x):
64 | return x.padded_batch(
65 | batch_size,
66 | # The entry is the source line rows;
67 | # this has unknown-length vectors. The last entry is
68 | # the source row size; this is a scalar.
69 | padded_shapes=(
70 | tf.TensorShape([None]), # src_raw
71 | tf.TensorShape([None]), # src_ids
72 | tf.TensorShape([None]), # fake label ids
73 | tf.TensorShape([])), # src_len
74 | # Pad the source sequences with eos tokens.
75 | # (Though notice we don't generally need to do this since
76 | # later on we will be masking out calculations past the true sequence.
77 |
78 | # padding_values = 0, default value
79 | )
80 |
81 | batched_dataset = batching_func(src_dataset)
82 | batched_iter = batched_dataset.make_initializable_iterator()
83 | (src_raw, src_ids, lb_ids, src_seq_len) = batched_iter.get_next()
84 | return BatchedInput(
85 | initializer = batched_iter.initializer,
86 | text = src_ids,
87 | label = lb_ids,
88 | text_raw = src_raw,
89 | sequence_length = src_seq_len)
90 |
91 |
92 | def get_iterator(txt_dataset,
93 | lb_dataset,
94 | vocab_table,
95 | index_table,
96 | batch_size,
97 | num_buckets,
98 | max_len = None,
99 | output_buffer_size = None,
100 | num_parallel_calls = 4):
101 |
102 | if not output_buffer_size:
103 | output_buffer_size = batch_size * 1000
104 | txt_lb_dataset = tf.data.Dataset.zip((txt_dataset, lb_dataset))
105 | txt_lb_dataset = txt_lb_dataset.shuffle(output_buffer_size)
106 |
107 | txt_lb_dataset = txt_lb_dataset.map(
108 | lambda txt, lb : (
109 | tf.string_split([txt]).values, tf.string_split([lb]).values),
110 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size)
111 |
112 | # Filter zero length input sequences.
113 | txt_lb_dataset = txt_lb_dataset.filter(
114 | lambda txt, lb : tf.logical_and(tf.size(txt) > 0, tf.size(lb) > 0))
115 |
116 | if max_len:
117 | txt_lb_dataset = txt_lb_dataset.map(
118 | lambda txt, lb : (txt[ : max_len], lb[ : max_len]),
119 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size)
120 | # Convert the word strings to ids. Word strings that are not in the
121 | # vocab get the lookup table's default_value integer.
122 |
123 | txt_lb_dataset = txt_lb_dataset.map(
124 | lambda txt, lb : (
125 | tf.cast(vocab_table.lookup(txt), tf.int32),
126 | tf.cast(index_table.lookup(lb), tf.int32)),
127 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size)
128 |
129 | # Add in sequence lengths.
130 | txt_lb_dataset = txt_lb_dataset.map(
131 | lambda txt, lb : (
132 | txt, lb, tf.size(txt)),
133 | num_parallel_calls = num_parallel_calls).prefetch(output_buffer_size)
134 |
135 | # Bucket by sequence length (buckets for lengths 0-9, 10-19, ...)
136 | def batching_func(x):
137 | return x.padded_batch(
138 | batch_size,
139 | # The first two entries are the text and label line rows;
140 | # these have unknown-length vectors. The last entry is
141 | # the row sizes; these are scalars.
142 | padded_shapes = (
143 | tf.TensorShape([None]), # txt
144 | tf.TensorShape([None]), # lb
145 | tf.TensorShape([])), # length
146 | # Pad the text and label sequences with eos tokens.
147 | # (Though notice we don't generally need to do this since
148 | # later on we will be masking out calculations past the true sequence.
149 |
150 | # padding_values = (0, 0, 0) # default values,
151 | # 0 for length--unused though
152 | )
153 |
154 | if num_buckets > 1:
155 |
156 | def key_func(unused_1, unused_2, seq_len):
157 | # Calculate bucket_width by maximum text sequence length.
158 | # Pairs with length [0, bucket_width) go to bucket 0, length
159 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length
160 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket.
161 | if max_len:
162 | bucket_width = (max_len + num_buckets - 1) // num_buckets
163 | else:
164 | bucket_width = 25
165 |
166 | # Bucket sentence pairs by the length of their text sentence and label
167 | # sentence.
168 | bucket_id = seq_len // bucket_width
169 | return tf.to_int64(tf.minimum(num_buckets, bucket_id))
170 |
171 | def reduce_func(unused_key, windowed_data):
172 | return batching_func(windowed_data)
173 |
174 | batched_dataset = txt_lb_dataset.apply(
175 | tf.contrib.data.group_by_window(
176 | key_func = key_func, reduce_func = reduce_func, window_size = batch_size))
177 |
178 | # One batch could have multiple windows, although there is just one window
179 | # in a batch.
180 |
181 | else:
182 | batched_dataset = batching_func(txt_lb_dataset)
183 | batched_iter = batched_dataset.make_initializable_iterator()
184 | (txt_ids, lb_ids, seq_len) = batched_iter.get_next()
185 | return BatchedInput(
186 | initializer = batched_iter.initializer,
187 | text = txt_ids,
188 | label = lb_ids,
189 | text_raw = None,
190 | sequence_length = seq_len)
--------------------------------------------------------------------------------
/sycws/indices.txt:
--------------------------------------------------------------------------------
1 | 0
2 | 1
3 | 2
4 | 3
5 |
--------------------------------------------------------------------------------
/sycws/main_body.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 03/24/2018
6 | #
7 | # Description: Training, evaluation and inference
8 | #
9 | # Last Modified at: 05/21/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import absolute_import
30 | from __future__ import print_function
31 | from __future__ import division
32 |
33 | from . import model_helper
34 | from . import prf_script
35 |
36 | import tensorflow as tf
37 | import numpy as np
38 |
39 | import time
40 | import os
41 | import codecs
42 |
43 |
44 | TAG_S = 0
45 | TAG_B = 1
46 | TAG_M = 2
47 | TAG_E = 3
48 |
49 |
50 | def train(hparams, model_creator):
51 | out_dir = hparams.out_dir
52 | num_train_steps = hparams.num_train_steps
53 | steps_per_stats = hparams.steps_per_stats
54 | steps_per_external_eval = hparams.steps_per_external_eval
55 | if not steps_per_external_eval:
56 | steps_per_external_eval = 10 * steps_per_stats
57 |
58 | train_model = model_helper.create_train_model(hparams, model_creator)
59 | eval_model = model_helper.create_eval_model(hparams, model_creator)
60 | infer_model = model_helper.create_infer_model(hparams, model_creator)
61 |
62 | eval_txt_file = "%s.%s" % (hparams.eval_prefix, "txt")
63 | eval_lb_file = "%s.%s" % (hparams.eval_prefix, "lb")
64 | eval_iterator_feed_dict = {
65 | eval_model.txt_file_placeholder : eval_txt_file,
66 | eval_model.lb_file_placeholder : eval_lb_file
67 | }
68 |
69 | model_dir = hparams.out_dir
70 |
71 | # TensorFlow model
72 | train_sess = tf.Session(graph = train_model.graph)
73 | eval_sess = tf.Session(graph = eval_model.graph)
74 | infer_sess = tf.Session(graph = infer_model.graph)
75 |
76 | # Read the infer data
77 | with codecs.getreader("utf-8")(
78 | tf.gfile.GFile(eval_txt_file, mode="rb")) as f:
79 | infer_data = f.read().splitlines()
80 |
81 | with train_model.graph.as_default():
82 | loaded_train_model, global_step = model_helper.create_or_load_model(
83 | train_model.model, model_dir, train_sess, "train", init = True)
84 |
85 | print("First evaluation:")
86 | _run_full_eval(hparams, loaded_train_model, train_sess, eval_model,
87 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model,
88 | infer_sess, infer_data, global_step, init = True)
89 |
90 | print("# Initialize train iterator...")
91 | train_sess.run(train_model.iterator.initializer)
92 |
93 | process_time = 0.0
94 | while global_step < num_train_steps:
95 | # train a batch
96 | start_time = time.time()
97 | try:
98 | step_result = loaded_train_model.train(train_sess)
99 | process_time += time.time() - start_time
100 | except tf.errors.OutOfRangeError:
101 | # finish one epoch
102 | print(
103 | "# Finished an epoch, step %d. Perform evaluation" %
104 | global_step)
105 |
106 | # Save checkpoint
107 | loaded_train_model.saver.save(
108 | train_sess,
109 | os.path.join(out_dir, "segmentation.ckpt"),
110 | global_step = global_step)
111 | _run_full_eval(hparams, loaded_train_model, train_sess, eval_model,
112 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model,
113 | infer_sess, infer_data, global_step, init = False)
114 |
115 | train_sess.run(train_model.iterator.initializer)
116 | continue
117 |
118 | _, train_loss, global_step, batch_size = step_result
119 | if global_step % steps_per_stats == 0:
120 | avg_time = process_time / steps_per_stats
121 | # print loss info
122 | print("[%d][loss]: %f, time per step: %.2fs" % (global_step,
123 | train_loss, avg_time))
124 | process_time = 0.0
125 |
126 | if global_step % steps_per_external_eval == 0:
127 | # Save checkpoint
128 | loaded_train_model.saver.save(
129 | train_sess,
130 | os.path.join(out_dir, "segmentation.ckpt"),
131 | global_step = global_step)
132 |
133 | print("External Evaluation:")
134 | _run_full_eval(hparams, loaded_train_model, train_sess, eval_model,
135 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model,
136 | infer_sess, infer_data, global_step, init = False)
137 |
138 |
139 | def evaluation(eval_model, model_dir, eval_sess,
140 | eval_iterator_feed_dict, init = True):
141 | with eval_model.graph.as_default():
142 | loaded_eval_model, global_step = model_helper.create_or_load_model(
143 | eval_model.model, model_dir, eval_sess, "eval", init)
144 |
145 | eval_sess.run(eval_model.iterator.initializer,
146 | feed_dict = eval_iterator_feed_dict)
147 |
148 | total_char_cnt = 0
149 | total_right_cnt = 0
150 | total_line = 0
151 | while True:
152 | try:
153 | (batch_char_cnt, batch_right_cnt, batch_size,
154 | batch_lens) = loaded_eval_model.eval(eval_sess)
155 |
156 | for right_cnt, length in zip(batch_right_cnt, batch_lens):
157 | total_right_cnt += np.sum(right_cnt[ : length])
158 |
159 | total_char_cnt += batch_char_cnt
160 | total_line += batch_size
161 | except tf.errors.OutOfRangeError:
162 | # finish the evaluation
163 | break
164 |
165 | precision = total_right_cnt / total_char_cnt
166 | print("Tagging precision: %.3f, of total %d lines" % (precision, total_line))
167 |
168 |
169 | def _eval_inference(infer_model, infer_sess, infer_data, model_dir, hparams, init):
170 | with infer_model.graph.as_default():
171 | loaded_infer_model, global_step = model_helper.create_or_load_model(
172 | infer_model.model, model_dir, infer_sess, "infer", init)
173 |
174 | infer_sess.run(
175 | infer_model.iterator.initializer,
176 | feed_dict = {
177 | infer_model.txt_placeholder: infer_data,
178 | infer_model.batch_size_placeholder: hparams.infer_batch_size
179 | })
180 |
181 | test_list = []
182 | while True:
183 | try:
184 | text_raw, decoded_tags, seq_lens = loaded_infer_model.infer(infer_sess)
185 | except tf.errors.OutOfRangeError:
186 | # finish the evaluation
187 | break
188 | _decode_by_function(lambda x : test_list.append(x), text_raw, decoded_tags, seq_lens)
189 |
190 | gold_file = hparams.eval_gold_file
191 | score = prf_script.get_prf_score(test_list, gold_file)
192 | return score
193 |
194 |
195 | def _run_full_eval(hparams, loaded_train_model, train_sess, eval_model,
196 | model_dir, eval_sess, eval_iterator_feed_dict, infer_model,
197 | infer_sess, infer_data, global_step, init):
198 |
199 | evaluation(eval_model, model_dir, eval_sess,
200 | eval_iterator_feed_dict)
201 | score = _eval_inference(infer_model, infer_sess, infer_data,
202 | model_dir, hparams, init)
203 | # save the best model
204 | if score > getattr(hparams, "best_Fvalue"):
205 | setattr(hparams, "best_Fvalue", score)
206 | loaded_train_model.saver.save(
207 | train_sess,
208 | os.path.join(
209 | getattr(hparams, "best_Fvalue_dir"), "segmentation.ckpt"),
210 | global_step = global_step)
211 |
212 |
213 | def load_data(inference_input_file):
214 | # Load inference data.
215 | inference_data = []
216 | with codecs.getreader("utf-8")(
217 | tf.gfile.GFile(inference_input_file, mode="rb")) as f:
218 | for line in f:
219 | line = line.strip()
220 | if line:
221 | inference_data.append(u' '.join(list(line)))
222 |
223 | return inference_data
224 |
225 |
226 | def _decode_by_function(writer_function, text_raw, decoded_tags, seq_lens):
227 | assert len(text_raw) == len(decoded_tags)
228 | assert len(seq_lens) == len(decoded_tags)
229 |
230 | for text_line, tags_line, length in zip(text_raw, decoded_tags, seq_lens):
231 | text_line = text_line[ : length]
232 | tags_line = tags_line[ : length]
233 | newline = u""
234 |
235 | for char, tag in zip(text_line, tags_line):
236 | char = char.decode("utf-8")
237 | if tag == TAG_S or tag == TAG_B:
238 | newline += u' ' + char
239 | else:
240 | newline += char
241 |
242 | newline = newline.strip()
243 | writer_function(newline + u'\n')
244 |
245 |
246 | def inference(ckpt, input_file, trans_file, hparams, model_creator):
247 | infer_model = model_helper.create_infer_model(hparams, model_creator)
248 |
249 | infer_sess = tf.Session(graph = infer_model.graph)
250 | with infer_model.graph.as_default():
251 | loaded_infer_model = model_helper.load_model(infer_model.model,
252 | ckpt, infer_sess, "infer", init = True)
253 |
254 | # Read data
255 | infer_data = load_data(input_file)
256 | infer_sess.run(
257 | infer_model.iterator.initializer,
258 | feed_dict = {
259 | infer_model.txt_placeholder: infer_data,
260 | infer_model.batch_size_placeholder: hparams.infer_batch_size
261 | })
262 |
263 | with codecs.getwriter("utf-8")(
264 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f:
265 | while True:
266 | try:
267 | text_raw, decoded_tags, seq_lens = loaded_infer_model.infer(infer_sess)
268 | except tf.errors.OutOfRangeError:
269 | # finish the evaluation
270 | break
271 | _decode_by_function(lambda x : trans_f.write(x), text_raw, decoded_tags, seq_lens)
--------------------------------------------------------------------------------
/sycws/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 03/24/2018
6 | #
7 | # Description: Segmentation models
8 | #
9 | # Last Modified at: 05/21/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import absolute_import
30 | from __future__ import print_function
31 |
32 | from . import model_helper
33 | from . import data_iterator
34 |
35 | import tensorflow as tf
36 |
37 |
38 | class BasicModel(object):
39 | """Bi-LSTM + single layer + CRF"""
40 |
41 | def __init__(self, hparams, mode, iterator, vocab_table):
42 | assert isinstance(iterator, data_iterator.BatchedInput)
43 |
44 | self.iterator = iterator
45 | self.mode = mode
46 | self.vocab_table = vocab_table
47 | self.vocab_size = hparams.vocab_size
48 | self.time_major = hparams.time_major
49 |
50 | # Initializer
51 | self.initializer = tf.truncated_normal_initializer
52 |
53 | # Embeddings
54 | self.init_embeddings(hparams)
55 | # Because usually the last batch may be less than the batch
56 | # size we set, we should get batch_size dynamically.
57 | self.batch_size = tf.size(self.iterator.sequence_length)
58 |
59 | ## Train graph
60 | loss = self.build_graph(hparams)
61 |
62 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
63 | self.train_loss = loss
64 |
65 | elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
66 | self.char_count = tf.reduce_sum(self.iterator.sequence_length)
67 | self.right_count = self._calculate_right()
68 |
69 | elif self.mode == tf.contrib.learn.ModeKeys.INFER:
70 | self.decode_tags = self._decode()
71 |
72 | self.global_step = tf.Variable(0, trainable = False)
73 | params = tf.trainable_variables()
74 |
75 | # Gradients and SGD update operation for training the model.
76 | # Arrage for the embedding vars to appear at the beginning.
77 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
78 | self.learning_rate = tf.constant(hparams.learning_rate)
79 |
80 | # Optimizer, SGD
81 | opt = tf.train.AdamOptimizer(self.learning_rate)
82 |
83 | # Gradients
84 | gradients = tf.gradients(self.train_loss, params)
85 |
86 | clipped_grads, grad_norm = tf.clip_by_global_norm(
87 | gradients, hparams.max_gradient_norm)
88 |
89 | # self.grad_norm = grad_norm
90 | self.update = opt.apply_gradients(
91 | zip(clipped_grads, params), global_step = self.global_step)
92 |
93 | # Saver, saves 5 checkpoints by default
94 | self.saver = tf.train.Saver(
95 | tf.global_variables(), max_to_keep = 5)
96 |
97 | # Print trainable variables
98 | print("# Trainable variables")
99 | for param in params:
100 | print(" %s, %s" % (param.name, str(param.get_shape())))
101 |
102 |
103 | def init_embeddings(self, hparams, dtype = tf.float32):
104 | with tf.variable_scope("embeddings", dtype = dtype) as scope:
105 | if hparams.embed_file:
106 | self.char_embedding = model_helper.create_pretrained_emb_from_txt(
107 | hparams.vocab_file, hparams.embed_file)
108 | else:
109 | self.char_embedding = tf.get_variable(
110 | "char_embedding", [self.vocab_size, hparams.num_units], dtype,
111 | initializer = self.initializer(stddev = hparams.init_std))
112 |
113 |
114 | def build_graph(self, hparams):
115 | print("# creating %s graph ..." % self.mode)
116 | dtype = tf.float32
117 |
118 | with tf.variable_scope("model_body", dtype = dtype):
119 | # build Bi-LSTM
120 | encoder_outputs, encoder_state = self._encode_layer(hparams)
121 |
122 | # middle layer
123 | middle_outputs = self._middle_layer(encoder_outputs, hparams)
124 | self.middle_outputs = middle_outputs
125 |
126 | # Decoder layer
127 | xentropy = self._decode_layer(middle_outputs)
128 |
129 | # Loss
130 | if self.mode != tf.contrib.learn.ModeKeys.INFER:
131 | # get the regularization loss
132 | reg_loss = tf.losses.get_regularization_loss()
133 | loss = tf.reduce_mean(xentropy) + reg_loss
134 | else:
135 | loss = None
136 | return loss
137 |
138 |
139 | def _encode_layer(self, hparams, dtype = tf.float32):
140 | # Bi-LSTM
141 | iterator = self.iterator
142 |
143 | text = iterator.text
144 | # [batch_size, txt_ids]
145 | if self.time_major:
146 | text = tf.transpose(text)
147 | # [txt_ids, batch_size]
148 |
149 | with tf.variable_scope("encoder", dtype = dtype) as scope:
150 | # Look up embedding, emp_inp: [max_time, batch_size, num_units]
151 | encoder_emb_inp = tf.nn.embedding_lookup(
152 | self.char_embedding, text)
153 |
154 | encoder_outputs, encoder_state = (
155 | self._build_bidirectional_rnn(
156 | inputs = encoder_emb_inp,
157 | sequence_length = iterator.sequence_length,
158 | dtype = dtype,
159 | hparams = hparams))
160 |
161 | # Encoder_outpus (time_major): [max_time, batch_size, 2*num_units]
162 | return encoder_outputs, encoder_state
163 |
164 |
165 | def _build_bidirectional_rnn(self, inputs, sequence_length,
166 | dtype, hparams):
167 |
168 | num_units = hparams.num_units
169 | mode = self.mode
170 | dropout = hparams.dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
171 |
172 | fw_cell = tf.contrib.rnn.BasicLSTMCell(num_units)
173 | bw_cell = tf.contrib.rnn.BasicLSTMCell(num_units)
174 |
175 | if dropout > 0.0:
176 | fw_cell = tf.contrib.rnn.DropoutWrapper(
177 | cell = fw_cell, input_keep_prob = (1.0 - dropout))
178 |
179 | bw_cell = tf.contrib.rnn.DropoutWrapper(
180 | cell = bw_cell, input_keep_prob = (1.0 - dropout))
181 |
182 | bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
183 | fw_cell,
184 | bw_cell,
185 | inputs,
186 | dtype = dtype,
187 | sequence_length = sequence_length,
188 | time_major = self.time_major)
189 |
190 | # concatenate the fw and bw outputs
191 | return tf.concat(bi_outputs, -1), bi_state
192 |
193 |
194 | def _single_layer(self, encoder_outputs, hparams, dtype):
195 | num_units = hparams.num_units
196 | num_tags = hparams.num_tags
197 |
198 | with tf.variable_scope('middle', dtype = dtype) as scope:
199 | hidden_W = tf.get_variable(
200 | shape = [num_units * 2, hparams.num_tags],
201 | initializer = self.initializer(stddev = hparams.init_std),
202 | name = "weights",
203 | regularizer = tf.contrib.layers.l2_regularizer(0.001))
204 |
205 | hidden_b = tf.Variable(tf.zeros([num_tags], name = "bias"))
206 |
207 | encoder_outputs = tf.reshape(encoder_outputs, [-1, (2 * num_units)])
208 | middle_outputs = tf.add(tf.matmul(encoder_outputs, hidden_W), hidden_b)
209 | if self.time_major:
210 | middle_outputs = tf.reshape(middle_outputs,
211 | [-1, self.batch_size, num_tags])
212 | middle_outputs = tf.transpose(middle_outputs, [1, 0, 2])
213 | else:
214 | middle_outputs = tf.reshape(middle_outputs,
215 | [self.batch_size, -1, num_tags])
216 | # [batch_size, max_time, num_tags]
217 | return middle_outputs
218 |
219 |
220 | def _cnn_layer(self, encoder_outputs, hparams, dtype):
221 | num_units = hparams.num_units
222 | # CNN
223 | with tf.variable_scope('middle', dtype = dtype) as scope:
224 | cfilter = tf.get_variable(
225 | "cfilter",
226 | shape = [1, 2, 2 * num_units, hparams.num_tags],
227 | regularizer = tf.contrib.layers.l2_regularizer(0.0001),
228 | initializer = self.initializer(stddev = hparams.filter_init_std),
229 | dtype = tf.float32)
230 |
231 | return model_helper.create_cnn_layer(encoder_outputs, self.time_major,
232 | self.batch_size, num_units, cfilter)
233 |
234 |
235 | def _middle_layer(self, encoder_outputs, hparams, dtype = tf.float32):
236 | # single layer
237 | return self._single_layer(encoder_outputs, hparams, dtype)
238 |
239 |
240 | def _decode_layer(self, middle_outputs, dtype = tf.float32):
241 | # CRF
242 | with tf.variable_scope('decoder', dtype = dtype) as scope:
243 | log_likelihood, trans_params = tf.contrib.crf.crf_log_likelihood(
244 | middle_outputs, self.iterator.label, self.iterator.sequence_length)
245 |
246 | self.trans_params = trans_params
247 | return -log_likelihood
248 |
249 |
250 | def _decode(self):
251 | decode_tags, _ = tf.contrib.crf.crf_decode(self.middle_outputs,
252 | self.trans_params, self.iterator.sequence_length)
253 | # [batch_size, max_time]
254 |
255 | return decode_tags
256 |
257 |
258 | def _calculate_right(self):
259 | decode_tags = self._decode()
260 | sign_tensor = tf.equal(decode_tags, self.iterator.label)
261 | right_count = tf.cast(sign_tensor, tf.int32)
262 |
263 | return right_count
264 |
265 |
266 | def train(self, sess):
267 | assert self.mode == tf.contrib.learn.ModeKeys.TRAIN
268 | return sess.run([self.update,
269 | self.train_loss,
270 | self.global_step,
271 | self.batch_size])
272 |
273 |
274 | def eval(self, sess):
275 | assert self.mode == tf.contrib.learn.ModeKeys.EVAL
276 | return sess.run([self.char_count,
277 | self.right_count,
278 | self.batch_size,
279 | self.iterator.sequence_length])
280 |
281 |
282 | def infer(self, sess):
283 | assert self.mode == tf.contrib.learn.ModeKeys.INFER
284 | return sess.run([self.iterator.text_raw,
285 | self.decode_tags,
286 | self.iterator.sequence_length])
287 |
288 |
289 | class CnnCrfModel(BasicModel):
290 | """Bi-LSTM + CNN + CRF"""
291 |
292 | def _middle_layer(self, encoder_outputs, hparams, dtype = tf.float32):
293 | # CNN
294 | return self._cnn_layer(encoder_outputs, hparams, dtype)
--------------------------------------------------------------------------------
/sycws/model_helper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 03/24/2018
6 | #
7 | # Description: Help functions for the implementation of the models
8 | #
9 | # Last Modified at: 05/20/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import absolute_import
30 | from __future__ import print_function
31 |
32 | from collections import namedtuple
33 | from tensorflow.python.ops import lookup_ops
34 |
35 | from . import data_iterator
36 |
37 | import tensorflow as tf
38 | import numpy as np
39 | import time
40 | import codecs
41 |
42 | __all__ = [
43 | "create_train_model", "create_eval_model",
44 | "create_infer_model", "create_cnn_layer"
45 | "create_or_load_model", "load_model",
46 | "create_pretrained_emb_from_txt"
47 | ]
48 |
49 |
50 | UNK_ID = 0
51 |
52 |
53 | # subclass to purify the __dict__
54 | class TrainModel(
55 | namedtuple("TrainModel", ("graph", "model",
56 | "iterator"))):
57 | pass
58 |
59 |
60 | def create_train_model(hparams, model_creator):
61 | txt_file = "%s.%s" % (hparams.train_prefix, "txt")
62 | lb_file = "%s.%s" % (hparams.train_prefix, "lb")
63 | vocab_file = hparams.vocab_file
64 | index_file = hparams.index_file
65 |
66 | graph = tf.Graph()
67 |
68 | with graph.as_default(), tf.container("train"):
69 | vocab_table = lookup_ops.index_table_from_file(
70 | vocab_file, default_value = UNK_ID)
71 | # for the labels
72 | index_table = lookup_ops.index_table_from_file(
73 | index_file, default_value = 0)
74 |
75 | txt_dataset = tf.data.TextLineDataset(txt_file)
76 | lb_dataset = tf.data.TextLineDataset(lb_file)
77 |
78 | iterator = data_iterator.get_iterator(
79 | txt_dataset,
80 | lb_dataset,
81 | vocab_table,
82 | index_table,
83 | batch_size = hparams.batch_size,
84 | num_buckets = hparams.num_buckets,
85 | max_len = hparams.max_len)
86 |
87 | model = model_creator(
88 | hparams,
89 | iterator = iterator,
90 | mode = tf.contrib.learn.ModeKeys.TRAIN,
91 | vocab_table = vocab_table)
92 |
93 | return TrainModel(
94 | graph = graph,
95 | model = model,
96 | iterator = iterator)
97 |
98 |
99 | class EvalModel(
100 | namedtuple("EvalModel",
101 | ("graph", "model", "txt_file_placeholder",
102 | "lb_file_placeholder", "iterator"))):
103 | pass
104 |
105 |
106 | def create_eval_model(hparams, model_creator):
107 | vocab_file = hparams.vocab_file
108 | index_file = hparams.index_file
109 | graph = tf.Graph()
110 |
111 | with graph.as_default(), tf.container("eval"):
112 | vocab_table = lookup_ops.index_table_from_file(
113 | vocab_file, default_value = UNK_ID)
114 | # for the labels
115 | index_table = lookup_ops.index_table_from_file(
116 | index_file, default_value = 0)
117 |
118 | # the file's name
119 | txt_file_placeholder = tf.placeholder(shape = (), dtype = tf.string)
120 | lb_file_placeholder = tf.placeholder(shape = (), dtype = tf.string)
121 | txt_dataset = tf.data.TextLineDataset(txt_file_placeholder)
122 | lb_dataset = tf.data.TextLineDataset(lb_file_placeholder)
123 |
124 | iterator = data_iterator.get_iterator(
125 | txt_dataset,
126 | lb_dataset,
127 | vocab_table,
128 | index_table,
129 | batch_size = hparams.batch_size,
130 | num_buckets = hparams.num_buckets,
131 | max_len = hparams.max_len)
132 |
133 | model = model_creator(
134 | hparams,
135 | iterator = iterator,
136 | mode = tf.contrib.learn.ModeKeys.EVAL,
137 | vocab_table = vocab_table)
138 |
139 | return EvalModel(
140 | graph = graph,
141 | model = model,
142 | txt_file_placeholder = txt_file_placeholder,
143 | lb_file_placeholder = lb_file_placeholder,
144 | iterator = iterator)
145 |
146 |
147 | class InferModel(
148 | namedtuple("InferModel",
149 | ("graph", "model", "txt_placeholder",
150 | "batch_size_placeholder", "iterator"))):
151 | pass
152 |
153 |
154 | def create_infer_model(hparams, model_creator):
155 | """Create inference model."""
156 | graph = tf.Graph()
157 | vocab_file = hparams.vocab_file
158 |
159 | with graph.as_default(), tf.container("infer"):
160 | vocab_table = lookup_ops.index_table_from_file(
161 | vocab_file, default_value = UNK_ID)
162 | # for the labels
163 | '''
164 | Although this is nonsense for the inference procedure, this is to ensure
165 | the labels are not None when building the model graph.
166 | (refer to model.BasicModel._decode_layer)
167 | '''
168 | mapping_strings = tf.constant(['0'])
169 | index_table = tf.contrib.lookup.index_table_from_tensor(
170 | mapping = mapping_strings, default_value = 0)
171 |
172 | txt_placeholder = tf.placeholder(shape=[None], dtype = tf.string)
173 | batch_size_placeholder = tf.placeholder(shape = [], dtype = tf.int64)
174 |
175 | txt_dataset = tf.data.Dataset.from_tensor_slices(
176 | txt_placeholder)
177 | iterator = data_iterator.get_infer_iterator(
178 | txt_dataset,
179 | vocab_table,
180 | index_table,
181 | batch_size = batch_size_placeholder)
182 |
183 | model = model_creator(
184 | hparams,
185 | iterator = iterator,
186 | mode = tf.contrib.learn.ModeKeys.INFER,
187 | vocab_table = vocab_table)
188 |
189 | return InferModel(
190 | graph = graph,
191 | model = model,
192 | txt_placeholder = txt_placeholder,
193 | batch_size_placeholder = batch_size_placeholder,
194 | iterator = iterator)
195 |
196 |
197 | def _load_vocab(vocab_file):
198 | vocab = []
199 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
200 | vocab_size = 0
201 | for word in f:
202 | vocab_size += 1
203 | vocab.append(word.strip())
204 | return vocab, vocab_size
205 |
206 |
207 | def _load_embed_txt(embed_file):
208 | """Load embed_file into a python dictionary.
209 |
210 | Note: the embed_file should be a Glove formated txt file. Assuming
211 | embed_size=5, for example:
212 |
213 | the -0.071549 0.093459 0.023738 -0.090339 0.056123
214 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037
215 | and 0.20327 0.47348 0.050877 0.002103 0.060547
216 |
217 | Note: The first line stores the information of the # of embeddings and
218 | the size of an embedding.
219 |
220 | Args:
221 | embed_file: file path to the embedding file.
222 | Returns:
223 | a dictionary that maps word to vector, and the size of embedding dimensions.
224 | """
225 | emb_dict = dict()
226 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f:
227 | emb_num, emb_size = f.readline().strip().split()
228 | emb_num = int(emb_num)
229 | emb_size = int(emb_size)
230 | for line in f:
231 | tokens = line.strip().split(" ")
232 | word = tokens[0]
233 | vec = list(map(float, tokens[1:]))
234 | emb_dict[word] = vec
235 | assert emb_size == len(vec), "All embedding size should be same."
236 | return emb_dict, emb_size
237 |
238 |
239 | def create_pretrained_emb_from_txt(vocab_file, embed_file, dtype = tf.float32):
240 | """Load pretrain embeding from embed_file, and return an embedding matrix.
241 |
242 | Args:
243 | embed_file: Path to a Glove formated embedding txt file.
244 | Note: we only need the embeddings whose corresponding words are in the
245 | vocab_file.
246 | """
247 | vocab, _ = _load_vocab(vocab_file)
248 |
249 | print('# Using pretrained embedding: %s.' % embed_file)
250 | emb_dict, emb_size = _load_embed_txt(embed_file)
251 |
252 | emb_mat = np.array(
253 | [emb_dict[token] for token in vocab], dtype = dtype.as_numpy_dtype())
254 |
255 | # The commented codes below are used for creating untrainable embeddings, which means
256 | # the value of each embedding is settled.
257 | # num_trainable_tokens = 1 # the unk token is trainable
258 | # emb_mat = tf.constant(emb_mat)
259 | # emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
260 | # with tf.variable_scope("pretrain_embeddings", dtype = dtype) as scope:
261 | # emb_mat_var = tf.get_variable(
262 | # "emb_mat_var", [num_trainable_tokens, emb_size])
263 | # return tf.concat([emb_mat_var, emb_mat_const], 0)
264 |
265 | # Whereas we use the pretrained embedding values as initial values,
266 | # so the embeddings can be trainable and their values can be changed.
267 | return tf.Variable(emb_mat, name = "char_embedding")
268 |
269 |
270 | def _char_convolution(inputs, cfilter):
271 | conv1 = tf.nn.conv2d(inputs, cfilter, [1, 1, 1, 1],
272 | padding = 'VALID')
273 | # inputs.shape = [batch_size, 1, 3, 2*num_units]
274 | # namely, in_height = 1, in_width = 3, in_channels = 2*num_units
275 | # filter.shape = [1, 2, 2*num_units, num_tags], strides = [1, 1, 1, 1]
276 | # conv1.shape = [batch_size, 1, 2, num_tags]
277 | conv1 = tf.nn.relu(conv1)
278 | pool1 = tf.nn.max_pool(conv1,
279 | ksize = [1, 1, 2, 1],
280 | strides = [1, 1, 1, 1],
281 | padding = 'VALID')
282 |
283 | # pool1.shape = [batch_size, 1, 1, num_tags]
284 | pool1 = tf.squeeze(pool1, [1, 2])
285 | # pool1.shape = [batch_size, num_tags]
286 | return pool1
287 |
288 |
289 | def create_cnn_layer(inputs, time_major, batch_size, num_units, cfilter):
290 | if not time_major:
291 | # trnaspose
292 | inputs = tf.trnaspose(inputs, [1, 0, 2])
293 |
294 | inputs = tf.expand_dims(inputs, 2)
295 | # [max_time, batch_size, 1, 2*num_units]
296 |
297 | left = inputs[1 : ]
298 | right = inputs[ : -1]
299 | left = tf.pad(left, [[1, 0], [0, 0], [0, 0], [0, 0]], "CONSTANT")
300 | right = tf.pad(right, [[0, 1], [0, 0], [0, 0], [0, 0]], "CONSTANT")
301 |
302 | char_blocks = tf.concat([left, inputs, right], 3)
303 | # [max_time, batch_size, 1, 3*2*num_units]
304 | char_blocks = tf.reshape(char_blocks, [-1, batch_size, 3, (2 * num_units)])
305 | char_blocks = tf.expand_dims(char_blocks, 2)
306 | # [max_time, batch_size, 1, 3, 2*num_units]
307 |
308 | do_char_conv = lambda x : _char_convolution(x, cfilter)
309 | cnn_outputs = tf.map_fn(do_char_conv, char_blocks)
310 | # [max_time, batch_size, num_tags]
311 |
312 | return tf.transpose(cnn_outputs, [1, 0, 2])
313 |
314 |
315 | def load_model(model, ckpt, session, name, init):
316 | start_time = time.time()
317 | model.saver.restore(session, ckpt)
318 | if init:
319 | session.run(tf.tables_initializer())
320 | print(
321 | " loaded %s model parameters from %s, time %.2fs" %
322 | (name, ckpt, time.time() - start_time))
323 | return model
324 |
325 |
326 | def create_or_load_model(model, model_dir, session, name, init):
327 | """Create segmentation model and initialize or load parameters in session."""
328 | latest_ckpt = tf.train.latest_checkpoint(model_dir)
329 | if latest_ckpt:
330 | model = load_model(model, latest_ckpt, session, name, init)
331 | else:
332 | start_time = time.time()
333 | session.run(tf.global_variables_initializer())
334 | session.run(tf.tables_initializer())
335 | print(" created %s model with fresh parameters, time %.2fs" %
336 | (name, time.time() - start_time))
337 |
338 | global_step = model.global_step.eval(session = session)
339 | return model, global_step
--------------------------------------------------------------------------------
/sycws/prf_script.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 05/20/2018
6 | #
7 | # Description: The PRF scoring script used for evaluation
8 | #
9 | # Last Modified at: 05/20/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import print_function
30 | from __future__ import division
31 |
32 | import codecs
33 |
34 | def get_prf_score(test_list, gold_file):
35 | e = 0 # wrong words number
36 | c = 0 # correct words number
37 | N = 0 # gold words number
38 | TN = 0 # test words number
39 |
40 | assert type(test_list) == list
41 | test_raw = []
42 | for line in test_list:
43 | sent = line.strip().split()
44 | if sent:
45 | test_raw.append(sent)
46 |
47 | gold_raw = []
48 | with codecs.open(gold_file, 'r', "utf-8") as inpt2:
49 | for line in inpt2:
50 | sent = line.strip().split()
51 | if sent:
52 | gold_raw.append(sent)
53 | N += len(sent)
54 |
55 | for i, gold_sent in enumerate(gold_raw):
56 | test_sent = test_raw[i]
57 |
58 | ig = 0
59 | it = 0
60 | glen = len(gold_sent)
61 | tlen = len(test_sent)
62 | while True:
63 | if ig >= glen or it >= tlen:
64 | break
65 |
66 | gword = gold_sent[ig]
67 | tword = test_sent[it]
68 | if gword == tword:
69 | c += 1
70 | else:
71 | lg = len(gword)
72 | lt = len(tword)
73 | while lg != lt:
74 | try:
75 | if lg < lt:
76 | ig += 1
77 | gword = gold_sent[ig]
78 | lg += len(gword)
79 | else:
80 | it += 1
81 | tword = test_sent[it]
82 | lt += len(tword)
83 | except Exception as e:
84 | # pdb.set_trace()
85 | print ("Line: %d" % (i + 1))
86 | print ("\nIt is the user's responsibility that a sentence in must", end = ' ')
87 | print ("have the SAME LENGTH with its corresponding sentence in .\n")
88 | raise e
89 |
90 | ig += 1
91 | it += 1
92 |
93 | TN += len(test_sent)
94 |
95 | e = TN - c
96 | precision = c / TN
97 | recall = c / N
98 | F = 2 * precision * recall / (precision + recall)
99 | error_rate = e / N
100 |
101 | print ("Correct words: %d"%c)
102 | print ("Error words: %d"%e)
103 | print ("Gold words: %d\n"%N)
104 | print ("precision: %f"%precision)
105 | print ("recall: %f"%recall)
106 | print ("F-Value: %f"%F)
107 | print ("error_rate: %f"%error_rate)
108 |
109 | return F
--------------------------------------------------------------------------------
/sycws/sycws.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 03/21/2018
6 | #
7 | # Description: A neural network tool for Chinese word segmentation
8 | #
9 | # Last Modified at: 05/21/2018, by: Synrey Yee
10 |
11 | '''
12 | ==========================================================================
13 | Copyright 2018 Xingyu Yi (Alias: Synrey Yee) All Rights Reserved.
14 |
15 | Licensed under the Apache License, Version 2.0 (the "License");
16 | you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software
22 | distributed under the License is distributed on an "AS IS" BASIS,
23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | See the License for the specific language governing permissions and
25 | limitations under the License.
26 | ==========================================================================
27 | '''
28 |
29 | from __future__ import absolute_import
30 | from __future__ import print_function
31 |
32 | from . import main_body
33 | from . import model as model_r
34 |
35 | import tensorflow as tf
36 |
37 | import argparse
38 | import sys
39 | import os
40 | import codecs
41 |
42 |
43 | # the parsed parameters
44 | FLAGS = None
45 | UNK = u"unk"
46 |
47 |
48 | # build the training parameters
49 | def add_arguments(parser):
50 | parser.register("type", "bool", lambda v: v.lower() == "true")
51 |
52 | # data
53 | parser.add_argument("--train_prefix", type = str, default = None,
54 | help = "Train prefix, expect files with .txt/.lb suffixes.")
55 | parser.add_argument("--eval_prefix", type = str, default = None,
56 | help = "Eval prefix, expect files with .txt/.lb suffixes.")
57 | parser.add_argument("--eval_gold_file", type = str, default = None,
58 | help = "Eval gold file.")
59 |
60 | parser.add_argument("--vocab_file", type = str, default = None,
61 | help = "Vocabulary file.")
62 | parser.add_argument("--embed_file", type = str, default = None, help="""\
63 | Pretrained embedding files, The expecting files should be Glove formated txt files.\
64 | """)
65 | parser.add_argument("--index_file", type = str, default = "./sycws/indices.txt",
66 | help = "Indices file.")
67 | parser.add_argument("--out_dir", type = str, default = None,
68 | help = "Store log/model files.")
69 | parser.add_argument("--max_len", type = int, default = 150,
70 | help = "Max length of char sequences during training.")
71 |
72 | # hyperparameters
73 | parser.add_argument("--num_units", type = int, default = 100, help = "Network size.")
74 | parser.add_argument("--model", type = str, default = "CNN-CRF",
75 | help = "2 kind of models: BiLSTM + (CRF | CNN-CRF)")
76 |
77 | parser.add_argument("--learning_rate", type = float, default = 0.001,
78 | help = "Learning rate. Adam: 0.001 | 0.0001")
79 |
80 | parser.add_argument("--num_train_steps",
81 | type = int, default = 45000, help = "Num steps to train.")
82 |
83 | parser.add_argument("--init_std", type = float, default = 0.05,
84 | help = "for truncated normal init_op")
85 | parser.add_argument("--filter_init_std", type = float, default = 0.035,
86 | help = "truncated normal initialization for CNN's filter")
87 |
88 | parser.add_argument("--dropout", type = float, default = 0.3,
89 | help = "Dropout rate (not keep_prob)")
90 | parser.add_argument("--max_gradient_norm", type = float, default = 5.0,
91 | help = "Clip gradients to this norm.")
92 | parser.add_argument("--batch_size", type = int, default = 128, help = "Batch size.")
93 |
94 | parser.add_argument("--steps_per_stats", type = int, default = 100,
95 | help = "How many training steps to print loss.")
96 | parser.add_argument("--num_buckets", type = int, default = 5,
97 | help = "Put data into similar-length buckets.")
98 | parser.add_argument("--steps_per_external_eval", type = int, default = None,
99 | help = """\
100 | How many training steps to do per external evaluation. Automatically set
101 | based on data if None.\
102 | """)
103 |
104 | parser.add_argument("--num_tags", type = int, default = 4, help = "BMES")
105 | parser.add_argument("--time_major", type = "bool", nargs = "?", const = True,
106 | default = True,
107 | help = "Whether to use time_major mode for dynamic RNN.")
108 |
109 | # Inference
110 | parser.add_argument("--ckpt", type = str, default = None,
111 | help = "Checkpoint file to load a model for inference.")
112 | parser.add_argument("--inference_input_file", type = str, default = None,
113 | help = "Set to the text to decode.")
114 | parser.add_argument("--infer_batch_size", type = int, default = 32,
115 | help = "Batch size for inference mode.")
116 | parser.add_argument("--inference_output_file", type = str, default = None,
117 | help = "Output file to store decoding results.")
118 |
119 | def create_hparams(flags):
120 | """Create training hparams."""
121 | return tf.contrib.training.HParams(
122 | # data
123 | train_prefix = flags.train_prefix,
124 | eval_prefix = flags.eval_prefix,
125 | eval_gold_file = flags.eval_gold_file,
126 | vocab_file = flags.vocab_file,
127 | embed_file = flags.embed_file,
128 |
129 | index_file = flags.index_file,
130 | out_dir = flags.out_dir,
131 | max_len = flags.max_len,
132 |
133 | # hparams
134 | num_units = flags.num_units,
135 | model = flags.model,
136 | learning_rate = flags.learning_rate,
137 | num_train_steps = flags.num_train_steps,
138 |
139 | init_std = flags.init_std,
140 | filter_init_std = flags.filter_init_std,
141 | dropout = flags.dropout,
142 | max_gradient_norm = flags.max_gradient_norm,
143 |
144 | batch_size = flags.batch_size,
145 | num_buckets = flags.num_buckets,
146 | steps_per_stats = flags.steps_per_stats,
147 | steps_per_external_eval = flags.steps_per_external_eval,
148 |
149 | num_tags = flags.num_tags,
150 | time_major = flags.time_major,
151 |
152 | # inference
153 | ckpt = flags.ckpt,
154 | inference_input_file = flags.inference_input_file,
155 | infer_batch_size = flags.infer_batch_size,
156 | inference_output_file = flags.inference_output_file,
157 | )
158 |
159 |
160 | def check_corpora(train_prefix, eval_prefix):
161 | train_txt = train_prefix + ".txt"
162 | train_lb = train_prefix + ".lb"
163 | eval_txt = eval_prefix + ".txt"
164 | eval_lb = eval_prefix + ".lb"
165 |
166 | def _inner_check(txt_reader, lb_reader):
167 | for txt_line, lb_line in zip(txt_reader, lb_reader):
168 | txt_length = len(txt_line.strip().split())
169 | lb_length = len(lb_line.strip().split())
170 | assert txt_length == lb_length
171 |
172 | train_txt_rd = codecs.open(train_txt, 'r', "utf-8")
173 | train_lb_rd = codecs.open(train_lb, 'r', "utf-8")
174 | eval_txt_rd = codecs.open(eval_txt, 'r', "utf-8")
175 | eval_lb_rd = codecs.open(eval_lb, 'r', "utf-8")
176 |
177 | with train_txt_rd, train_lb_rd:
178 | _inner_check(train_txt_rd, train_lb_rd)
179 |
180 | with eval_txt_rd, eval_lb_rd:
181 | _inner_check(eval_txt_rd, eval_lb_rd)
182 |
183 |
184 | def check_vocab(vocab_file):
185 | vocab = []
186 | with codecs.open(vocab_file, 'r', "utf-8") as vob_inp:
187 | for word in vob_inp:
188 | vocab.append(word.strip())
189 |
190 | if vocab[0] != UNK:
191 | vocab = [UNK] + vocab
192 | with codecs.open(vocab_file, 'w', "utf-8") as vob_opt:
193 | for word in vocab:
194 | vob_opt.write(word + u'\n')
195 |
196 | return len(vocab)
197 |
198 |
199 | def print_hparams(hparams):
200 | values = hparams.values()
201 | for key in sorted(values.keys()):
202 | print(" %s = %s" % (key, str(values[key])))
203 |
204 |
205 | def main(unused_argv):
206 | out_dir = FLAGS.out_dir
207 | if not tf.gfile.Exists(out_dir):
208 | tf.gfile.MakeDirs(out_dir)
209 |
210 | hparams = create_hparams(FLAGS)
211 | model = hparams.model.upper()
212 | if model == "CRF":
213 | model_creator = model_r.BasicModel
214 | elif model == "CNN-CRF":
215 | model_creator = model_r.CnnCrfModel
216 | else:
217 | raise ValueError("Unknown model %s" % model)
218 |
219 | assert tf.gfile.Exists(hparams.vocab_file)
220 | vocab_size = check_vocab(hparams.vocab_file)
221 | hparams.add_hparam("vocab_size", vocab_size)
222 |
223 | if FLAGS.inference_input_file:
224 | # Inference
225 | trans_file = FLAGS.inference_output_file
226 | ckpt = FLAGS.ckpt
227 | if not ckpt:
228 | ckpt = tf.train.latest_checkpoint(out_dir)
229 |
230 | main_body.inference(ckpt, FLAGS.inference_input_file,
231 | trans_file, hparams, model_creator)
232 | else:
233 | # Train
234 | check_corpora(FLAGS.train_prefix, FLAGS.eval_prefix)
235 |
236 | # used for evaluation
237 | hparams.add_hparam("best_Fvalue", 0) # larger is better
238 | best_metric_dir = os.path.join(hparams.out_dir, "best_Fvalue")
239 | hparams.add_hparam("best_Fvalue_dir", best_metric_dir)
240 | tf.gfile.MakeDirs(best_metric_dir)
241 |
242 | print_hparams(hparams)
243 | main_body.train(hparams, model_creator)
244 |
245 |
246 | if __name__ == '__main__':
247 | parser = argparse.ArgumentParser()
248 | add_arguments(parser)
249 | FLAGS, unparsed = parser.parse_known_args()
250 | tf.app.run(main = main, argv = [sys.argv[0]] + unparsed)
--------------------------------------------------------------------------------
/third_party/compile_w2v.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Author: Synrey Yee
4 | #
5 | # Created at: 05/07/2017
6 | #
7 | # Description: compile word2vec.c
8 | #
9 | # Last Modified at: 05/07/2017, by: Synrey Yee
10 |
11 | gcc word2vec.c -o word2vec -lm -lpthread
--------------------------------------------------------------------------------
/third_party/word2vec:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MeteorYee/LSTM-CNN-CWS/836451513587f38d054eac6b0ff3d4e39a142ae6/third_party/word2vec
--------------------------------------------------------------------------------
/third_party/word2vec.c:
--------------------------------------------------------------------------------
1 | // Copyright 2013 Google Inc. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 | #define MAX_STRING 100
22 | #define EXP_TABLE_SIZE 1000
23 | #define MAX_EXP 6
24 | #define MAX_SENTENCE_LENGTH 1000
25 | #define MAX_CODE_LENGTH 40
26 |
27 | const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary
28 |
29 | typedef float real; // Precision of float numbers
30 |
31 | struct vocab_word {
32 | long long cn;
33 | int *point;
34 | char *word, *code, codelen;
35 | };
36 |
37 | char train_file[MAX_STRING], output_file[MAX_STRING];
38 | char save_vocab_file[MAX_STRING], read_vocab_file[MAX_STRING];
39 | struct vocab_word *vocab;
40 | int binary = 0, cbow = 1, debug_mode = 2, window = 5, min_count = 5, num_threads = 12, min_reduce = 1;
41 | int *vocab_hash;
42 | long long vocab_max_size = 1000, vocab_size = 0, layer1_size = 100;
43 | long long train_words = 0, word_count_actual = 0, iter = 5, file_size = 0, classes = 0;
44 | real alpha = 0.025, starting_alpha, sample = 1e-3;
45 | real *syn0, *syn1, *syn1neg, *expTable;
46 | clock_t start;
47 |
48 | int hs = 0, negative = 5;
49 | const int table_size = 1e8;
50 | int *table;
51 |
52 | void InitUnigramTable() {
53 | int a, i;
54 | double train_words_pow = 0;
55 | double d1, power = 0.75;
56 | table = (int *)malloc(table_size * sizeof(int));
57 | for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power);
58 | i = 0;
59 | d1 = pow(vocab[i].cn, power) / train_words_pow;
60 | for (a = 0; a < table_size; a++) {
61 | table[a] = i;
62 | if (a / (double)table_size > d1) {
63 | i++;
64 | d1 += pow(vocab[i].cn, power) / train_words_pow;
65 | }
66 | if (i >= vocab_size) i = vocab_size - 1;
67 | }
68 | }
69 |
70 | // Reads a single word from a file, assuming space + tab + EOL to be word boundaries
71 | void ReadWord(char *word, FILE *fin) {
72 | int a = 0, ch;
73 | while (!feof(fin)) {
74 | ch = fgetc(fin);
75 | if (ch == 13) continue;
76 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
77 | if (a > 0) {
78 | if (ch == '\n') ungetc(ch, fin);
79 | break;
80 | }
81 | if (ch == '\n') {
82 | strcpy(word, (char *)"");
83 | return;
84 | } else continue;
85 | }
86 | word[a] = ch;
87 | a++;
88 | if (a >= MAX_STRING - 1) a--; // Truncate too long words
89 | }
90 | word[a] = 0;
91 | }
92 |
93 | // Returns hash value of a word
94 | int GetWordHash(char *word) {
95 | unsigned long long a, hash = 0;
96 | for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a];
97 | hash = hash % vocab_hash_size;
98 | return hash;
99 | }
100 |
101 | // Returns position of a word in the vocabulary; if the word is not found, returns -1
102 | int SearchVocab(char *word) {
103 | unsigned int hash = GetWordHash(word);
104 | while (1) {
105 | if (vocab_hash[hash] == -1) return -1;
106 | if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash];
107 | hash = (hash + 1) % vocab_hash_size;
108 | }
109 | return -1;
110 | }
111 |
112 | // Reads a word and returns its index in the vocabulary
113 | int ReadWordIndex(FILE *fin) {
114 | char word[MAX_STRING];
115 | ReadWord(word, fin);
116 | if (feof(fin)) return -1;
117 | return SearchVocab(word);
118 | }
119 |
120 | // Adds a word to the vocabulary
121 | int AddWordToVocab(char *word) {
122 | unsigned int hash, length = strlen(word) + 1;
123 | if (length > MAX_STRING) length = MAX_STRING;
124 | vocab[vocab_size].word = (char *)calloc(length, sizeof(char));
125 | strcpy(vocab[vocab_size].word, word);
126 | vocab[vocab_size].cn = 0;
127 | vocab_size++;
128 | // Reallocate memory if needed
129 | if (vocab_size + 2 >= vocab_max_size) {
130 | vocab_max_size += 1000;
131 | vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word));
132 | }
133 | hash = GetWordHash(word);
134 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
135 | vocab_hash[hash] = vocab_size - 1;
136 | return vocab_size - 1;
137 | }
138 |
139 | // Used later for sorting by word counts
140 | int VocabCompare(const void *a, const void *b) {
141 | return ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn;
142 | }
143 |
144 | // Sorts the vocabulary by frequency using word counts
145 | void SortVocab() {
146 | int a, size;
147 | unsigned int hash;
148 | // Sort the vocabulary and keep at the first position
149 | qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare);
150 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
151 | size = vocab_size;
152 | train_words = 0;
153 | for (a = 0; a < size; a++) {
154 | // Words occuring less than min_count times will be discarded from the vocab
155 | if ((vocab[a].cn < min_count) && (a != 0)) {
156 | vocab_size--;
157 | free(vocab[a].word);
158 | } else {
159 | // Hash will be re-computed, as after the sorting it is not actual
160 | hash=GetWordHash(vocab[a].word);
161 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
162 | vocab_hash[hash] = a;
163 | train_words += vocab[a].cn;
164 | }
165 | }
166 | vocab = (struct vocab_word *)realloc(vocab, (vocab_size + 1) * sizeof(struct vocab_word));
167 | // Allocate memory for the binary tree construction
168 | for (a = 0; a < vocab_size; a++) {
169 | vocab[a].code = (char *)calloc(MAX_CODE_LENGTH, sizeof(char));
170 | vocab[a].point = (int *)calloc(MAX_CODE_LENGTH, sizeof(int));
171 | }
172 | }
173 |
174 | // Reduces the vocabulary by removing infrequent tokens
175 | void ReduceVocab() {
176 | int a, b = 0;
177 | unsigned int hash;
178 | for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) {
179 | vocab[b].cn = vocab[a].cn;
180 | vocab[b].word = vocab[a].word;
181 | b++;
182 | } else free(vocab[a].word);
183 | vocab_size = b;
184 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
185 | for (a = 0; a < vocab_size; a++) {
186 | // Hash will be re-computed, as it is not actual
187 | hash = GetWordHash(vocab[a].word);
188 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
189 | vocab_hash[hash] = a;
190 | }
191 | fflush(stdout);
192 | min_reduce++;
193 | }
194 |
195 | // Create binary Huffman tree using the word counts
196 | // Frequent words will have short uniqe binary codes
197 | void CreateBinaryTree() {
198 | long long a, b, i, min1i, min2i, pos1, pos2, point[MAX_CODE_LENGTH];
199 | char code[MAX_CODE_LENGTH];
200 | long long *count = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long));
201 | long long *binary = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long));
202 | long long *parent_node = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long));
203 | for (a = 0; a < vocab_size; a++) count[a] = vocab[a].cn;
204 | for (a = vocab_size; a < vocab_size * 2; a++) count[a] = 1e15;
205 | pos1 = vocab_size - 1;
206 | pos2 = vocab_size;
207 | // Following algorithm constructs the Huffman tree by adding one node at a time
208 | for (a = 0; a < vocab_size - 1; a++) {
209 | // First, find two smallest nodes 'min1, min2'
210 | if (pos1 >= 0) {
211 | if (count[pos1] < count[pos2]) {
212 | min1i = pos1;
213 | pos1--;
214 | } else {
215 | min1i = pos2;
216 | pos2++;
217 | }
218 | } else {
219 | min1i = pos2;
220 | pos2++;
221 | }
222 | if (pos1 >= 0) {
223 | if (count[pos1] < count[pos2]) {
224 | min2i = pos1;
225 | pos1--;
226 | } else {
227 | min2i = pos2;
228 | pos2++;
229 | }
230 | } else {
231 | min2i = pos2;
232 | pos2++;
233 | }
234 | count[vocab_size + a] = count[min1i] + count[min2i];
235 | parent_node[min1i] = vocab_size + a;
236 | parent_node[min2i] = vocab_size + a;
237 | binary[min2i] = 1;
238 | }
239 | // Now assign binary code to each vocabulary word
240 | for (a = 0; a < vocab_size; a++) {
241 | b = a;
242 | i = 0;
243 | while (1) {
244 | code[i] = binary[b];
245 | point[i] = b;
246 | i++;
247 | b = parent_node[b];
248 | if (b == vocab_size * 2 - 2) break;
249 | }
250 | vocab[a].codelen = i;
251 | vocab[a].point[0] = vocab_size - 2;
252 | for (b = 0; b < i; b++) {
253 | vocab[a].code[i - b - 1] = code[b];
254 | vocab[a].point[i - b] = point[b] - vocab_size;
255 | }
256 | }
257 | free(count);
258 | free(binary);
259 | free(parent_node);
260 | }
261 |
262 | void LearnVocabFromTrainFile() {
263 | char word[MAX_STRING];
264 | FILE *fin;
265 | long long a, i;
266 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
267 | fin = fopen(train_file, "rb");
268 | if (fin == NULL) {
269 | printf("ERROR: training data file not found!\n");
270 | exit(1);
271 | }
272 | vocab_size = 0;
273 | AddWordToVocab((char *)"");
274 | while (1) {
275 | ReadWord(word, fin);
276 | if (feof(fin)) break;
277 | train_words++;
278 | if ((debug_mode > 1) && (train_words % 100000 == 0)) {
279 | printf("%lldK%c", train_words / 1000, 13);
280 | fflush(stdout);
281 | }
282 | i = SearchVocab(word);
283 | if (i == -1) {
284 | a = AddWordToVocab(word);
285 | vocab[a].cn = 1;
286 | } else vocab[i].cn++;
287 | if (vocab_size > vocab_hash_size * 0.7) ReduceVocab();
288 | }
289 | SortVocab();
290 | if (debug_mode > 0) {
291 | printf("Vocab size: %lld\n", vocab_size);
292 | printf("Words in train file: %lld\n", train_words);
293 | }
294 | file_size = ftell(fin);
295 | fclose(fin);
296 | }
297 |
298 | void SaveVocab() {
299 | long long i;
300 | FILE *fo = fopen(save_vocab_file, "wb");
301 | for (i = 0; i < vocab_size; i++) fprintf(fo, "%s %lld\n", vocab[i].word, vocab[i].cn);
302 | fclose(fo);
303 | }
304 |
305 | void ReadVocab() {
306 | long long a, i = 0;
307 | char c;
308 | char word[MAX_STRING];
309 | FILE *fin = fopen(read_vocab_file, "rb");
310 | if (fin == NULL) {
311 | printf("Vocabulary file not found\n");
312 | exit(1);
313 | }
314 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
315 | vocab_size = 0;
316 | while (1) {
317 | ReadWord(word, fin);
318 | if (feof(fin)) break;
319 | a = AddWordToVocab(word);
320 | fscanf(fin, "%lld%c", &vocab[a].cn, &c);
321 | i++;
322 | }
323 | SortVocab();
324 | if (debug_mode > 0) {
325 | printf("Vocab size: %lld\n", vocab_size);
326 | printf("Words in train file: %lld\n", train_words);
327 | }
328 | fin = fopen(train_file, "rb");
329 | if (fin == NULL) {
330 | printf("ERROR: training data file not found!\n");
331 | exit(1);
332 | }
333 | fseek(fin, 0, SEEK_END);
334 | file_size = ftell(fin);
335 | fclose(fin);
336 | }
337 |
338 | void InitNet() {
339 | long long a, b;
340 | unsigned long long next_random = 1;
341 | a = posix_memalign((void **)&syn0, 128, (long long)vocab_size * layer1_size * sizeof(real));
342 | if (syn0 == NULL) {printf("Memory allocation failed\n"); exit(1);}
343 | if (hs) {
344 | a = posix_memalign((void **)&syn1, 128, (long long)vocab_size * layer1_size * sizeof(real));
345 | if (syn1 == NULL) {printf("Memory allocation failed\n"); exit(1);}
346 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++)
347 | syn1[a * layer1_size + b] = 0;
348 | }
349 | if (negative>0) {
350 | a = posix_memalign((void **)&syn1neg, 128, (long long)vocab_size * layer1_size * sizeof(real));
351 | if (syn1neg == NULL) {printf("Memory allocation failed\n"); exit(1);}
352 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++)
353 | syn1neg[a * layer1_size + b] = 0;
354 | }
355 | for (a = 0; a < vocab_size; a++) for (b = 0; b < layer1_size; b++) {
356 | next_random = next_random * (unsigned long long)25214903917 + 11;
357 | syn0[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) / layer1_size;
358 | }
359 | CreateBinaryTree();
360 | }
361 |
362 | void *TrainModelThread(void *id) {
363 | long long a, b, d, cw, word, last_word, sentence_length = 0, sentence_position = 0;
364 | long long word_count = 0, last_word_count = 0, sen[MAX_SENTENCE_LENGTH + 1];
365 | long long l1, l2, c, target, label, local_iter = iter;
366 | unsigned long long next_random = (long long)id;
367 | real f, g;
368 | clock_t now;
369 | real *neu1 = (real *)calloc(layer1_size, sizeof(real));
370 | real *neu1e = (real *)calloc(layer1_size, sizeof(real));
371 | FILE *fi = fopen(train_file, "rb");
372 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET);
373 | while (1) {
374 | if (word_count - last_word_count > 10000) {
375 | word_count_actual += word_count - last_word_count;
376 | last_word_count = word_count;
377 | if ((debug_mode > 1)) {
378 | now=clock();
379 | printf("%cAlpha: %f Progress: %.2f%% Words/thread/sec: %.2fk ", 13, alpha,
380 | word_count_actual / (real)(iter * train_words + 1) * 100,
381 | word_count_actual / ((real)(now - start + 1) / (real)CLOCKS_PER_SEC * 1000));
382 | fflush(stdout);
383 | }
384 | alpha = starting_alpha * (1 - word_count_actual / (real)(iter * train_words + 1));
385 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001;
386 | }
387 | if (sentence_length == 0) {
388 | while (1) {
389 | word = ReadWordIndex(fi);
390 | if (feof(fi)) break;
391 | if (word == -1) continue;
392 | word_count++;
393 | if (word == 0) break;
394 | // The subsampling randomly discards frequent words while keeping the ranking same
395 | if (sample > 0) {
396 | real ran = (sqrt(vocab[word].cn / (sample * train_words)) + 1) * (sample * train_words) / vocab[word].cn;
397 | next_random = next_random * (unsigned long long)25214903917 + 11;
398 | if (ran < (next_random & 0xFFFF) / (real)65536) continue;
399 | }
400 | sen[sentence_length] = word;
401 | sentence_length++;
402 | if (sentence_length >= MAX_SENTENCE_LENGTH) break;
403 | }
404 | sentence_position = 0;
405 | }
406 | if (feof(fi) || (word_count > train_words / num_threads)) {
407 | word_count_actual += word_count - last_word_count;
408 | local_iter--;
409 | if (local_iter == 0) break;
410 | word_count = 0;
411 | last_word_count = 0;
412 | sentence_length = 0;
413 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET);
414 | continue;
415 | }
416 | word = sen[sentence_position];
417 | if (word == -1) continue;
418 | for (c = 0; c < layer1_size; c++) neu1[c] = 0;
419 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0;
420 | next_random = next_random * (unsigned long long)25214903917 + 11;
421 | b = next_random % window;
422 | if (cbow) { //train the cbow architecture
423 | // in -> hidden
424 | cw = 0;
425 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) {
426 | c = sentence_position - window + a;
427 | if (c < 0) continue;
428 | if (c >= sentence_length) continue;
429 | last_word = sen[c];
430 | if (last_word == -1) continue;
431 | for (c = 0; c < layer1_size; c++) neu1[c] += syn0[c + last_word * layer1_size];
432 | cw++;
433 | }
434 | if (cw) {
435 | for (c = 0; c < layer1_size; c++) neu1[c] /= cw;
436 | if (hs) for (d = 0; d < vocab[word].codelen; d++) {
437 | f = 0;
438 | l2 = vocab[word].point[d] * layer1_size;
439 | // Propagate hidden -> output
440 | for (c = 0; c < layer1_size; c++) f += neu1[c] * syn1[c + l2];
441 | if (f <= -MAX_EXP) continue;
442 | else if (f >= MAX_EXP) continue;
443 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
444 | // 'g' is the gradient multiplied by the learning rate
445 | g = (1 - vocab[word].code[d] - f) * alpha;
446 | // Propagate errors output -> hidden
447 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1[c + l2];
448 | // Learn weights hidden -> output
449 | for (c = 0; c < layer1_size; c++) syn1[c + l2] += g * neu1[c];
450 | }
451 | // NEGATIVE SAMPLING
452 | if (negative > 0) for (d = 0; d < negative + 1; d++) {
453 | if (d == 0) {
454 | target = word;
455 | label = 1;
456 | } else {
457 | next_random = next_random * (unsigned long long)25214903917 + 11;
458 | target = table[(next_random >> 16) % table_size];
459 | if (target == 0) target = next_random % (vocab_size - 1) + 1;
460 | if (target == word) continue;
461 | label = 0;
462 | }
463 | l2 = target * layer1_size;
464 | f = 0;
465 | for (c = 0; c < layer1_size; c++) f += neu1[c] * syn1neg[c + l2];
466 | if (f > MAX_EXP) g = (label - 1) * alpha;
467 | else if (f < -MAX_EXP) g = (label - 0) * alpha;
468 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha;
469 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2];
470 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * neu1[c];
471 | }
472 | // hidden -> in
473 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) {
474 | c = sentence_position - window + a;
475 | if (c < 0) continue;
476 | if (c >= sentence_length) continue;
477 | last_word = sen[c];
478 | if (last_word == -1) continue;
479 | for (c = 0; c < layer1_size; c++) syn0[c + last_word * layer1_size] += neu1e[c];
480 | }
481 | }
482 | } else { //train skip-gram
483 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) {
484 | c = sentence_position - window + a;
485 | if (c < 0) continue;
486 | if (c >= sentence_length) continue;
487 | last_word = sen[c];
488 | if (last_word == -1) continue;
489 | l1 = last_word * layer1_size;
490 | for (c = 0; c < layer1_size; c++) neu1e[c] = 0;
491 | // HIERARCHICAL SOFTMAX
492 | if (hs) for (d = 0; d < vocab[word].codelen; d++) {
493 | f = 0;
494 | l2 = vocab[word].point[d] * layer1_size;
495 | // Propagate hidden -> output
496 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1[c + l2];
497 | if (f <= -MAX_EXP) continue;
498 | else if (f >= MAX_EXP) continue;
499 | else f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
500 | // 'g' is the gradient multiplied by the learning rate
501 | g = (1 - vocab[word].code[d] - f) * alpha;
502 | // Propagate errors output -> hidden
503 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1[c + l2];
504 | // Learn weights hidden -> output
505 | for (c = 0; c < layer1_size; c++) syn1[c + l2] += g * syn0[c + l1];
506 | }
507 | // NEGATIVE SAMPLING
508 | if (negative > 0) for (d = 0; d < negative + 1; d++) {
509 | if (d == 0) {
510 | target = word;
511 | label = 1;
512 | } else {
513 | next_random = next_random * (unsigned long long)25214903917 + 11;
514 | target = table[(next_random >> 16) % table_size];
515 | if (target == 0) target = next_random % (vocab_size - 1) + 1;
516 | if (target == word) continue;
517 | label = 0;
518 | }
519 | l2 = target * layer1_size;
520 | f = 0;
521 | for (c = 0; c < layer1_size; c++) f += syn0[c + l1] * syn1neg[c + l2];
522 | if (f > MAX_EXP) g = (label - 1) * alpha;
523 | else if (f < -MAX_EXP) g = (label - 0) * alpha;
524 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha;
525 | for (c = 0; c < layer1_size; c++) neu1e[c] += g * syn1neg[c + l2];
526 | for (c = 0; c < layer1_size; c++) syn1neg[c + l2] += g * syn0[c + l1];
527 | }
528 | // Learn weights input -> hidden
529 | for (c = 0; c < layer1_size; c++) syn0[c + l1] += neu1e[c];
530 | }
531 | }
532 | sentence_position++;
533 | if (sentence_position >= sentence_length) {
534 | sentence_length = 0;
535 | continue;
536 | }
537 | }
538 | fclose(fi);
539 | free(neu1);
540 | free(neu1e);
541 | pthread_exit(NULL);
542 | }
543 |
544 | void TrainModel() {
545 | long a, b, c, d;
546 | FILE *fo;
547 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));
548 | printf("Starting training using file %s\n", train_file);
549 | starting_alpha = alpha;
550 | if (read_vocab_file[0] != 0) ReadVocab(); else LearnVocabFromTrainFile();
551 | if (save_vocab_file[0] != 0) SaveVocab();
552 | if (output_file[0] == 0) return;
553 | InitNet();
554 | if (negative > 0) InitUnigramTable();
555 | start = clock();
556 | for (a = 0; a < num_threads; a++) pthread_create(&pt[a], NULL, TrainModelThread, (void *)a);
557 | for (a = 0; a < num_threads; a++) pthread_join(pt[a], NULL);
558 | fo = fopen(output_file, "wb");
559 | if (classes == 0) {
560 | // Save the word vectors
561 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size);
562 | for (a = 0; a < vocab_size; a++) {
563 | fprintf(fo, "%s ", vocab[a].word);
564 | if (binary) for (b = 0; b < layer1_size; b++) fwrite(&syn0[a * layer1_size + b], sizeof(real), 1, fo);
565 | else for (b = 0; b < layer1_size; b++) fprintf(fo, "%lf ", syn0[a * layer1_size + b]);
566 | fprintf(fo, "\n");
567 | }
568 | } else {
569 | // Run K-means on the word vectors
570 | int clcn = classes, iter = 10, closeid;
571 | int *centcn = (int *)malloc(classes * sizeof(int));
572 | int *cl = (int *)calloc(vocab_size, sizeof(int));
573 | real closev, x;
574 | real *cent = (real *)calloc(classes * layer1_size, sizeof(real));
575 | for (a = 0; a < vocab_size; a++) cl[a] = a % clcn;
576 | for (a = 0; a < iter; a++) {
577 | for (b = 0; b < clcn * layer1_size; b++) cent[b] = 0;
578 | for (b = 0; b < clcn; b++) centcn[b] = 1;
579 | for (c = 0; c < vocab_size; c++) {
580 | for (d = 0; d < layer1_size; d++) cent[layer1_size * cl[c] + d] += syn0[c * layer1_size + d];
581 | centcn[cl[c]]++;
582 | }
583 | for (b = 0; b < clcn; b++) {
584 | closev = 0;
585 | for (c = 0; c < layer1_size; c++) {
586 | cent[layer1_size * b + c] /= centcn[b];
587 | closev += cent[layer1_size * b + c] * cent[layer1_size * b + c];
588 | }
589 | closev = sqrt(closev);
590 | for (c = 0; c < layer1_size; c++) cent[layer1_size * b + c] /= closev;
591 | }
592 | for (c = 0; c < vocab_size; c++) {
593 | closev = -10;
594 | closeid = 0;
595 | for (d = 0; d < clcn; d++) {
596 | x = 0;
597 | for (b = 0; b < layer1_size; b++) x += cent[layer1_size * d + b] * syn0[c * layer1_size + b];
598 | if (x > closev) {
599 | closev = x;
600 | closeid = d;
601 | }
602 | }
603 | cl[c] = closeid;
604 | }
605 | }
606 | // Save the K-means classes
607 | for (a = 0; a < vocab_size; a++) fprintf(fo, "%s %d\n", vocab[a].word, cl[a]);
608 | free(centcn);
609 | free(cent);
610 | free(cl);
611 | }
612 | fclose(fo);
613 | }
614 |
615 | int ArgPos(char *str, int argc, char **argv) {
616 | int a;
617 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
618 | if (a == argc - 1) {
619 | printf("Argument missing for %s\n", str);
620 | exit(1);
621 | }
622 | return a;
623 | }
624 | return -1;
625 | }
626 |
627 | int main(int argc, char **argv) {
628 | int i;
629 | if (argc == 1) {
630 | printf("WORD VECTOR estimation toolkit v 0.1c\n\n");
631 | printf("Options:\n");
632 | printf("Parameters for training:\n");
633 | printf("\t-train \n");
634 | printf("\t\tUse text data from to train the model\n");
635 | printf("\t-output \n");
636 | printf("\t\tUse to save the resulting word vectors / word clusters\n");
637 | printf("\t-size \n");
638 | printf("\t\tSet size of word vectors; default is 100\n");
639 | printf("\t-window \n");
640 | printf("\t\tSet max skip length between words; default is 5\n");
641 | printf("\t-sample \n");
642 | printf("\t\tSet threshold for occurrence of words. Those that appear with higher frequency in the training data\n");
643 | printf("\t\twill be randomly down-sampled; default is 1e-3, useful range is (0, 1e-5)\n");
644 | printf("\t-hs \n");
645 | printf("\t\tUse Hierarchical Softmax; default is 0 (not used)\n");
646 | printf("\t-negative \n");
647 | printf("\t\tNumber of negative examples; default is 5, common values are 3 - 10 (0 = not used)\n");
648 | printf("\t-threads \n");
649 | printf("\t\tUse threads (default 12)\n");
650 | printf("\t-iter \n");
651 | printf("\t\tRun more training iterations (default 5)\n");
652 | printf("\t-min-count \n");
653 | printf("\t\tThis will discard words that appear less than times; default is 5\n");
654 | printf("\t-alpha \n");
655 | printf("\t\tSet the starting learning rate; default is 0.025 for skip-gram and 0.05 for CBOW\n");
656 | printf("\t-classes \n");
657 | printf("\t\tOutput word classes rather than word vectors; default number of classes is 0 (vectors are written)\n");
658 | printf("\t-debug \n");
659 | printf("\t\tSet the debug mode (default = 2 = more info during training)\n");
660 | printf("\t-binary \n");
661 | printf("\t\tSave the resulting vectors in binary moded; default is 0 (off)\n");
662 | printf("\t-save-vocab \n");
663 | printf("\t\tThe vocabulary will be saved to \n");
664 | printf("\t-read-vocab \n");
665 | printf("\t\tThe vocabulary will be read from , not constructed from the training data\n");
666 | printf("\t-cbow \n");
667 | printf("\t\tUse the continuous bag of words model; default is 1 (use 0 for skip-gram model)\n");
668 | printf("\nExamples:\n");
669 | printf("./word2vec -train data.txt -output vec.txt -size 200 -window 5 -sample 1e-4 -negative 5 -hs 0 -binary 0 -cbow 1 -iter 3\n\n");
670 | return 0;
671 | }
672 | output_file[0] = 0;
673 | save_vocab_file[0] = 0;
674 | read_vocab_file[0] = 0;
675 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]);
676 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]);
677 | if ((i = ArgPos((char *)"-save-vocab", argc, argv)) > 0) strcpy(save_vocab_file, argv[i + 1]);
678 | if ((i = ArgPos((char *)"-read-vocab", argc, argv)) > 0) strcpy(read_vocab_file, argv[i + 1]);
679 | if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]);
680 | if ((i = ArgPos((char *)"-binary", argc, argv)) > 0) binary = atoi(argv[i + 1]);
681 | if ((i = ArgPos((char *)"-cbow", argc, argv)) > 0) cbow = atoi(argv[i + 1]);
682 | if (cbow) alpha = 0.05;
683 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]);
684 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]);
685 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]);
686 | if ((i = ArgPos((char *)"-sample", argc, argv)) > 0) sample = atof(argv[i + 1]);
687 | if ((i = ArgPos((char *)"-hs", argc, argv)) > 0) hs = atoi(argv[i + 1]);
688 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]);
689 | if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]);
690 | if ((i = ArgPos((char *)"-iter", argc, argv)) > 0) iter = atoi(argv[i + 1]);
691 | if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]);
692 | if ((i = ArgPos((char *)"-classes", argc, argv)) > 0) classes = atoi(argv[i + 1]);
693 | vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word));
694 | vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int));
695 | expTable = (real *)malloc((EXP_TABLE_SIZE + 1) * sizeof(real));
696 | for (i = 0; i < EXP_TABLE_SIZE; i++) {
697 | expTable[i] = exp((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table
698 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1)
699 | }
700 | TrainModel();
701 | return 0;
702 | }
--------------------------------------------------------------------------------