├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── batch.py ├── evaluation.py ├── layers.py ├── losses.py ├── model.py ├── reader.py ├── segmenter.py ├── toolbox.py └── transducer_model.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Segmenter 2 | Universal segmenter, written by Y. Shao, Uppsala University 3 | 4 | ### News 5 | 6 | We published a TACL paper on more information and analysis of the segmenter. (https://transacl.org/ojs/index.php/tacl/article/viewFile/1446/315) 7 | 8 | The segmenter achieves the best overall word segmentation accuracy and second best overall sentence segmentaton accuracy at the ConLL2018 shared task. (http://universaldependencies.org/conll18/results-words.html) 9 | 10 | The segmenter is applied to the MLP 2017 shared tasks (http://mlp.computing.dcu.ie/mlp2017_Shared_Task.html) and achieved outstanding results on all the datasets. (2017.8.13) 11 | 12 | ## Universal Dependencies 13 | 14 | ### Training 15 | 16 | #### Segmentation: 17 | 18 | python segmenter.py train -p ud-treebanks-conll2017/UD_English -m seg_Eng 19 | 20 | #### Joint sentence segmentation: 21 | 22 | python segmenter.py train -p ud-treebanks-conll2017/UD_English -ss -m ss_seg_Eng 23 | 24 | ### Decoding 25 | 26 | python segmenter.py tag -p ud-treebanks-conll2017/UD_English -m ss_seg_Eng -r ud-raw/en_pud.txt -opth tokenized/en_pud.txt 27 | 28 | ## MLP 2017 29 | 30 | ### (Single) 31 | 32 | ### Training 33 | 34 | #### For Basque Finnish Kazakh Marathi Uyghur and Farsi 35 | 36 | python segmenter.py train -p mlp/basque -f mlp1 -ng 3 -m basque 37 | 38 | (The training and development sets of Basque are in directory mlp/basque) 39 | 40 | #### For Vietnamese 41 | 42 | python segmenter.py train -p mlp/basque -f mlp1 -ng 3 -sea 43 | 44 | #### For Chinese and Japanese 45 | 46 | python segmenter.py train -p mlp/tchinese -f mlp2 -ng 3 -m tchinese 47 | 48 | ### Decoding 49 | 50 | #### For Basque Finnish Kazakh Marathi Uyghur Farsi and Vietnamese 51 | 52 | python segmenter.py tag -p mlp/basque -f mlp1 -m basque -r testset/basque_raw.txt -opth segmented_mlp/basque_single_out.txt 53 | 54 | #### For Chinese and Japanese 55 | 56 | python segmenter.py tag -p mlp/tchinese -f mlp2 -m tchinese -r testset/tchinese_raw.txt -opth segmented_mlp/tchinese_single_out.txt 57 | 58 | ### (Ensemble) 59 | 60 | ### Training 61 | 62 | python segmenter.py train -p mlp/basque -f mlp1 -ng 3 -m basque_1 63 | 64 | python segmenter.py train -p mlp/basque -f mlp1 -ng 3 -m basque_2 65 | 66 | python segmenter.py train -p mlp/basque -f mlp1 -ng 3 -m basque_3 67 | 68 | python segmenter.py train -p mlp/basque -f mlp1 -ng 3 -m basque_4 69 | 70 | ### Decoding 71 | 72 | python segmenter.py tag -ens -p mlp/basque -f mlp1 -m basque -r testset/basque_raw.txt -opth segmented_mlp/basque_ensemble_out.txt 73 | 74 | ## Reference 75 | 76 | Yan Shao, Christian Hardmeier, and Joakim Nivre. 2018. Universal word segmentation: Implementation and interpretation. Transactions of the Association for Computational Linguistics 6:421–435. 77 | 78 | https://transacl.org/ojs/index.php/tacl/article/viewFile/1446/315 79 | 80 | 81 | Yan Shao. "Cross-lingual Word Segmentation and Morpheme Segmentation as Sequence Labelling" arXiv preprint arXiv:1709.03756 (2017). 82 | 83 | https://arxiv.org/pdf/1709.03756.pdf 84 | 85 | -------------------------------------------------------------------------------- /batch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import toolbox 4 | import numpy as np 5 | 6 | 7 | def train(sess, model, batch_size, config, lr, lrv, data, dr=None, drv=None, verbose=False): 8 | assert len(data) == len(model) 9 | num_items = len(data) 10 | samples = zip(*data) 11 | random.shuffle(samples) 12 | start_idx = 0 13 | n_samples = len(samples) 14 | model.append(lr) 15 | if dr is not None: 16 | model.append(dr) 17 | while start_idx < len(samples): 18 | if verbose: 19 | print '%d' % (start_idx * 100 / n_samples) + '%' 20 | next_batch_samples = samples[start_idx:start_idx + batch_size] 21 | real_batch_size = len(next_batch_samples) 22 | if real_batch_size < batch_size: 23 | next_batch_samples.extend(samples[:batch_size - real_batch_size]) 24 | holders = [] 25 | for item in range(num_items): 26 | holders.append([s[item] for s in next_batch_samples]) 27 | holders.append(lrv) 28 | if dr is not None: 29 | holders.append(drv) 30 | sess.run(config, feed_dict={m: h for m, h in zip(model, holders)}) 31 | start_idx += batch_size 32 | 33 | 34 | def softmax(x): 35 | dim = len(list(x.shape)) - 1 36 | anp = np.exp(x - np.max(x, axis=dim, keepdims=True)) 37 | return anp / np.sum(anp, axis=dim, keepdims=True) 38 | 39 | 40 | def predict(sess, model, data, dr=None, transitions=None, crf=True, decode_sess=None, scores=None, decode_holders=None, 41 | argmax=True, batch_size=100, ensemble=False, verbose=False): 42 | en_num = None 43 | if ensemble: 44 | en_num = len(sess) 45 | num_items = len(data) 46 | input_v = model[:num_items] 47 | if dr is not None: 48 | input_v.append(dr) 49 | predictions = model[num_items:] 50 | output = [[] for _ in range(len(predictions))] 51 | samples = zip(*data) 52 | start_idx = 0 53 | n_samples = len(samples) 54 | if crf > 0: 55 | trans = [] 56 | for i in range(len(predictions)): 57 | if ensemble: 58 | en_trans = 0 59 | for en_sess in sess: 60 | en_trans += en_sess.run(transitions[i]) 61 | trans.append(en_trans/en_num) 62 | else: 63 | trans.append(sess.run(transitions[i])) 64 | while start_idx < n_samples: 65 | if verbose: 66 | print '%d' % (start_idx*100/n_samples) + '%' 67 | next_batch_input = samples[start_idx:start_idx + batch_size] 68 | batch_size = len(next_batch_input) 69 | holders= [] 70 | for item in range(num_items): 71 | holders.append([s[item] for s in next_batch_input]) 72 | if dr is not None: 73 | holders.append(0.0) 74 | length = np.sum(np.sign(holders[0]), axis=1) 75 | if crf > 0: 76 | assert transitions is not None and len(transitions) == len(predictions) and len(scores) == len(decode_holders) 77 | for i in range(len(predictions)): 78 | if ensemble: 79 | en_obs = 0 80 | for en_sess in sess: 81 | en_obs += en_sess.run(predictions[i], feed_dict={i: h for i, h in zip(input_v, holders)}) 82 | ob = en_obs/en_num 83 | else: 84 | ob = sess.run(predictions[i], feed_dict={i: h for i, h in zip(input_v, holders)}) 85 | pre_values = [ob, trans[i], length, batch_size] 86 | assert len(pre_values) == len(decode_holders[i]) 87 | max_scores, max_scores_pre = decode_sess.run(scores[i], feed_dict={i: h for i, h in zip(decode_holders[i], pre_values)}) 88 | output[i].extend(toolbox.viterbi(max_scores, max_scores_pre, length, batch_size)) 89 | elif argmax: 90 | for i in range(len(predictions)): 91 | pre = sess.run(predictions[i], feed_dict={i: h for i, h in zip(input_v, holders)}) 92 | dim_axis = len(list(pre.shape)) - 1 93 | if argmax is True: 94 | pre = np.argmax(pre, axis= dim_axis) 95 | else: 96 | pre = softmax(pre) 97 | pre[:, :, 0][pre[:, :, 0] > argmax] = 1 98 | pre[:, :, 0][pre[:, :, 0] <= argmax] = 0 99 | pre = np.argmax(pre, axis=dim_axis) 100 | pre = pre.tolist() 101 | if dim_axis > 1: 102 | pre = toolbox.trim_output(pre, length) 103 | output[i].extend(pre) 104 | else: 105 | for i in range(len(predictions)): 106 | pre = sess.run(predictions[i], feed_dict={i: h for i, h in zip(input_v, holders)}) 107 | #pre = softmax(pre) 108 | dim_axis = len(list(pre.shape)) - 1 109 | if dim_axis > 1: 110 | pre = toolbox.trim_output(pre, length) 111 | output[i].extend(pre) 112 | start_idx += batch_size 113 | return output 114 | 115 | 116 | def train_seq2seq(sess, model, decoding, batch_size, config, lr, lrv, data, dr=None, drv=None, verbose=False): 117 | #assert len(data) == len(model) 118 | samples = zip(*data) 119 | random.shuffle(samples) 120 | start_idx = 0 121 | n_samples = len(samples) 122 | model.append(lr) 123 | model.append(decoding) 124 | if dr is not None: 125 | model.append(dr) 126 | while start_idx < len(samples): 127 | if verbose: 128 | print '%d' % (start_idx * 100 / n_samples) + '%' 129 | next_batch_samples = samples[start_idx:start_idx + batch_size] 130 | real_batch_size = len(next_batch_samples) 131 | if real_batch_size < batch_size: 132 | next_batch_samples.extend(samples[:batch_size - real_batch_size]) 133 | holders = [] 134 | next_batch_samples = zip(*next_batch_samples) 135 | for n_batch in next_batch_samples: 136 | n_batch = np.asarray(n_batch).T 137 | for b in n_batch: 138 | holders.append(b) 139 | holders.append(lrv) 140 | holders.append(False) 141 | if dr is not None: 142 | holders.append(drv) 143 | sess.run(config, feed_dict={m: h for m, h in zip(model, holders)}) 144 | start_idx += batch_size 145 | 146 | 147 | def predict_seq2seq(sess, model, decoding, data, decode_len, dr=None, argmax=True, batch_size=100, ensemble=False, verbose=False): 148 | num_items = len(data) 149 | in_len = len(data[0][0]) 150 | input_v = model[:num_items*in_len + decode_len] 151 | input_v.append(decoding) 152 | if dr is not None: 153 | input_v.append(dr) 154 | predictions = model[num_items*in_len + decode_len:] 155 | output = [] 156 | samples = zip(*data) 157 | start_idx = 0 158 | n_samples = len(samples) 159 | while start_idx < n_samples: 160 | if verbose: 161 | print '%d' % (start_idx * 100 / n_samples) + '%' 162 | next_batch_input = samples[start_idx:start_idx + batch_size] 163 | batch_size = len(next_batch_input) 164 | holders = [] 165 | next_batch_input = zip(*next_batch_input) 166 | for n_batch in next_batch_input: 167 | n_batch = np.asarray(n_batch).T 168 | for b in n_batch: 169 | holders.append(b) 170 | for i in range(decode_len): 171 | holders.append(np.zeros(batch_size, dtype='int32')) 172 | holders.append(True) 173 | if dr is not None: 174 | holders.append(0.0) 175 | if argmax: 176 | pre = sess.run(predictions, feed_dict={i: h for i, h in zip(input_v, holders)}) 177 | pre = [np.argmax(pre_t, axis=1) for pre_t in pre] 178 | pre = np.asarray(pre).T.tolist() 179 | pre = [np.trim_zeros(pre_t) for pre_t in pre] 180 | output += pre 181 | else: 182 | pre = sess.run(predictions, feed_dict={i: h for i, h in zip(input_v, holders)}) 183 | output += pre 184 | start_idx += batch_size 185 | return output -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | 4 | 5 | def eq_tokens(gt, st): 6 | if gt.strip() in ["``", "''"] and st.strip() == '"': 7 | return True 8 | else: 9 | return gt == st 10 | 11 | lcs = {} 12 | 13 | 14 | def score(gtokens, stokens): 15 | for i in range(0, len(gtokens) + 1): 16 | lcs[(i, 0)] = 0 17 | for j in range(0, len(stokens) + 1): 18 | lcs[(0, j)] = 0 19 | for i in range(1, len(gtokens) + 1): 20 | for j in range(1, len(stokens) + 1): 21 | if eq_tokens(gtokens[i - 1], stokens[j - 1]): 22 | lcs[(i, j)] = lcs[(i - 1, j - 1)] + 1 23 | else: 24 | if lcs[(i - 1, j)] >= lcs[(i, j - 1)]: 25 | lcs[(i, j)] = lcs[(i - 1, j)] 26 | else: 27 | lcs[(i, j)] = lcs[(i, j - 1)] 28 | 29 | tp = lcs[(len(gtokens), len(stokens))] 30 | g = len(gtokens) 31 | s = len(stokens) 32 | 33 | #precision = float(tp)/float(s) 34 | #recall = float(tp)/float(g) 35 | #f1score=(2*precision*recall)/(precision + recall) 36 | 37 | return tp, g, s 38 | 39 | 40 | def exact_match(g, pre): 41 | acc = 0 42 | for g_t, p_t in zip(g, pre): 43 | if g_t == p_t: 44 | acc += 1 45 | return acc 46 | 47 | 48 | def evaluator(prediction, gold, prediction_raw=None, gold_raw=None, verbose=False): 49 | if prediction_raw is None or gold_raw is None: 50 | prediction = prediction[0] 51 | assert len(prediction) == len(gold) 52 | tp, s, g, pp, rp, tw = 0, 0, 0, 0, 0, 0 53 | for pre, gd in zip(prediction, gold): 54 | pre_t = pre.split(' ') 55 | gd_t = gd.split(' ') 56 | tt, t_g, t_s = score(gd_t, pre_t) 57 | tp += tt 58 | g += t_g 59 | s += t_s 60 | if verbose: 61 | pp += len(pre_t) 62 | rp += len(gd_t) 63 | sl = len(''.join(pre_t)) 64 | tw += (1 + sl) * sl / 2 65 | precision = float(tp) / float(s) 66 | recall = float(tp) / float(g) 67 | if precision == 0 and recall == 0: 68 | f1score = 0 69 | else: 70 | f1score = (2 * precision * recall) / (precision + recall) 71 | 72 | if verbose: 73 | tnr = 1 - float(pp - tp) / float(tw - rp) 74 | return precision, recall, f1score, tnr 75 | else: 76 | return precision, recall, f1score 77 | else: 78 | prediction = copy.copy(prediction[0]) 79 | prediction_raw = copy.copy(prediction_raw[0]) 80 | gold = copy.copy(gold) 81 | gold_raw = copy.copy(gold_raw) 82 | prediction_raw = ["".join(pre.split()) for pre in prediction_raw] 83 | gold_raw = ["".join(gd.split()) for gd in gold_raw] 84 | assert len(prediction) == len(prediction_raw) 85 | assert len(gold) == len(gold_raw) 86 | n_prediction = len(prediction) 87 | n_gold = len(gold) 88 | correct = 0 89 | l_prediction = 0 90 | l_gold = 0 91 | pre_tokens = [] 92 | gd_tokens = [] 93 | tp, s, g = 0, 0, 0 94 | last_correct = True 95 | while prediction_raw and gold_raw and prediction and gold: 96 | if prediction_raw[0] == gold_raw[0] or (last_correct and len(prediction_raw[0]) == len(gold_raw[0])): # words right 97 | correct += 1 98 | l_prediction += len(prediction_raw[0]) # move 99 | l_gold += len(gold_raw[0]) 100 | pre_tokens = prediction[0].split(' ') 101 | gd_tokens = gold[0].split(' ') 102 | tt, t_g, t_s = score(gd_tokens, pre_tokens) 103 | tp += tt 104 | g += t_g 105 | s += t_s 106 | pre_tokens = [] 107 | gd_tokens = [] 108 | prediction_raw.pop(0) 109 | gold_raw.pop(0) 110 | prediction.pop(0) 111 | gold.pop(0) 112 | last_correct = True 113 | else: 114 | if l_prediction == l_gold: 115 | if len(gd_tokens) < 1000: 116 | tt, t_g, t_s = score(gd_tokens, pre_tokens) 117 | tp += tt 118 | g += t_g 119 | s += t_s 120 | else: 121 | g += len(gd_tokens) 122 | s += len(pre_tokens) 123 | l_prediction += len(prediction_raw[0]) # move 124 | l_gold += len(gold_raw[0]) 125 | pre_tokens = prediction[0].split(' ') 126 | gd_tokens = gold[0].split(' ') 127 | prediction_raw.pop(0) 128 | gold_raw.pop(0) 129 | prediction.pop(0) 130 | gold.pop(0) 131 | last_correct = False 132 | elif l_prediction < l_gold: 133 | l_prediction += len(prediction_raw[0]) 134 | pre_tokens += prediction[0].split(' ') 135 | prediction_raw.pop(0) 136 | prediction.pop(0) 137 | last_correct = False 138 | elif l_prediction > l_gold: 139 | gd_tokens += gold[0].split(' ') 140 | l_gold += len(gold_raw[0]) # move 141 | gold_raw.pop(0) 142 | gold.pop(0) 143 | last_correct = False 144 | 145 | if correct > 0: 146 | sent_precision = float(correct) / float(n_prediction) 147 | sent_recall = float(correct) / float(n_gold) 148 | sent_f1score = (2 * sent_precision * sent_recall) / (sent_precision + sent_recall) 149 | else: 150 | sent_precision = 0 151 | sent_recall = 0 152 | sent_f1score = 0 153 | 154 | if tp > 0: 155 | precision = float(tp) / float(s) 156 | recall = float(tp) / float(g) 157 | f1score = (2 * precision * recall) / (precision + recall) 158 | else: 159 | precision = 0 160 | recall = 0 161 | f1score = 0 162 | 163 | return precision, recall, f1score, sent_precision, sent_recall, sent_f1score 164 | 165 | 166 | def sent_evaluator(prediction, gold): 167 | prediction = ["".join(pre.split()) for pre in prediction] 168 | gold = ["".join(gd.split()) for gd in gold] 169 | n_prediction = len(prediction) 170 | n_gold = len(gold) 171 | correct = 0 172 | l_prediction = 0 173 | l_gold = 0 174 | while prediction and gold: 175 | if prediction[0] == gold[0]: # words right 176 | correct += 1 177 | l_prediction = 0 # move 178 | l_gold = 0 179 | prediction.pop(0) 180 | gold.pop(0) 181 | else: 182 | if l_prediction == l_gold: 183 | l_prediction = 0 # move 184 | l_gold = 0 185 | prediction.pop(0) 186 | gold.pop(0) 187 | elif l_prediction < l_gold: 188 | l_prediction += len(prediction[0]) 189 | prediction.pop(0) 190 | elif l_prediction > l_gold: 191 | l_gold += len(gold[0]) # move 192 | gold.pop(0) 193 | if correct > 0: 194 | precision = float(correct) / float(n_prediction) 195 | recall = float(correct) / float(n_gold) 196 | f1score = (2 * precision * recall) / (precision + recall) 197 | else: 198 | precision = 0 199 | recall = 0 200 | f1score = 0 201 | return precision, recall, f1score 202 | 203 | 204 | def trans_evaluator(prediction, gold): 205 | assert len(prediction) == len(gold) 206 | acc = float(exact_match(prediction, gold))/len(prediction) 207 | tp, s, g = 0, 0, 0 208 | for pre, gd in zip(prediction, gold): 209 | pre_t = pre.split(' ') 210 | gd_t = gd.split(' ') 211 | tt, t_g, t_s = score(gd_t, pre_t) 212 | tp += tt 213 | g += t_g 214 | s += t_s 215 | precision = float(tp) / float(s) 216 | recall = float(tp) / float(g) 217 | if precision == 0 and recall == 0: 218 | f1score = 0 219 | else: 220 | f1score = (2 * precision * recall) / (precision + recall) 221 | return acc, f1score 222 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | 6 | 7 | class HiddenLayer(object): 8 | """ 9 | Hidden layer with or without bias. 10 | Input: tensor of dimension (dims*, input_dim) 11 | Output: tensor of dimension (dims*, output_dim) 12 | """ 13 | def __init__(self, input_dim, output_dim, bias=True, activation='tanh', name='hidden_layer'): 14 | """ 15 | :param input_dim: 16 | :param output_dim: 17 | :param bias: 18 | :param activation: 19 | :param name: 20 | :return: 21 | """ 22 | 23 | self.input_dim = input_dim 24 | self.output_dim = output_dim 25 | self.is_bias = bias 26 | self.name = name 27 | if activation == 'linear': 28 | self.activation = None 29 | elif activation == 'tanh': 30 | self.activation = tf.tanh 31 | elif activation == 'sigmoid': 32 | self.activation = tf.sigmoid 33 | elif activation == 'softmax': 34 | self.activation = tf.nn.softmax 35 | elif activation == 'relu': 36 | self.activation = tf.nn.relu 37 | elif activation is not None: 38 | raise Exception('Unknown activation function: ' % activation) 39 | 40 | #Initialise weights and bias 41 | rand_uniform_init = tf.contrib.layers.xavier_initializer() 42 | self.weights = tf.get_variable(name + '_weights', [input_dim, output_dim], initializer=rand_uniform_init) 43 | self.bias = tf.get_variable(name + '_bias', [output_dim], initializer=tf.constant_initializer(0.0)) 44 | 45 | #define parameters 46 | if self.is_bias: 47 | self.params = [self.weights, self.bias] 48 | else: 49 | self.params = [self.weights] 50 | 51 | def __call__(self, input_t): 52 | """ 53 | :param input_t: 54 | :return: 55 | """ 56 | input_shape = input_t.get_shape().as_list() 57 | input_t = tf.reshape(input_t, [-1, input_shape[-1]]) 58 | linear = tf.matmul(input_t, self.weights) 59 | self.linear = tf.reshape(linear, [-1] + input_shape[1:-1] + [self.output_dim]) 60 | if self.is_bias: 61 | self.linear += self.bias 62 | if self.activation is None: 63 | self.output = self.linear 64 | else: 65 | self.output = self.activation(self.linear) 66 | return self.output 67 | 68 | 69 | class EmbeddingLayer(object): 70 | """ 71 | Embedding layer to map input into word representations 72 | Input: tensor of dimension (dim*) with values in range(0, input_dim) 73 | Output: tensor of dimension (dim*, output_dim) 74 | """ 75 | def __init__(self, input_dim, output_dim, weights=None, is_variable=False, trainable=True, name='embedding_layer'): 76 | """ 77 | :param input_dim: 78 | :param output_dim: 79 | :param name: 80 | """ 81 | self.input_dim = input_dim 82 | self.output_dim = output_dim 83 | self.name = name 84 | self.trainable = trainable 85 | self.weights = weights 86 | 87 | # Generate random embeddings or read pre-trained embeddings 88 | rand_uniform_init = tf.contrib.layers.xavier_initializer() 89 | if self.weights is None: 90 | self.embeddings = tf.get_variable(self.name + '_emb', [self.input_dim, self.output_dim], 91 | initializer=rand_uniform_init, trainable=self.trainable) 92 | elif is_variable: 93 | self.embeddings = weights 94 | else: 95 | emb_count = len(weights) 96 | if emb_count < input_dim: 97 | pad_weights = np.zeros([self.input_dim - emb_count, self.output_dim], dtype='float32') 98 | self.weights = np.concatenate((self.weights, pad_weights), axis=0) 99 | self.embeddings = tf.get_variable(self.name + '_emb', initializer=self.weights, trainable=self.trainable) 100 | #Define Parameters 101 | self.params = [self.embeddings] 102 | self.weight_name = self.name + '_emb' 103 | 104 | def __call__(self, input_t): 105 | """ 106 | return the embeddings of the given indexes 107 | :param input: 108 | :return: 109 | """ 110 | self.input = input_t 111 | self.output = tf.gather(self.embeddings, self.input) 112 | return self.output 113 | 114 | 115 | class Convolution(object): 116 | ''' 117 | Regular convolutional layer 118 | ''' 119 | @staticmethod 120 | def weight_variable(shape, name): 121 | initial = tf.truncated_normal(shape, stddev=0.1) 122 | return tf.Variable(initial, name=name) 123 | 124 | @staticmethod 125 | def bias_variable(shape, name): 126 | initial = tf.constant(0.1, shape=shape, name=name) 127 | return tf.Variable(initial) 128 | 129 | def __init__(self, conv_width, in_channels, out_channels, stride=1, dim=2, padding='SAME', 130 | name='convolutional_layer'): 131 | self.in_channels = in_channels 132 | self.out_channels = out_channels 133 | self.dim = dim 134 | if dim == 1: 135 | self.strides = [1, stride, 1] 136 | else: 137 | self.strides = [1, stride, stride, 1] 138 | self.padding = padding 139 | self.name = name 140 | self.conv_width = conv_width 141 | if dim == 1: 142 | self.w_conv = self.weight_variable([self.conv_width, self.in_channels, self.out_channels], 143 | name=self.name + '_w') 144 | else: 145 | self.w_conv = self.weight_variable([self.conv_width, self.conv_width, self.in_channels, self.out_channels], 146 | name=self.name + '_w') 147 | self.b_conv = self.bias_variable([self.out_channels], name=self.name + '_b') 148 | 149 | def conv2d(self, x, W): 150 | return tf.nn.conv2d(x, W, strides=self.strides, padding=self.padding) 151 | 152 | def conv1d(self, x, W): 153 | return tf.nn.conv1d(x, W, stride=self.strides, padding=self.padding) 154 | 155 | def __call__(self, input_t): 156 | if self.dim == 1: 157 | return tf.nn.relu(self.conv1d(input_t, self.w_conv) + self.b_conv) 158 | else: 159 | return tf.nn.relu(self.conv2d(input_t, self.w_conv) + self.b_conv) 160 | 161 | 162 | 163 | class Maxpooling(object): 164 | ''' 165 | Maxpooling layer 166 | ''' 167 | def __init__(self, pooling_size, stride=1, padding='SAME', name='pooling_layer'): 168 | self.padding = padding 169 | self.name = name 170 | self.ksize = [1, pooling_size, pooling_size, 1] 171 | 172 | def __call__(self, input_v): 173 | return tf.nn.max_pool(input_v, ksize=self.ksize, strides=self.ksize, padding='SAME') 174 | 175 | 176 | class DropoutLayer(object): 177 | """ 178 | Dropout layer 179 | """ 180 | def __init__(self, p=0.5, name='dropout_layer'): 181 | """ 182 | :param p: dropout rate 183 | :param name: 184 | """ 185 | #assert 0. <= p < 1 186 | self.p = p 187 | self.name = name 188 | 189 | def __call__(self, input_t): 190 | self.input = input_t 191 | return tf.nn.dropout(self.input, keep_prob=1 - self.p, name=self.name) 192 | 193 | 194 | class BiLSTM(object): 195 | """ 196 | Bidirectional LSTM 197 | """ 198 | def __init__(self, cell_dim, nums_layers=1, p=0.5, fw_cell=None, bw_cell=None, state=False, name='biLSTM', 199 | scope=None): 200 | """ 201 | :param cell_dim: 202 | :param nums_steps: 203 | :param nums_layers: 204 | :param p: 205 | :param name: 206 | """ 207 | self.cell_dim = cell_dim 208 | self.nums_layers = nums_layers 209 | self.p = p 210 | self.state = state 211 | self.name = name 212 | self.scope = scope 213 | if fw_cell is None: 214 | self.lstm_cell_fw = tf.nn.rnn_cell.LSTMCell(self.cell_dim, state_is_tuple=True) 215 | else: 216 | self.lstm_cell_fw = fw_cell 217 | if bw_cell is None: 218 | self.lstm_cell_bw = tf.nn.rnn_cell.LSTMCell(self.cell_dim, state_is_tuple=True) 219 | else: 220 | self.lstm_cell_bw = bw_cell 221 | #assert 0. <= p < 1 222 | 223 | def __call__(self, input_t, input_ids): 224 | self.input = input_t 225 | self.input_ids = input_ids 226 | #if self.p > 0.: 227 | self.lstm_cell_fw = tf.nn.rnn_cell.DropoutWrapper(self.lstm_cell_fw, output_keep_prob=(1 - self.p)) 228 | self.lstm_cell_bw = tf.nn.rnn_cell.DropoutWrapper(self.lstm_cell_bw, output_keep_prob=(1 - self.p)) 229 | if self.nums_layers > 1: 230 | self.lstm_cell_fw = tf.nn.rnn_cell.MultiRNNCell([self.lstm_cell_fw] * self.nums_layers) 231 | self.lstm_cell_bw = tf.nn.rnn_cell.MultiRNNCell([self.lstm_cell_bw] * self.nums_layers) 232 | self.length = tf.reduce_sum(tf.sign(self.input_ids), axis=1) 233 | self.length = tf.cast(self.length, dtype=tf.int32) 234 | int_states, final_states = tf.nn.bidirectional_dynamic_rnn(self.lstm_cell_fw, self.lstm_cell_bw, self.input, 235 | sequence_length=self.length, dtype=tf.float32, 236 | scope=self.scope) 237 | self.output = tf.concat(values=int_states, axis=2) 238 | if self.state: 239 | return self.output, final_states 240 | else: 241 | return self.output 242 | 243 | 244 | 245 | class TimeDistributed(object): 246 | """ 247 | Time-distributed wrapper for layers 248 | """ 249 | def __init__(self, layer, name='Time-distributed Wrapper'): 250 | self.layer = layer 251 | self.name = name 252 | 253 | def __call__(self, input_t, input_ids=None, pad=None): 254 | self.input = tf.unstack(input_t, axis=1) 255 | if input_ids is None: 256 | self.out = [self.layer(splits) for splits in self.input] 257 | else: 258 | self.out = [] 259 | pad = self.layer(self.input[0])*0 260 | masks = tf.reduce_sum(input_ids, axis=0) 261 | length = len(self.input) 262 | for i in range(length): 263 | r = tf.cond(tf.greater(masks[i], 0), lambda: self.layer(input_t[i]), lambda: pad) 264 | self.out.append(r) 265 | self.out = tf.stack(self.out, axis=1) 266 | return self.out 267 | 268 | 269 | class Forward(object): 270 | """ 271 | forward algorithm for the CRF loss 272 | """ 273 | def __init__(self, observations, transitions, nums_tags, length, batch_size, viterbi=True): 274 | self.observations = observations 275 | self.transitions = transitions 276 | self.viterbi = viterbi 277 | self.length = length 278 | self.batch_size = batch_size 279 | self.nums_tags = nums_tags 280 | self.nums_steps = observations.get_shape().as_list()[1] 281 | 282 | @staticmethod 283 | def log_sum_exp(x, axis=None): 284 | """ 285 | Sum probabilities in the log-space 286 | :param x: 287 | :param axis: 288 | :return: 289 | """ 290 | x_max = tf.reduce_max(x, axis=axis, keepdims=True) 291 | x_max_ = tf.reduce_max(x, axis=axis) 292 | return x_max_ + tf.log(tf.reduce_sum(tf.exp(x - x_max), axis=axis)) 293 | 294 | def __call__(self): 295 | small = -1000 296 | class_pad = tf.stack(small * tf.ones([self.batch_size, self.nums_steps, 1])) 297 | self.observations = tf.concat(axis=2, values=[self.observations, class_pad]) 298 | b_vec = tf.cast(tf.stack(([small] * self.nums_tags + [0]) * self.batch_size), tf.float32) 299 | b_vec = tf.reshape(b_vec, [self.batch_size, 1, -1]) 300 | #e_vec = tf.cast(tf.pack(([0] + [small] * self.nums_tags) * self.batch_size), tf.float32) 301 | #e_vec = tf.reshape(e_vec, [self.batch_size, 1, -1]) 302 | self.observations = tf.concat(axis=1, values=[b_vec, self.observations]) 303 | self.transitions = tf.reshape(tf.tile(self.transitions, [self.batch_size, 1]), 304 | [self.batch_size, self.nums_tags + 1, self.nums_tags + 1]) 305 | self.observations = tf.reshape(self.observations, [-1, self.nums_steps + 1, self.nums_tags + 1, 1]) 306 | self.observations = tf.transpose(self.observations, [1, 0, 2, 3]) 307 | previous = self.observations[0, :, :, :] 308 | max_scores = [] 309 | max_scores_pre = [] 310 | alphas = [previous] 311 | for t in range(1, self.nums_steps + 1): 312 | previous = tf.reshape(previous, [-1, self.nums_tags + 1, 1]) 313 | current = tf.reshape(self.observations[t,:, :, :], [-1, 1, self.nums_tags + 1]) 314 | alpha_t = previous + current + self.transitions 315 | if self.viterbi: 316 | max_scores.append(tf.reduce_max(alpha_t, axis=1)) 317 | max_scores_pre.append(tf.argmax(alpha_t, axis=1)) 318 | alpha_t = tf.reshape(self.log_sum_exp(alpha_t, axis=1), [-1, self.nums_tags + 1, 1]) 319 | alphas.append(alpha_t) 320 | previous = alpha_t 321 | alphas = tf.stack(alphas, axis=1) 322 | alphas = tf.reshape(alphas, [-1, self.nums_tags + 1, 1]) 323 | last_alphas = tf.gather(alphas, tf.range(0, self.batch_size) * (self.nums_steps + 1) + self.length) 324 | last_alphas = tf.reshape(last_alphas, [self.batch_size, self.nums_tags + 1, 1]) 325 | max_scores = tf.stack(max_scores, axis=1) 326 | max_scores_pre = tf.stack(max_scores_pre, axis=1) 327 | return tf.reduce_sum(self.log_sum_exp(last_alphas, axis=1)), max_scores, max_scores_pre 328 | 329 | 330 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | 4 | from layers import Forward 5 | 6 | 7 | def softmax_cross_entropy(y, y_, nums_tags): 8 | one_hot_y_ = tf.contrib.layers.one_hot_encoding(y_, nums_tags) 9 | one_hot_y_ = tf.reshape(one_hot_y_, [-1, nums_tags]) 10 | y = tf.reshape(y, [-1, nums_tags]) 11 | return tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=one_hot_y_) 12 | #return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 13 | 14 | 15 | def cross_entropy(y, y_): 16 | return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), axis=[1])) 17 | 18 | 19 | def mean_square(y, y_): 20 | return tf.reduce_mean(tf.square(y_ - y)) 21 | 22 | 23 | def sparse_cross_entropy(y, y_): 24 | return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=y_) 25 | 26 | 27 | def sparse_cross_entropy_with_weights(y, y_, weights= None, average_cross_steps=True): 28 | if weights is None: 29 | weights = tf.cast(tf.sign(y_), tf.float32) 30 | out = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=y_) 31 | if average_cross_steps: 32 | weights_sum = tf.reduce_sum(weights, axis=0) 33 | return out*weights/(weights_sum + 1e-12) 34 | else: 35 | return out*weights 36 | 37 | 38 | def sequence_loss_by_example(logits, targets, weights=None, average_across_timesteps=True, softmax_loss_function=None, name=None): 39 | """Weighted cross-entropy loss for a sequence of logits (per example). 40 | Args: 41 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 42 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 43 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 44 | average_across_timesteps: If set, divide the returned cost by the total label weight. 45 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 46 | to be used instead of the standard softmax (the default if this is None). 47 | name: Optional name for this operation, default: "sequence_loss_by_example". 48 | Returns: 49 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 50 | Raises: 51 | ValueError: If len(logits) is different from len(targets) or len(weights). 52 | """ 53 | if len(targets) != len(logits) or len(weights) != len(logits): 54 | raise ValueError("Lengths of logits, weights, and targets must be the same " "%d, %d, %d." % (len(logits), len(weights), len(targets))) 55 | with tf.name_scope(name + "sequence_loss_by_example"): 56 | log_perp_list = [] 57 | for logit, target, weight in zip(logits, targets, weights): 58 | if softmax_loss_function is None: 59 | target = tf.reshape(target, [-1]) 60 | crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit, labels=target) 61 | else: 62 | crossent = softmax_loss_function(logit, target) 63 | log_perp_list.append(crossent * weight) 64 | log_perps = tf.add_n(log_perp_list) 65 | if average_across_timesteps: 66 | total_size = tf.add_n(weights) 67 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 68 | log_perps /= total_size 69 | return log_perps 70 | 71 | 72 | def crf_loss(y, y_, transitions, nums_tags, batch_size): 73 | tag_scores = y 74 | nums_steps = len(tf.unstack(tag_scores, axis=1)) 75 | masks = tf.cast(tf.sign(y_), dtype=tf.float32) 76 | lengths = tf.reduce_sum(tf.sign(y_), axis=1) 77 | tag_ids = y_ 78 | b_id = tf.stack([[nums_tags]] * batch_size) 79 | #e_id = tf.pack([[0]] * batch_size) 80 | padded_tag_ids = tf.concat(axis=1, values=[b_id, tag_ids]) 81 | idx_tag_ids = tf.stack([tf.slice(padded_tag_ids, [0, i], [-1, 2]) for i in range(nums_steps)], axis=1) 82 | tag_ids = tf.contrib.layers.one_hot_encoding(tag_ids, nums_tags) 83 | point_score = tf.reduce_sum(tag_scores * tag_ids, axis=2) 84 | point_score *= masks 85 | trans_score = tf.gather_nd(transitions, idx_tag_ids) 86 | extend_mask = masks 87 | trans_score *= extend_mask 88 | target_path_score = tf.reduce_sum(point_score) + tf.reduce_sum(trans_score) 89 | total_path_score, _, _ = Forward(tag_scores, transitions, nums_tags, lengths, batch_size)() 90 | return - (target_path_score - total_path_score) 91 | 92 | 93 | def loss_wrapper(y, y_, loss_function, transitions=None, nums_tags=None, batch_size=None, weights=None, average_cross_steps=True): 94 | assert len(y) == len(y_) 95 | total_loss = [] 96 | if loss_function is crf_loss: 97 | #print len(y), len(transitions), len(nums_tags) 98 | assert len(y) == len(transitions) and len(transitions) == len(nums_tags) and batch_size is not None 99 | for sy, sy_, stranstion, snums_tags in zip(y, y_, transitions, nums_tags): 100 | total_loss.append(loss_function(sy, sy_, stranstion, snums_tags, batch_size)) 101 | elif loss_function is cross_entropy: 102 | assert len(y) == len(nums_tags) 103 | for sy, sy_, snums_tags in zip(y, y_, nums_tags): 104 | total_loss.append(loss_function(sy, sy_, snums_tags)) 105 | elif loss_function is sparse_cross_entropy: 106 | for sy, sy_ in zip(y, y_): 107 | total_loss.append(loss_function(sy, sy_)) 108 | elif loss_function is sparse_cross_entropy_with_weights: 109 | assert len(y) == len(nums_tags) 110 | for sy, sy_, snums_tags in zip(y, y_): 111 | total_loss.append(tf.reshape(loss_function(sy, sy_, weights=weights, average_cross_steps=average_cross_steps), [-1])) 112 | else: 113 | for sy, sy_ in zip(y, y_): 114 | total_loss.append(tf.reshape(loss_function(sy, sy_), [-1])) 115 | return tf.stack(total_loss) 116 | 117 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | from layers import EmbeddingLayer, BiLSTM, HiddenLayer, DropoutLayer, Convolution, Maxpooling, Forward 4 | from time import time 5 | import losses 6 | import toolbox 7 | import batch as Batch 8 | import random 9 | import cPickle as pickle 10 | import codecs 11 | import evaluation 12 | 13 | 14 | class Model(object): 15 | 16 | def __init__(self, nums_chars, nums_tags, buckets_char, counts=None, batch_size=10, crf=1, ngram=None, 17 | sent_seg=False, is_space=True, emb_path=None, tag_scheme='BIES'): 18 | self.nums_chars = nums_chars 19 | self.nums_tags = nums_tags 20 | self.buckets_char = buckets_char 21 | self.counts = counts 22 | self.crf = crf 23 | self.ngram = ngram 24 | self.emb_path = emb_path 25 | self.emb_layer = None 26 | self.tag_scheme = tag_scheme 27 | self.gram_layers = [] 28 | self.batch_size = batch_size 29 | self.l_rate = None 30 | self.decay = None 31 | self.train_step = None 32 | self.saver = None 33 | self.decode_holders = None 34 | self.scores = None 35 | self.params = None 36 | self.pixels = None 37 | self.is_space = is_space 38 | self.sent_seg = sent_seg 39 | self.updates = [] 40 | self.bucket_dit = {} 41 | self.input_v = [] 42 | self.input_w = [] 43 | self.input_p = None 44 | self.output = [] 45 | self.output_ = [] 46 | self.output_p = [] 47 | 48 | if self.crf > 0: 49 | self.transition_char = tf.get_variable('transitions_char', [self.nums_tags + 1, self.nums_tags + 1]) 50 | else: 51 | self.transition_char = None 52 | 53 | while len(self.buckets_char) > len(self.counts): 54 | self.counts.append(1) 55 | 56 | self.real_batches = toolbox.get_real_batch(self.counts, self.batch_size) 57 | 58 | def main_graph(self, trained_model, scope, emb_dim, cell, rnn_dim, rnn_num, drop_out=0.5, emb=None): 59 | if trained_model is not None: 60 | param_dic = {'nums_chars': self.nums_chars, 'nums_tags': self.nums_tags, 'crf': self.crf, 'emb_dim':emb_dim, 61 | 'cell': cell, 'rnn_dim': rnn_dim, 'rnn_num': rnn_num, 'drop_out': drop_out, 62 | 'buckets_char': self.buckets_char, 'ngram': self.ngram, 'is_space': self.is_space, 63 | 'sent_seg': self.sent_seg, 'emb_path': self.emb_path, 'tag_scheme': self.tag_scheme} 64 | #print param_dic 65 | f_model = open(trained_model, 'w') 66 | pickle.dump(param_dic, f_model) 67 | f_model.close() 68 | 69 | # define shared weights and variables 70 | 71 | dr = tf.placeholder(tf.float32, [], name='drop_out_holder') 72 | self.drop_out = dr 73 | self.drop_out_v = drop_out 74 | 75 | self.emb_layer = EmbeddingLayer(self.nums_chars + 20, emb_dim, weights=emb, name='emb_layer') 76 | 77 | if self.ngram is not None: 78 | ng_embs = [None for _ in range(len(self.ngram))] 79 | for i, n_gram in enumerate(self.ngram): 80 | self.gram_layers.append(EmbeddingLayer(n_gram + 5000 * (i + 2), emb_dim, weights=ng_embs[i], 81 | name= str(i + 2) + 'gram_layer')) 82 | 83 | with tf.variable_scope('BiRNN'): 84 | 85 | if cell == 'gru': 86 | fw_rnn_cell = tf.nn.rnn_cell.GRUCell(rnn_dim) 87 | bw_rnn_cell = tf.nn.rnn_cell.GRUCell(rnn_dim) 88 | else: 89 | fw_rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_dim, state_is_tuple=True) 90 | bw_rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_dim, state_is_tuple=True) 91 | 92 | if rnn_num > 1: 93 | fw_rnn_cell = tf.nn.rnn_cell.MultiRNNCell([fw_rnn_cell]*rnn_num, state_is_tuple=True) 94 | bw_rnn_cell = tf.nn.rnn_cell.MultiRNNCell([bw_rnn_cell]*rnn_num, state_is_tuple=True) 95 | 96 | output_wrapper = HiddenLayer(rnn_dim * 2, self.nums_tags, activation='linear', name='hidden') 97 | 98 | #define model for each bucket 99 | for idx, bucket in enumerate(self.buckets_char): 100 | if idx == 1: 101 | scope.reuse_variables() 102 | t1 = time() 103 | 104 | input_v = tf.placeholder(tf.int32, [None, bucket], name='input_' + str(bucket)) 105 | 106 | self.input_v.append([input_v]) 107 | 108 | emb_set = [] 109 | 110 | word_out = self.emb_layer(input_v) 111 | emb_set.append(word_out) 112 | 113 | if self.ngram is not None: 114 | for i in range(len(self.ngram)): 115 | input_g = tf.placeholder(tf.int32, [None, bucket], name='input_g' + str(i) + str(bucket)) 116 | self.input_v[-1].append(input_g) 117 | gram_out = self.gram_layers[i](input_g) 118 | emb_set.append(gram_out) 119 | 120 | if len(emb_set) > 1: 121 | emb_out = tf.concat(axis=2, values=emb_set) 122 | 123 | else: 124 | emb_out = emb_set[0] 125 | 126 | emb_out = DropoutLayer(dr)(emb_out) 127 | 128 | rnn_out = BiLSTM(rnn_dim, fw_cell=fw_rnn_cell, bw_cell=bw_rnn_cell, p=dr, name='BiLSTM' + str(bucket), 129 | scope='BiRNN')(emb_out, input_v) 130 | 131 | output = output_wrapper(rnn_out) 132 | 133 | self.output.append([output]) 134 | 135 | self.output_.append([tf.placeholder(tf.int32, [None, bucket], name='tags' + str(bucket))]) 136 | self.bucket_dit[bucket] = idx 137 | 138 | print 'Bucket %d, %f seconds' % (idx + 1, time() - t1) 139 | 140 | assert len(self.input_v) == len(self.output) 141 | 142 | self.params = tf.trainable_variables() 143 | 144 | self.saver = tf.train.Saver() 145 | 146 | def config(self, optimizer, decay, lr_v=None, momentum=None, clipping=True, max_gradient_norm=5.0): 147 | 148 | self.decay = decay 149 | print 'Training preparation...' 150 | 151 | print 'Defining loss...' 152 | loss = [] 153 | if self.crf > 0: 154 | loss_function = losses.crf_loss 155 | for i in range(len(self.input_v)): 156 | bucket_loss = losses.loss_wrapper(self.output[i], self.output_[i], loss_function, 157 | transitions=[self.transition_char], nums_tags=[self.nums_tags], 158 | batch_size=self.real_batches[i]) 159 | loss.append(bucket_loss) 160 | else: 161 | loss_function = losses.sparse_cross_entropy 162 | for output, output_ in zip(self.output, self.output_): 163 | bucket_loss = losses.loss_wrapper(output, output_, loss_function) 164 | loss.append(bucket_loss) 165 | 166 | l_rate = tf.placeholder(tf.float32, [], name='learning_rate_holder') 167 | self.l_rate = l_rate 168 | 169 | if optimizer == 'sgd': 170 | if momentum is None: 171 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=l_rate) 172 | else: 173 | optimizer = tf.train.MomentumOptimizer(learning_rate=l_rate, momentum=momentum) 174 | elif optimizer == 'adagrad': 175 | assert lr_v is not None 176 | optimizer = tf.train.AdagradOptimizer(learning_rate=l_rate) 177 | elif optimizer == 'adam': 178 | optimizer = tf.train.AdamOptimizer() 179 | else: 180 | raise Exception('optimiser error') 181 | 182 | self.train_step = [] 183 | 184 | print 'Computing gradients...' 185 | 186 | for idx, l in enumerate(loss): 187 | t2 = time() 188 | if clipping: 189 | gradients = tf.gradients(l, self.params) 190 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, max_gradient_norm) 191 | train_step = optimizer.apply_gradients(zip(clipped_gradients, self.params)) 192 | else: 193 | train_step = optimizer.minimize(l) 194 | print 'Bucket %d, %f seconds' % (idx + 1, time() - t2) 195 | self.train_step.append(train_step) 196 | 197 | def decode_graph(self): 198 | self.decode_holders = [] 199 | self.scores = [] 200 | for bucket in self.buckets_char: 201 | decode_holders = [] 202 | scores = [] 203 | nt = self.nums_tags 204 | ob = tf.placeholder(tf.float32, [None, bucket, nt]) 205 | trans = tf.placeholder(tf.float32, [nt + 1, nt + 1]) 206 | nums_steps = ob.get_shape().as_list()[1] 207 | length = tf.placeholder(tf.int32, [None]) 208 | b_size = tf.placeholder(tf.int32, []) 209 | small = -1000 210 | class_pad = tf.stack(small * tf.ones([b_size, nums_steps, 1])) 211 | observations = tf.concat(axis=2, values=[ob, class_pad]) 212 | b_vec = tf.tile(([small] * nt + [0]), [b_size]) 213 | b_vec = tf.cast(b_vec, tf.float32) 214 | b_vec = tf.reshape(b_vec, [b_size, 1, -1]) 215 | observations = tf.concat(axis=1, values=[b_vec, observations]) 216 | transitions = tf.reshape(tf.tile(trans, [b_size, 1]), [b_size, nt + 1, nt + 1]) 217 | observations = tf.reshape(observations, [-1, nums_steps + 1, nt + 1, 1]) 218 | observations = tf.transpose(observations, [1, 0, 2, 3]) 219 | previous = observations[0, :, :, :] 220 | max_scores = [] 221 | max_scores_pre = [] 222 | alphas = [previous] 223 | for t in range(1, nums_steps + 1): 224 | previous = tf.reshape(previous, [-1, nt + 1, 1]) 225 | current = tf.reshape(observations[t, :, :, :], [-1, 1, nt + 1]) 226 | alpha_t = previous + current + transitions 227 | max_scores.append(tf.reduce_max(alpha_t, axis=1)) 228 | max_scores_pre.append(tf.argmax(alpha_t, axis=1)) 229 | alpha_t = tf.reshape(Forward.log_sum_exp(alpha_t, axis=1), [-1, nt + 1, 1]) 230 | alphas.append(alpha_t) 231 | previous = alpha_t 232 | max_scores = tf.stack(max_scores, axis=1) 233 | max_scores_pre = tf.stack(max_scores_pre, axis=1) 234 | decode_holders.append([ob, trans, length, b_size]) 235 | scores.append((max_scores, max_scores_pre)) 236 | self.decode_holders.append(decode_holders) 237 | self.scores.append(scores) 238 | 239 | def define_updates(self, new_chars, emb_path, char2idx): 240 | 241 | self.nums_chars += len(new_chars) 242 | 243 | if emb_path is not None: 244 | 245 | old_emb_weights = self.emb_layer.embeddings 246 | emb_dim = old_emb_weights.get_shape().as_list()[1] 247 | emb_len = old_emb_weights.get_shape().as_list()[0] 248 | new_emb = tf.stack(toolbox.get_new_embeddings(new_chars, emb_dim, emb_path)) 249 | n_emb_sh = new_emb.get_shape().as_list() 250 | if len(n_emb_sh) > 1: 251 | new_emb_weights = tf.concat(axis=0, values=[old_emb_weights[:len(char2idx) - len(new_chars)], new_emb, 252 | old_emb_weights[len(char2idx):]]) 253 | if new_emb_weights.get_shape().as_list()[0] > emb_len: 254 | new_emb_weights = new_emb_weights[:emb_len] 255 | assign_op = old_emb_weights.assign(new_emb_weights) 256 | self.updates.append(assign_op) 257 | 258 | def run_updates(self, sess, weight_path): 259 | weight_path = weight_path.replace('//', '/') 260 | self.saver.restore(sess, weight_path) 261 | for op in self.updates: 262 | sess.run(op) 263 | 264 | print 'Loaded.' 265 | 266 | def define_transducer_dict(self, trans_str, char2idx, sess, transducer): 267 | indices = [] 268 | for ch in trans_str: 269 | if ch == ' ': 270 | indices.append(3) 271 | elif ch in char2idx: 272 | indices.append(char2idx[ch]) 273 | else: 274 | indices.append(char2idx['']) 275 | indices += [2] 276 | out = transducer.tag([indices], char2idx, sess, batch_size=1) 277 | out = out[0].replace(' ', ' ') 278 | return out 279 | 280 | def train(self, t_x, t_y, v_x, v_y_raw, v_y_gold, idx2tag, idx2char, unk_chars, trans_dict, sess, epochs, 281 | trained_model, transducer=None, lr=0.05, decay=0.05, decay_step=1, sent_seg=False, outpath=None): 282 | lr_r = lr 283 | 284 | best_epoch = 0 285 | best_score = [0] * 6 286 | 287 | chars = toolbox.decode_chars(v_x[0], idx2char) 288 | 289 | for i in range(len(v_x[0])): 290 | for j, n in enumerate(v_x[0][i]): 291 | if n in unk_chars: 292 | v_x[0][i][j] = 1 293 | 294 | for i in range(len(t_x[0])): 295 | for k in range(len(t_x[0][i])): 296 | for j, n in enumerate(t_x[0][i][k]): 297 | if n in unk_chars: 298 | t_x[0][i][k][j] = 1 299 | 300 | transducer_dict = None 301 | if transducer is not None: 302 | char2idx = {k:v for v, k in idx2char.items()} 303 | 304 | def transducer_dict(trans_str): 305 | return self.define_transducer_dict(trans_str, char2idx, sess[-1], transducer) 306 | 307 | for epoch in range(epochs): 308 | print 'epoch: %d' % (epoch + 1) 309 | t = time() 310 | if epoch % decay_step == 0 and decay > 0: 311 | lr_r = lr/(1 + decay*(epoch/decay_step)) 312 | 313 | data_list = t_x + t_y 314 | 315 | samples = zip(*data_list) 316 | 317 | random.shuffle(samples) 318 | 319 | for sample in samples: 320 | c_len = len(sample[0][0]) 321 | idx = self.bucket_dit[c_len] 322 | real_batch_size = self.real_batches[idx] 323 | model = self.input_v[idx] + self.output_[idx] 324 | Batch.train(sess=sess[0], model=model, batch_size=real_batch_size, config=self.train_step[idx], 325 | lr=self.l_rate, lrv=lr_r, dr=self.drop_out, drv=self.drop_out_v, data=list(sample), 326 | verbose=False) 327 | 328 | predictions = [] 329 | 330 | #for v_b_x in zip(*v_x): 331 | c_len = len(v_x[0][0]) 332 | idx = self.bucket_dit[c_len] 333 | b_prediction = self.predict(data=v_x, sess=sess, model=self.input_v[idx] + self.output[idx], index=idx, 334 | argmax=True, batch_size=200) 335 | b_prediction = toolbox.decode_tags(b_prediction, idx2tag) 336 | predictions.append(b_prediction) 337 | 338 | predictions = zip(*predictions) 339 | predictions = toolbox.merge_bucket(predictions) 340 | 341 | if self.is_space == 'sea': 342 | prediction_out, raw_out = toolbox.generate_output_sea(chars, predictions) 343 | else: 344 | prediction_out, raw_out = toolbox.generate_output(chars, predictions, trans_dict, transducer_dict) 345 | 346 | if sent_seg: 347 | scores = evaluation.evaluator(prediction_out, v_y_gold, raw_out, v_y_raw) 348 | else: 349 | scores = evaluation.evaluator(prediction_out, v_y_gold) 350 | if sent_seg: 351 | c_score = scores[2] * scores[5] 352 | c_best_score = best_score[2] * best_score[5] 353 | else: 354 | c_score = scores[2] 355 | c_best_score = best_score[2] 356 | 357 | if c_score > c_best_score: 358 | best_epoch = epoch + 1 359 | best_score = scores 360 | self.saver.save(sess[0], trained_model, write_meta_graph=False) 361 | 362 | if outpath is not None: 363 | wt = codecs.open(outpath, 'w', encoding='utf-8') 364 | for pre in prediction_out[0]: 365 | wt.write(pre + '\n') 366 | wt.close() 367 | 368 | 369 | if sent_seg: 370 | print 'Sentence segmentation:' 371 | print 'F score: %f\n' % scores[5] 372 | print 'Word segmentation:' 373 | print 'F score: %f' % scores[2] 374 | else: 375 | print 'F score: %f' % c_score 376 | print 'Time consumed: %d seconds' % int(time() - t) 377 | print 'Training is finished!' 378 | if sent_seg: 379 | print 'Sentence segmentation:' 380 | print 'Best F score: %f' % best_score[5] 381 | print 'Best Precision: %f' % best_score[3] 382 | print 'Best Recall: %f\n' % best_score[4] 383 | print 'Word segmentation:' 384 | print 'Best F score: %f' % best_score[2] 385 | print 'Best Precision: %f' % best_score[0] 386 | print 'Best Recall: %f\n' % best_score[1] 387 | else: 388 | print 'Best F score: %f' % best_score[2] 389 | print 'Best Precision: %f' % best_score[0] 390 | print 'Best Recall: %f\n' % best_score[1] 391 | print 'Best epoch: %d' % best_epoch 392 | 393 | def test(self, t_x, t_y_raw, t_y_gold, idx2tag, idx2char, unk_chars, sub_dict, trans_dict, sess, transducer, 394 | ensemble=None, batch_size=100, sent_seg=False, bias=-1, outpath=None, trans_type='mix'): 395 | 396 | chars = toolbox.decode_chars(t_x[0], idx2char) 397 | gold_out = t_y_gold 398 | 399 | for i in range(len(t_x[0])): 400 | for j, n in enumerate(t_x[0][i]): 401 | if n in sub_dict: 402 | t_x[0][i][j] = sub_dict[n] 403 | elif n in unk_chars: 404 | t_x[0][i][j] = 1 405 | 406 | transducer_dict = None 407 | if transducer is not None: 408 | char2idx = {v: k for k, v in idx2char.items()} 409 | 410 | def transducer_dict(trans_str): 411 | return self.define_transducer_dict(trans_str, char2idx, sess[-1], transducer) 412 | 413 | if bias < 0: 414 | argmax = True 415 | else: 416 | argmax = False 417 | 418 | prediction = self.predict(data=t_x, sess=sess, model=self.input_v[0] + self.output[0], index=0, 419 | argmax=argmax, batch_size=batch_size, ensemble=ensemble) 420 | 421 | if bias >= 0 and self.crf == 0: 422 | prediction = [toolbox.biased_out(prediction[0], bias)] 423 | 424 | predictions = toolbox.decode_tags(prediction, idx2tag) 425 | 426 | if self.is_space == 'sea': 427 | prediction_out, raw_out = toolbox.generate_output_sea(chars, predictions) 428 | else: 429 | prediction_out, raw_out = toolbox.generate_output(chars, predictions, trans_dict, transducer_dict, 430 | trans_type=trans_type) 431 | 432 | if sent_seg: 433 | scores = evaluation.evaluator(prediction_out, gold_out, raw_out, t_y_raw) 434 | else: 435 | scores = evaluation.evaluator(prediction_out, gold_out, verbose=True) 436 | 437 | if outpath is not None: 438 | wt = codecs.open(outpath, 'w', encoding='utf-8') 439 | for pre in prediction_out[0]: 440 | wt.write(pre + '\n') 441 | wt.close() 442 | 443 | print 'Evaluation scores:' 444 | if sent_seg: 445 | print 'Sentence segmentation:' 446 | print 'F score: %f' % scores[5] 447 | print 'Precision: %f' % scores[3] 448 | print 'Recall: %f\n' % scores[4] 449 | print 'Word segmentation:' 450 | print 'F score: %f' % scores[2] 451 | print 'Precision: %f' % scores[0] 452 | print 'Recall: %f\n' % scores[1] 453 | else: 454 | print 'Precision: %f' % scores[0] 455 | print 'Recall: %f' % scores[1] 456 | print 'F score: %f' % scores[2] 457 | print 'True negative rate: %f' % scores[3] 458 | 459 | def tag(self, r_x, r_x_raw, idx2tag, idx2char, unk_chars, sub_dict, trans_dict, sess, transducer, ensemble=None, 460 | batch_size=100, outpath=None, sent_seg=False, seg_large=False, form='conll'): 461 | 462 | chars = toolbox.decode_chars(r_x[0], idx2char) 463 | 464 | for i in range(len(r_x[0])): 465 | for j, n in enumerate(r_x[0][i]): 466 | if n in sub_dict: 467 | r_x[0][i][j] = sub_dict[n] 468 | elif n in unk_chars: 469 | r_x[0][i][j] = 1 470 | 471 | c_len = len(r_x[0][0]) 472 | idx = self.bucket_dit[c_len] 473 | 474 | real_batch = batch_size * 300 / c_len 475 | 476 | transducer_dict = None 477 | if transducer is not None: 478 | char2idx = {v: k for k, v in idx2char.items()} 479 | 480 | def transducer_dict(trans_str): 481 | return self.define_transducer_dict(trans_str, char2idx, sess[-1], transducer) 482 | 483 | prediction = self.predict(data=r_x, sess=sess, model=self.input_v[idx] + self.output[idx], index=idx, 484 | argmax=True, batch_size=real_batch, ensemble=ensemble) 485 | 486 | predictions = toolbox.decode_tags(prediction, idx2tag) 487 | 488 | if self.is_space == 'sea': 489 | prediction_out, raw_out = toolbox.generate_output_sea(chars, predictions) 490 | multi_out = prediction_out 491 | else: 492 | prediction_out, raw_out, multi_out = toolbox.generate_output(chars, predictions, trans_dict, 493 | transducer_dict, multi_tok=True) 494 | 495 | pre_out = [] 496 | mut_out = [] 497 | for pre in prediction_out: 498 | pre_out += pre 499 | for mul in multi_out: 500 | mut_out += mul 501 | prediction_out = pre_out 502 | multi_out = mut_out 503 | 504 | if form == 'mlp1' or form == 'mlp2': 505 | prediction_out = toolbox.mlp_post(r_x_raw, prediction_out, self.is_space, form) 506 | 507 | if not seg_large: 508 | toolbox.printer(r_x_raw, prediction_out, multi_out, outpath, sent_seg, form) 509 | 510 | else: 511 | return prediction_out, multi_out 512 | 513 | def predict(self, data, sess, model, index=None, argmax=True, batch_size=100, ensemble=None, verbose=False): 514 | if self.crf: 515 | assert index is not None 516 | predictions = Batch.predict(sess=sess[0], decode_sess=sess[1], model=model, 517 | transitions=[self.transition_char], crf=self.crf, scores=self.scores[index], 518 | decode_holders=self.decode_holders[index], batch_size=batch_size, 519 | data=data, dr=self.drop_out, ensemble=ensemble, verbose=verbose) 520 | else: 521 | predictions = Batch.predict(sess=sess[0], model=model, crf=self.crf, argmax=argmax, batch_size=batch_size, 522 | data=data, dr=self.drop_out, ensemble=ensemble, verbose=verbose) 523 | return predictions 524 | 525 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Yan Shao, yan.shao@lingfil.uu.se 4 | """ 5 | import codecs 6 | 7 | 8 | def gold(path, is_dev=True, form='conll', is_space=False): 9 | sents = [] 10 | sent = [] 11 | cter = 0 12 | sents_dev = None 13 | if not is_dev: 14 | sents_dev = [] 15 | for line in codecs.open(path, 'rb', encoding='utf8'): 16 | line = line.strip() 17 | if form == 'conll': 18 | segs = line.split('\t') 19 | if len(segs) == 10: 20 | if '.' not in segs[0]: 21 | sent.append(tuple(segs)) 22 | elif len(sent) > 0: 23 | if not is_dev and cter == 9: 24 | sents_dev.append(sent) 25 | cter = 0 26 | else: 27 | sents.append(sent) 28 | cter += 1 29 | sent = [] 30 | elif form == 'mlp1' or form == 'mlp2': 31 | if len(line) > 0: 32 | if form == 'mlp1': 33 | segs = [] 34 | for l_seg in line.split(' '): 35 | if len(l_seg) > 0: 36 | if is_space == 'sea': 37 | segs.append(l_seg.replace('_', ' ')) 38 | else: 39 | segs += l_seg.split('\\\\') 40 | else: 41 | segs = line.split() 42 | for i, seg in enumerate(segs): 43 | sent.append((str(i + 1), seg)) 44 | if not is_dev and cter == 9: 45 | sents_dev.append(sent) 46 | cter = 0 47 | else: 48 | sents.append(sent) 49 | cter += 1 50 | sent = [] 51 | else: 52 | raise Exception('Format error, available: conll, mlp1, mlp2') 53 | if is_dev: 54 | return sents 55 | else: 56 | return sents, sents_dev 57 | 58 | 59 | def raw(path): 60 | sents = [] 61 | for line in codecs.open(path, 'rb', encoding='utf-8'): 62 | line = line.strip() 63 | sents.append(line) 64 | return sents 65 | 66 | 67 | def get_gold(sent, ignore_mwt=False): 68 | line = '' 69 | nt = -1 70 | mwt = '' 71 | segs = [] 72 | for tk in sent: 73 | if '-' in tk[0]: 74 | if nt == 0: 75 | s_mwt = ''.join(segs) 76 | if ignore_mwt and s_mwt != mwt: 77 | line += ' ' + mwt 78 | else: 79 | for seg in segs: 80 | line += ' ' + seg 81 | mwt = tk[1] 82 | sp = tk[0].split('-') 83 | nt = int(sp[1]) - int(sp[0]) + 1 84 | segs = [] 85 | elif nt == -1: 86 | line += ' ' + tk[1] 87 | elif nt > 0: 88 | segs.append(tk[1]) 89 | nt -= 1 90 | elif nt == 0: 91 | s_mwt = ''.join(segs) 92 | if ignore_mwt and s_mwt != mwt: 93 | line += ' ' + mwt 94 | else: 95 | for seg in segs: 96 | line += ' ' + seg 97 | nt = -1 98 | mwt = '' 99 | segs = [] 100 | line += ' ' + tk[1] 101 | return line.strip() 102 | 103 | 104 | def test_gold(path, form='conll', is_space=False, ignore_mwt=False): 105 | sents = [] 106 | sent = [] 107 | st = '' 108 | for line in codecs.open(path, 'rb', encoding='utf-8'): 109 | line = line.strip() 110 | if form == 'conll': 111 | segs = line.split('\t') 112 | if len(segs) == 10: 113 | if '.' not in segs[0]: 114 | sent.append(tuple(segs)) 115 | elif len(sent) > 0: 116 | sents.append(sent) 117 | sent = [] 118 | elif form == 'mlp1' or form == 'mlp2': 119 | if len(line) > 0: 120 | if form == 'mlp1': 121 | segs = [] 122 | for l_seg in line.split(' '): 123 | if is_space == 'sea': 124 | segs.append(l_seg.replace('_', ' ')) 125 | else: 126 | segs += l_seg.split('\\\\') 127 | else: 128 | segs = line.split() 129 | for seg in segs: 130 | st += ' ' + seg 131 | sents.append(st.strip()) 132 | st = '' 133 | else: 134 | raise Exception('Format error, available: conll, mlp1, mlp2') 135 | if form == 'conll': 136 | p_sents = [get_gold(s_sent, ignore_mwt=ignore_mwt) for s_sent in sents] 137 | sents = p_sents 138 | return sents 139 | 140 | 141 | def get_raw(path, fin, fout, cat='other', new=True, is_dev=True, form='conll', is_space=False): 142 | fout = codecs.open(path + '/' + fout, 'w', encoding='utf-8') 143 | fout_dev = None 144 | if not is_dev: 145 | fout_dev = codecs.open(path + '/raw_dev.txt', 'w', encoding='utf-8') 146 | cter = 0 147 | if form == 'conll': 148 | if cat == 'gold': 149 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'): 150 | line = line.strip() 151 | line = line.replace('&apos', '\'') 152 | if len(line) > 0 and ('# sentence' in line or '# text' in line): 153 | if new: 154 | if not is_dev and cter == 9: 155 | fout_dev.write(line[line.index('=') + 1:].lstrip() + '\n') 156 | cter = 0 157 | else: 158 | fout.write(line[line.index('=') + 1:].lstrip() + '\n') 159 | cter += 1 160 | else: 161 | if not is_dev and cter == 9: 162 | fout_dev.write(line[line.index(':') + 1:].lstrip() + '\n') 163 | cter = 0 164 | else: 165 | fout.write(line[line.index(':') + 1:].lstrip() + '\n') 166 | cter += 1 167 | 168 | elif cat == 'zh': 169 | pt = '' 170 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'): 171 | line = line.strip() 172 | line = line.split('\t') 173 | if len(line) == 10: 174 | pt += line[1] 175 | else: 176 | if len(pt) > 0: 177 | if not is_dev and cter == 9: 178 | fout_dev.write(pt + '\n') 179 | cter = 0 180 | else: 181 | fout.write(pt + '\n') 182 | cter += 1 183 | pt = '' 184 | 185 | else: 186 | punc_e = ['!', ')', ',', '.', ';', ':', '?', '»', '...', ']', '..', '....', '%', 'º', '²', '°'] 187 | punc_b = ['¿', '¡', '(', '«', '['] 188 | punc_m = ['"', '\''] 189 | punc_e = [s.decode('utf-8') for s in punc_e] 190 | punc_b = [s.decode('utf-8') for s in punc_b] 191 | punc_m = [s.decode('utf-8') for s in punc_m] 192 | md = {} 193 | for p in punc_m: 194 | md[p] = True 195 | pt = '' 196 | ct = 0 197 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'): 198 | line = line.strip() 199 | segs = line.split('\t') 200 | if len(segs) == 10: 201 | if '-' in segs[0]: 202 | sp = segs[0].split('-') 203 | ct = int(sp[1]) - int(sp[0]) + 1 204 | if len(pt) > 0 and pt[-1] in punc_b: 205 | pt += segs[1] 206 | elif len(pt) > 0 and pt[-1] in punc_m: 207 | if md[pt[-1]]: 208 | pt += ' ' + segs[1] 209 | else: 210 | pt += segs[1] 211 | else: 212 | pt += ' ' + segs[1] 213 | elif ct == 0: 214 | if segs[1] in punc_e: 215 | pt += segs[1] 216 | elif len(pt) > 0 and pt[-1] in punc_b: 217 | pt += segs[1] 218 | if segs[1] in punc_m: 219 | if md[segs[1]]: 220 | md[segs[1]] = False 221 | else: 222 | md[segs[1]] = True 223 | elif segs[1] in punc_m: 224 | if md[segs[1]]: 225 | pt += ' ' + segs[1] 226 | md[segs[1]] = False 227 | else: 228 | pt += segs[1] 229 | md[segs[1]] = True 230 | elif len(pt) > 0 and pt[-1] in punc_m: 231 | if md[pt[-1]]: 232 | pt += ' ' + segs[1] 233 | else: 234 | pt += segs[1] 235 | elif segs[1][0] == '\'': 236 | pt += segs[1] 237 | else: 238 | pt += ' ' + segs[1] 239 | else: 240 | ct -= 1 241 | else: 242 | if len(pt) > 0: 243 | pt = pt.lstrip() 244 | pt = pt.replace(' ",', '",') 245 | pt = pt.replace(' ".', '".') 246 | pt = pt.replace(':"...', ': "...') 247 | pt = pt.replace(' n\'t', 'n\'t') 248 | pt = pt.replace(' - ', '-') 249 | pt = pt.replace(' -- ', '--') 250 | pt = pt.replace(' / ', '/') 251 | if not is_dev and cter == 9: 252 | fout_dev.write(pt + '\n') 253 | cter = 0 254 | else: 255 | fout.write(pt + '\n') 256 | cter += 1 257 | pt = '' 258 | for p in punc_m: 259 | md[p] = True 260 | 261 | elif form == 'mlp1' or form == 'mlp2': 262 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'): 263 | line = line.strip() 264 | if len(line) > 0: 265 | if form == 'mlp1': 266 | if is_space == 'sea': 267 | line = line.replace('_', ' ') 268 | else: 269 | line = line.replace('\\\\', '') 270 | else: 271 | line = ''.join(line.split()) 272 | if not is_dev and cter == 9: 273 | fout_dev.write(line + '\n') 274 | cter = 0 275 | else: 276 | fout.write(line + '\n') 277 | cter += 1 278 | else: 279 | raise Exception('Format error, available: conll, mlp1, mlp2') 280 | fout.close() 281 | if not is_dev: 282 | fout_dev.close() 283 | -------------------------------------------------------------------------------- /segmenter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Yan Shao, yan.shao@lingfil.uu.se 4 | """ 5 | import reader 6 | import toolbox 7 | from model import Model 8 | from transducer_model import Seq2seq 9 | import sys 10 | import argparse 11 | import os 12 | import codecs 13 | import tensorflow as tf 14 | import cPickle as pickle 15 | 16 | from time import time 17 | 18 | parser = argparse.ArgumentParser(description='A Universal Tokeniser. Written by Y. Shao, Uppsala University') 19 | parser.add_argument('action', default='tag', choices=['train', 'test', 'tag'], help='train, test or tag') 20 | 21 | parser.add_argument('-f', '--format', default='conll', help='Data format of different tasks, conll, mlp1 or mlp2') 22 | 23 | parser.add_argument('-p', '--path', default=None, help='Path of the workstation') 24 | 25 | parser.add_argument('-t', '--train', default=None, help='File for training') 26 | parser.add_argument('-d', '--dev', default=None, help='File for validation') 27 | parser.add_argument('-e', '--test', default=None, help='File for evaluation') 28 | parser.add_argument('-r', '--raw', default=None, help='Raw file for tagging') 29 | 30 | parser.add_argument('-m', '--model', default='trained_model', help='Name of the trained model') 31 | parser.add_argument('-crf', '--crf', default=1, type=int, help='Using CRF interface') 32 | 33 | parser.add_argument('-bt', '--bucket_size', default=50, type=int, help='Bucket size') 34 | parser.add_argument('-sl', '--sent_limit', default=300, type=int, help='Long sentences will be chopped') 35 | 36 | parser.add_argument('-tg', '--tags', default='BIES', help='Boundary Tagging, default is BIES') 37 | 38 | parser.add_argument('-ed', '--emb_dimension', default=50, type=int, help='Dimension of the embeddings') 39 | parser.add_argument('-emb', '--embeddings', default=None, help='Path and name of pre-trained char embeddings') 40 | 41 | parser.add_argument('-ng', '--ngram', default=1, type=int, help='Using ngrams') 42 | 43 | parser.add_argument('-cell', '--cell', default='gru', help='Use GRU as the recurrent cell', choices=['gru', 'lstm']) 44 | parser.add_argument('-rnn', '--rnn_cell_dimension', default=200, type=int, help='Dimension of the RNN cells') 45 | parser.add_argument('-layer', '--rnn_layer_number', default=1, type=int, help='Numbers of the RNN layers') 46 | 47 | parser.add_argument('-dr', '--dropout_rate', default=0.5, type=float, help='Dropout rate') 48 | 49 | parser.add_argument('-iter', '--epochs', default=30, type=int, help='Numbers of epochs') 50 | parser.add_argument('-iter_trans', '--epochs_trans', default=50, type=int, help='Epochs for training the transducer') 51 | 52 | parser.add_argument('-op', '--optimizer', default='adagrad', help='Optimizer') 53 | parser.add_argument('-lr', '--learning_rate', default=0.2, type=float, help='Initial learning rate') 54 | parser.add_argument('-lr_trans', '--learning_rate_trans', default=0.3, type=float, help='Initial learning rate') 55 | parser.add_argument('-ld', '--decay_rate', default=0.05, type=float, help='Learning rate decay') 56 | parser.add_argument('-mt', '--momentum', default=None, type=float, help='Momentum') 57 | 58 | parser.add_argument('-ncp', '--no_clipping', default=False, action='store_true', help='Do not apply gradient clipping') 59 | 60 | parser.add_argument("-tb","--train_batch", help="Training batch size", default=10, type=int) 61 | parser.add_argument("-eb","--test_batch", help="Testing batch size", default=500, type=int) 62 | parser.add_argument("-rb","--tag_batch", help="Tagging batch size", default=500, type=int) 63 | 64 | parser.add_argument("-g","--gpu", help="the id of gpu, the default is 0", default=0, type=int) 65 | 66 | parser.add_argument('-opth', '--output_path', default=None, help='Output path') 67 | 68 | parser.add_argument('-sea', '--sea', help='Process languages like Vietamese', default=False, action='store_true') 69 | 70 | parser.add_argument('-ss', '--sent_seg', help='Perform sentence seg', default=False, action='store_true') 71 | 72 | parser.add_argument('-ens', '--ensemble', default=False, help='Ensemble several weights', action='store_true') 73 | 74 | parser.add_argument('-sgl', '--segment_large', default=False, help='Segment (very) large file', action='store_true') 75 | 76 | parser.add_argument('-lgs', '--large_size', default=10000, type=int, help='Segment (very) large file') 77 | 78 | parser.add_argument('-ot', '--only_tokenised', default=False, 79 | help='Only output the tokenised file when segment (very) large file', action='store_true') 80 | 81 | parser.add_argument('-ts', '--train_size', default=-1, type=int, help='No. of sentences used for training') 82 | 83 | parser.add_argument('-rs', '--reset', default=False, help='Delete and re-initialise the intermediate files', 84 | action='store_true') 85 | parser.add_argument('-rst', '--reset_trans', default=False, help='Retrain the transducers', action='store_true') 86 | 87 | parser.add_argument('-isp', '--ignore_space', default=False, help='Ignore space delimiters', action='store_true') 88 | parser.add_argument('-imt', '--ignore_mwt', default=False, help='Ignore multi-word tokens to be transcribed', 89 | action='store_true') 90 | 91 | parser.add_argument('-sb', '--segmentation_bias', default=-1, type=float, 92 | help='Add segmentation bias to under(over)-splitting') 93 | 94 | parser.add_argument('-tt', '--transduction_type', default='mix', choices=['mix', 'dict', 'trans', 'none'], 95 | help='Different ways of transducing the non-segmental MWTs') 96 | 97 | args = parser.parse_args() 98 | 99 | sys = reload(sys) 100 | sys.setdefaultencoding('utf-8') 101 | print 'Encoding: ', sys.getdefaultencoding() 102 | 103 | if args.action == 'train': 104 | assert args.path is not None 105 | path = args.path 106 | train_file = args.train 107 | dev_file = args.dev 108 | model_file = args.model 109 | print 'Reading data......' 110 | f_names = os.listdir(path) 111 | if train_file is None or dev_file is None: 112 | for f_n in f_names: 113 | if 'ud-train.conllu' in f_n or 'training.segd' in f_n or 'ud-sample.conllu' in f_n: 114 | train_file = f_n 115 | elif 'ud-dev.conllu' in f_n or 'development.segd' in f_n: 116 | dev_file = f_n 117 | assert train_file is not None 118 | is_space = True 119 | if 'Chinese' in path or 'Japanese' in path or args.format == 'mlp2': 120 | is_space = False 121 | 122 | if args.sea: 123 | is_space = 'sea' 124 | if args.reset or not os.path.isfile(path + '/raw_train.txt') or not os.path.isfile(path + '/raw_dev.txt'): 125 | cat = 'other' 126 | if 'Chinese' in path or 'Japanese' in path: 127 | cat = 'zh' 128 | for line in codecs.open(path + '/' + train_file, 'r', encoding='utf-8'): 129 | if len(line) < 2: 130 | break 131 | if '# sentence' in line or '# text' in line: 132 | cat = 'gold' 133 | 134 | if dev_file is None: 135 | reader.get_raw(path, train_file, '/raw_train.txt', cat, is_dev=False, form=args.format, is_space=is_space) 136 | else: 137 | reader.get_raw(path, train_file, '/raw_train.txt', cat, form=args.format, is_space=is_space) 138 | reader.get_raw(path, dev_file, '/raw_dev.txt', cat, form=args.format, is_space=is_space) 139 | 140 | if args.reset or not os.path.isfile(path + '/tag_train.txt') or not os.path.isfile(path + '/tag_dev.txt') or \ 141 | not os.path.isfile(path + '/tag_dev_gold.txt'): 142 | if dev_file is None: 143 | raws_train = reader.raw(path + '/raw_train.txt') 144 | raws_dev = reader.raw(path + '/raw_dev.txt') 145 | sents_train, sents_dev = reader.gold(path + '/' + train_file, False, form=args.format, is_space=is_space) 146 | else: 147 | raws_train = reader.raw(path + '/raw_train.txt') 148 | sents_train = reader.gold(path + '/' + train_file, form=args.format, is_space=is_space) 149 | 150 | raws_dev = reader.raw(path + '/raw_dev.txt') 151 | sents_dev = reader.gold(path + '/' + dev_file, form=args.format, is_space=is_space) 152 | 153 | if is_space != 'sea': 154 | toolbox.raw2tags(raws_train, sents_train, path, 'tag_train.txt', ignore_space=args.ignore_space, 155 | reset=args.reset, tag_scheme=args.tags, ignore_mwt=args.ignore_mwt) 156 | toolbox.raw2tags(raws_dev, sents_dev, path, 'tag_dev.txt', creat_dict=False, gold_path='tag_dev_gold.txt', 157 | ignore_space=args.ignore_space, tag_scheme=args.tags, ignore_mwt=args.ignore_mwt) 158 | else: 159 | toolbox.raw2tags_sea(raws_train, sents_train, path, 'tag_train.txt', reset=args.reset, tag_scheme=args.tags) 160 | toolbox.raw2tags_sea(raws_dev, sents_dev, path, 'tag_dev.txt', gold_path='tag_dev_gold.txt', 161 | tag_scheme=args.tags) 162 | 163 | if args.reset or not os.path.isfile(path + '/chars.txt'): 164 | toolbox.get_chars(path, ['raw_train.txt', 'raw_dev.txt'], sea=is_space) 165 | 166 | char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict = toolbox.get_dicts(path, args.sent_seg, args.tags, 167 | args.crf) 168 | 169 | if args.embeddings is not None: 170 | print 'Reading embeddings...' 171 | short_emb = args.embeddings[args.embeddings.index('/') + 1: args.embeddings.index('.')] 172 | if args.reset or not os.path.isfile(path + '/' + short_emb + '_sub.txt'): 173 | toolbox.get_sample_embedding(path, args.embeddings, char2idx) 174 | emb_dim, emb, valid_chars = toolbox.read_sample_embedding(path, short_emb, char2idx) 175 | for vch in valid_chars: 176 | if char2idx[vch] in unk_chars_idx: 177 | unk_chars_idx.remove(char2idx[vch]) 178 | else: 179 | emb_dim = args.emb_dimension 180 | emb = None 181 | 182 | train_x, train_y, max_len_train = toolbox.get_input_vec(path, 'tag_train.txt', char2idx, tag2idx, 183 | limit=args.sent_limit, sent_seg=args.sent_seg, 184 | is_space=is_space, train_size=args.train_size, 185 | ignore_space=args.ignore_space) 186 | 187 | dev_x, max_len_dev = toolbox.get_input_vec_raw(path, 'raw_dev.txt', char2idx, limit=args.sent_limit, 188 | sent_seg=args.sent_seg, is_space=is_space, 189 | ignore_space=args.ignore_space) 190 | if args.sent_seg: 191 | print 'Joint sentence segmentation...' 192 | else: 193 | print 'Training set: %d instances; Dev set: %d instances.' % (len(train_x[0]), len(dev_x[0])) 194 | 195 | nums_grams = None 196 | ng_embs = None 197 | 198 | if args.ngram > 1 and (args.reset or not os.path.isfile(path + '/' + str(args.ngram) + 'gram.txt')): 199 | toolbox.get_ngrams(path, args.ngram, is_space) 200 | 201 | ngram = toolbox.read_ngrams(path, args.ngram) 202 | 203 | if args.ngram > 1: 204 | gram2idx = toolbox.get_ngram_dic(ngram) 205 | train_gram = toolbox.get_gram_vec(path, 'tag_train.txt', gram2idx, limit=args.sent_limit,sent_seg=args.sent_seg, 206 | is_space=is_space, ignore_space=args.ignore_space) 207 | dev_gram = toolbox.get_gram_vec(path, 'raw_dev.txt', gram2idx, is_raw=True, limit=args.sent_limit, 208 | sent_seg=args.sent_seg, is_space=is_space, ignore_space=args.ignore_space) 209 | train_x += train_gram 210 | dev_x += dev_gram 211 | nums_grams = [] 212 | for dic in gram2idx: 213 | nums_grams.append(len(dic.keys())) 214 | 215 | max_len = max(max_len_train, max_len_dev) 216 | 217 | b_train_x, b_train_y = toolbox.buckets(train_x, train_y, size=args.bucket_size) 218 | b_train_x, b_train_y, b_lens, b_count = toolbox.pad_bucket(b_train_x, b_train_y, max_len) 219 | 220 | b_dev_x = [toolbox.pad_zeros(dev_x_i, max_len) for dev_x_i in dev_x] 221 | 222 | b_dev_y_gold = [line.strip() for line in codecs.open(path + '/tag_dev_gold.txt', 'r', encoding='utf-8')] 223 | 224 | nums_tag = len(tag2idx) 225 | 226 | config = tf.ConfigProto(allow_soft_placement=True) 227 | gpu_config = "/gpu:" + str(args.gpu) 228 | 229 | transducer = None 230 | transducer_graph = None 231 | trans_model = None 232 | trans_init = None 233 | 234 | if len(trans_dict) > 200 and not args.ignore_mwt: 235 | transducer = toolbox.get_dict_vec(trans_dict, char2idx) 236 | t = time() 237 | 238 | initializer = tf.contrib.layers.xavier_initializer() 239 | 240 | if transducer is not None: 241 | transducer_graph = tf.Graph() 242 | with transducer_graph.as_default(): 243 | with tf.variable_scope("transducer") as scope: 244 | trans_model = Seq2seq(path + '/' + model_file + '_transducer') 245 | print 'Defining transducer...' 246 | trans_model.define(char_num=len(char2idx), rnn_dim=args.rnn_cell_dimension, emb_dim=args.emb_dimension, 247 | max_x=len(transducer[0][0]), max_y=len(transducer[1][0])) 248 | trans_init = tf.global_variables_initializer() 249 | transducer_graph.finalize() 250 | 251 | print 'Initialization....' 252 | main_graph = tf.Graph() 253 | with main_graph.as_default(): 254 | with tf.variable_scope("tagger") as scope: 255 | model = Model(nums_chars=len(char2idx) + 2, nums_tags=nums_tag, buckets_char=b_lens, counts=b_count, 256 | crf=args.crf, ngram=nums_grams, batch_size=args.train_batch, sent_seg=args.sent_seg, 257 | is_space=is_space, emb_path=args.embeddings, tag_scheme=args.tags) 258 | 259 | model.main_graph(trained_model=path + '/' + model_file + '_model', scope=scope, 260 | emb_dim=emb_dim, cell=args.cell, rnn_dim=args.rnn_cell_dimension, 261 | rnn_num=args.rnn_layer_number, drop_out=args.dropout_rate, emb=emb) 262 | t = time() 263 | 264 | model.config(optimizer=args.optimizer, decay=args.decay_rate, lr_v=args.learning_rate, 265 | momentum=args.momentum, clipping=not args.no_clipping) 266 | init = tf.global_variables_initializer() 267 | 268 | print 'Done. Time consumed: %d seconds' % int(time() - t) 269 | 270 | main_graph.finalize() 271 | 272 | main_sess = tf.Session(config=config, graph=main_graph) 273 | 274 | if args.crf > 0: 275 | decode_graph = tf.Graph() 276 | with decode_graph.as_default(): 277 | model.decode_graph() 278 | decode_graph.finalize() 279 | 280 | decode_sess = tf.Session(config=config, graph=decode_graph) 281 | 282 | sess = [main_sess, decode_sess] 283 | 284 | else: 285 | sess = [main_sess, None] 286 | 287 | with tf.device(gpu_config): 288 | 289 | if transducer is not None: 290 | print 'Building transducer...' 291 | t = time() 292 | trans_sess = tf.Session(config=config, graph=transducer_graph) 293 | trans_sess.run(trans_init) 294 | trans_model.train(transducer[0], transducer[1], transducer[2], transducer[3], args.learning_rate_trans, 295 | char2idx, trans_sess, args.epochs_trans, batch_size=10, reset=args.reset_trans) 296 | sess.append(trans_sess) 297 | print 'Done. Time consumed: %d seconds' % int(time() - t) 298 | print 'Training the main segmenter..' 299 | main_sess.run(init) 300 | print 'Initialisation...' 301 | print 'Done. Time consumed: %d seconds' % int(time() - t) 302 | t = time() 303 | b_dev_raw = [line.strip() for line in codecs.open(path + '/raw_dev.txt', 'r', encoding='utf-8')] 304 | model.train(b_train_x, b_train_y, b_dev_x, b_dev_raw, b_dev_y_gold, idx2tag, idx2char, unk_chars_idx, trans_dict, 305 | sess, args.epochs, path + '/' + model_file + '_weights', transducer=trans_model, 306 | lr=args.learning_rate, decay=args.decay_rate, sent_seg=args.sent_seg, outpath=args.output_path) 307 | 308 | else: 309 | 310 | assert args.path is not None 311 | assert args.model is not None 312 | path = args.path 313 | assert os.path.isfile(path + '/chars.txt') 314 | 315 | model_file = args.model 316 | 317 | if args.ensemble: 318 | if not os.path.isfile(path + '/' + model_file + '_1_model') or not os.path.isfile(path + '/' + model_file + 319 | '_1_weights.index'): 320 | raise Exception('Not any model file or weights file under the name of ' + model_file + '.') 321 | fin = open(path + '/' + model_file + '_1_model', 'rb') 322 | else: 323 | if not os.path.isfile(path + '/' + model_file + '_model') or not os.path.isfile(path + '/' + model_file + 324 | '_weights.index'): 325 | raise Exception('No model file or weights file under the name of ' + model_file + '.') 326 | fin = open(path + '/' + model_file + '_model', 'rb') 327 | 328 | weight_path = path + '/' + model_file 329 | 330 | param_dic = pickle.load(fin) 331 | fin.close() 332 | 333 | nums_chars = param_dic['nums_chars'] 334 | nums_tags = param_dic['nums_tags'] 335 | crf = param_dic['crf'] 336 | emb_dim = param_dic['emb_dim'] 337 | cell = param_dic['cell'] 338 | rnn_dim = param_dic['rnn_dim'] 339 | rnn_num = param_dic['rnn_num'] 340 | drop_out = param_dic['drop_out'] 341 | buckets_char = param_dic['buckets_char'] 342 | nums_ngrams = param_dic['ngram'] 343 | is_space = param_dic['is_space'] 344 | sent_seg = param_dic['sent_seg'] 345 | emb_path = param_dic['emb_path'] 346 | tag_scheme = param_dic['tag_scheme'] 347 | 348 | if args.embeddings is not None: 349 | emb_path = args.embeddings 350 | 351 | ngram = 1 352 | grams, gram2idx = None, None 353 | if nums_ngrams is not None: 354 | ngram = len(nums_ngrams) + 1 355 | 356 | char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict = toolbox.get_dicts(path, sent_seg, tag_scheme, crf) 357 | 358 | trans_char_num = len(char2idx) 359 | 360 | if ngram > 1: 361 | grams = toolbox.read_ngrams(path, ngram) 362 | 363 | new_chars, new_grams = None, None 364 | 365 | test_x, test_y, raw_x, test_y_gold = None, None, None, None 366 | 367 | sub_dict = None 368 | 369 | max_step = None 370 | 371 | raw_file = None 372 | 373 | if args.action == 'test': 374 | test_file = args.test 375 | f_names = os.listdir(path) 376 | if test_file is None: 377 | for f_n in f_names: 378 | if 'ud-test.conllu' in f_n: 379 | test_file = f_n 380 | assert test_file is not None 381 | 382 | cat = 'other' 383 | if 'Chinese' in path or 'Japanese' in path: 384 | cat = 'zh' 385 | for line in codecs.open(path + '/' + test_file, 'r', encoding='utf-8'): 386 | if len(line) < 2: 387 | break 388 | if '# sentence' in line or '# text' in line: 389 | cat = 'gold' 390 | reader.get_raw(path, test_file, 'raw_test.txt', cat, form=args.format) 391 | 392 | raws_test = reader.raw(path + '/raw_test.txt') 393 | test_y_gold = reader.test_gold(path + '/' + test_file, form=args.format, is_space=is_space, 394 | ignore_mwt=args.ignore_mwt) 395 | 396 | new_chars = toolbox.get_new_chars(path + '/raw_test.txt', char2idx, is_space) 397 | 398 | if emb_path is not None: 399 | valid_chars = toolbox.get_valid_chars(new_chars + char2idx.keys(), emb_path) 400 | else: 401 | valid_chars = None 402 | 403 | char2idx, idx2char, unk_chars_idx, sub_dict = toolbox.update_char_dict(char2idx, new_chars, unk_chars_idx, valid_chars) 404 | 405 | test_x, max_len_test = toolbox.get_input_vec_raw(path, 'raw_test.txt', char2idx, limit=args.sent_limit + 100, 406 | sent_seg=sent_seg, is_space=is_space, 407 | ignore_space=args.ignore_space) 408 | 409 | max_step = max_len_test 410 | 411 | if sent_seg: 412 | print 'Joint sentence segmentation...' 413 | else: 414 | print 'Test set: %d instances.' % len(test_x[0]) 415 | 416 | if ngram > 1: 417 | gram2idx = toolbox.get_ngram_dic(grams) 418 | new_grams = toolbox.get_new_grams(path + '/' + test_file, gram2idx, is_space=is_space) 419 | 420 | test_grams = toolbox.get_gram_vec(path, 'raw_test.txt', gram2idx, is_raw=True, limit=args.sent_limit + 100, 421 | sent_seg=sent_seg, is_space=is_space, ignore_space=args.ignore_space) 422 | test_x += test_grams 423 | 424 | for k in range(len(test_x)): 425 | test_x[k] = toolbox.pad_zeros(test_x[k], max_step) 426 | 427 | elif args.action == 'tag': 428 | assert args.raw is not None 429 | 430 | raw_file = args.raw 431 | new_chars = toolbox.get_new_chars(raw_file, char2idx, is_space) 432 | 433 | if emb_path is not None: 434 | valid_chars = toolbox.get_valid_chars(new_chars, emb_path) 435 | else: 436 | valid_chars = None 437 | 438 | char2idx, idx2char, unk_chars_idx, sub_dict = toolbox.update_char_dict(char2idx, new_chars, unk_chars_idx, 439 | valid_chars) 440 | 441 | if not args.segment_large: 442 | 443 | if sent_seg: 444 | raw_x, raw_len = toolbox.get_input_vec_tag(None, raw_file, char2idx, limit=args.sent_limit + 100, 445 | is_space=is_space) 446 | else: 447 | raw_x, raw_len = toolbox.get_input_vec_raw(None, raw_file, char2idx, limit=args.sent_limit + 100, 448 | sent_seg=sent_seg, is_space=is_space) 449 | 450 | if sent_seg: 451 | print 'Joint sentence segmentation...' 452 | else: 453 | print 'Raw setences: %d instances.' % len(raw_x[0]) 454 | 455 | max_step = raw_len 456 | 457 | else: 458 | 459 | max_step = args.sent_limit 460 | 461 | if ngram > 1: 462 | gram2idx = toolbox.get_ngram_dic(grams) 463 | new_grams = toolbox.get_new_grams(raw_file, gram2idx, is_raw=True, is_space=is_space) 464 | 465 | if not args.segment_large: 466 | if sent_seg: 467 | raw_grams = toolbox.get_gram_vec_tag(None, raw_file, gram2idx, limit=args.sent_limit + 100, 468 | is_space=is_space) 469 | else: 470 | raw_grams = toolbox.get_gram_vec(None, raw_file, gram2idx, is_raw=True, limit=args.sent_limit + 100, 471 | sent_seg=sent_seg, is_space=is_space) 472 | 473 | raw_x += raw_grams 474 | 475 | if not args.segment_large: 476 | for k in range(len(raw_x)): 477 | raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step) 478 | 479 | config = tf.ConfigProto(allow_soft_placement=True) 480 | gpu_config = "/gpu:" + str(args.gpu) 481 | 482 | transducer = None 483 | transducer_graph = None 484 | trans_model = None 485 | trans_init = None 486 | 487 | if len(trans_dict) > 200: 488 | transducer = toolbox.get_dict_vec(trans_dict, char2idx) 489 | t = time() 490 | 491 | initializer = tf.contrib.layers.xavier_initializer() 492 | 493 | if transducer is not None: 494 | transducer_graph = tf.Graph() 495 | with transducer_graph.as_default(): 496 | with tf.variable_scope("transducer") as scope: 497 | trans_model = Seq2seq(path + '/' + model_file + '_transducer') 498 | trans_fin = open(path + '/' + model_file + '_transducer_model', 'rb') 499 | trans_param_dic = pickle.load(trans_fin) 500 | trans_fin.close() 501 | 502 | tr_char_num = trans_param_dic['char_num'] 503 | tr_rnn_dim = trans_param_dic['rnn_dim'] 504 | tr_emb_dim = trans_param_dic['emb_dim'] 505 | tr_max_x = trans_param_dic['max_x'] 506 | tr_max_y = trans_param_dic['max_y'] 507 | 508 | print 'Defining transducer...' 509 | trans_model.define(char_num=tr_char_num, rnn_dim=tr_rnn_dim, emb_dim=tr_emb_dim, 510 | max_x=tr_max_x, max_y=tr_max_y, write_trans_model=False) 511 | trans_init = tf.global_variables_initializer() 512 | transducer_graph.finalize() 513 | 514 | print 'Initialization....' 515 | main_graph = tf.Graph() 516 | with main_graph.as_default(): 517 | with tf.variable_scope("tagger") as scope: 518 | model = Model(nums_chars=nums_chars, nums_tags=nums_tags, buckets_char=[max_step], counts=[200], 519 | crf=crf, ngram=nums_ngrams, batch_size=args.tag_batch, is_space=is_space) 520 | 521 | model.main_graph(trained_model=None, scope=scope, emb_dim=emb_dim, cell=cell, 522 | rnn_dim=rnn_dim, rnn_num=rnn_num, drop_out=drop_out) 523 | 524 | model.define_updates(new_chars=new_chars, emb_path=emb_path, char2idx=char2idx) 525 | 526 | init = tf.global_variables_initializer() 527 | 528 | print 'Done. Time consumed: %d seconds' % int(time() - t) 529 | main_graph.finalize() 530 | 531 | idx=None 532 | 533 | if args.ensemble: 534 | idx = 1 535 | main_sess = [] 536 | while os.path.isfile(path + '/' + model_file + '_' + str(idx) + '_weights.index'): 537 | main_sess.append(tf.Session(config=config, graph=main_graph)) 538 | idx += 1 539 | else: 540 | main_sess = tf.Session(config=config, graph=main_graph) 541 | 542 | if crf: 543 | decode_graph = tf.Graph() 544 | 545 | with decode_graph.as_default(): 546 | model.decode_graph() 547 | decode_graph.finalize() 548 | 549 | decode_sess = tf.Session(config=config, graph=decode_graph) 550 | 551 | sess = [main_sess, decode_sess] 552 | 553 | else: 554 | sess = [main_sess, None] 555 | 556 | with tf.device(gpu_config): 557 | ens_model = None 558 | print 'Loading weights....' 559 | if args.ensemble: 560 | for i in range(1, idx): 561 | print 'Ensemble: ' + str(i) 562 | main_sess[i - 1].run(init) 563 | model.run_updates(main_sess[i - 1], weight_path + '_' + str(i) + '_weights') 564 | else: 565 | main_sess.run(init) 566 | model.run_updates(main_sess, weight_path + '_weights') 567 | 568 | if transducer is not None: 569 | print 'Loading transducer...' 570 | t = time() 571 | trans_sess = tf.Session(config=config, graph=transducer_graph) 572 | trans_sess.run(trans_init) 573 | if os.path.isfile(path + '/' + model_file + '_transducer_weights'): 574 | trans_weight_path = path + '/' + model_file + '_transducer_weights' 575 | trans_weight_path = trans_weight_path.replace('//', '/') 576 | trans_model.saver.restore(trans_sess, trans_weight_path) 577 | sess.append(trans_sess) 578 | 579 | if args.action == 'test': 580 | test_y_raw = [line.strip() for line in codecs.open(path + '/raw_test.txt', 'rb', encoding='utf-8')] 581 | model.test(test_x, test_y_raw, test_y_gold, idx2tag, idx2char, unk_chars_idx, sub_dict, trans_dict, sess, 582 | transducer=trans_model, ensemble=args.ensemble, batch_size=args.test_batch, sent_seg=sent_seg, 583 | bias=args.segmentation_bias, outpath=args.output_path, trans_type=args.transduction_type) 584 | 585 | if args.action == 'tag': 586 | 587 | if not args.segment_large: 588 | raw_sents = [] 589 | for line in codecs.open(raw_file, 'rb', encoding='utf-8'): 590 | line = line.strip() 591 | if len(line) > 0: 592 | raw_sents.append(line) 593 | model.tag(raw_x, raw_sents, idx2tag, idx2char, unk_chars_idx, sub_dict, trans_dict, sess, 594 | transducer=trans_model, outpath=args.output_path, ensemble=args.ensemble, 595 | batch_size=args.tag_batch, sent_seg=sent_seg, seg_large=args.segment_large, form=args.format) 596 | else: 597 | count = 0 598 | c_line = 0 599 | l_writer = codecs.open(args.output_path, 'w', encoding='utf-8') 600 | out = [] 601 | with codecs.open(raw_file, 'r', encoding='utf-8') as l_file: 602 | lines = [] 603 | for line in l_file: 604 | line = line.strip() 605 | if len(line) > 0: 606 | lines.append(line) 607 | else: 608 | c_line += 1 609 | if c_line >= args.large_size: 610 | count += len(lines) 611 | c_line = 0 612 | print count 613 | if args.sent_seg: 614 | raw_x, _ = toolbox.get_input_vec_tag(None, None, char2idx, lines=lines, 615 | limit=args.sent_limit, is_space=is_space) 616 | else: 617 | raw_x, _ = toolbox.get_input_vec_raw(None, None, char2idx, lines=lines, 618 | limit=args.sent_limit, sent_seg=sent_seg, 619 | is_space=is_space) 620 | if ngram > 1: 621 | if sent_seg: 622 | raw_grams = toolbox.get_gram_vec_tag(None, None, gram2idx, lines=lines, 623 | limit=args.sent_limit, is_space=is_space) 624 | else: 625 | raw_grams = toolbox.get_gram_vec(None, None, gram2idx, lines=lines, is_raw=True, 626 | limit=args.sent_limit, sent_seg=sent_seg, 627 | is_space=is_space) 628 | raw_x += raw_grams 629 | 630 | for k in range(len(raw_x)): 631 | raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step) 632 | 633 | predition, multi = model.tag(raw_x, lines, idx2tag, idx2char, unk_chars_idx, sub_dict, 634 | trans_dict, sess, transducer=trans_model, 635 | outpath=args.output_path, ensemble=args.ensemble, 636 | batch_size=args.tag_batch, sent_seg=sent_seg, 637 | seg_large=args.segment_large, form=args.format) 638 | 639 | if args.only_tokenised: 640 | for l_out in predition: 641 | if len(l_out.strip()) > 0: 642 | l_writer.write(l_out + '\n') 643 | else: 644 | for tagged_t, multi_t in zip(predition, multi): 645 | if len(tagged_t.strip()) > 0: 646 | l_writer.write('#sent_tok: ' + tagged_t + '\n') 647 | idx = 1 648 | tgs = multi_t.split(' ') 649 | pl = '' 650 | for _ in range(8): 651 | pl += '\t' + '_' 652 | for tg in tgs: 653 | if '!#!' in tg: 654 | segs = tg.split('!#!') 655 | l_writer.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + 656 | segs[0] + pl + '\n') 657 | else: 658 | l_writer.write(str(idx) + '\t' + tg + pl + '\n') 659 | idx += 1 660 | l_writer.write('\n') 661 | lines = [] 662 | if len(lines) > 0: 663 | 664 | if args.sent_seg: 665 | raw_x, _ = toolbox.get_input_vec_tag(None, None, char2idx, lines=lines, 666 | limit=args.sent_limit, is_space=is_space) 667 | else: 668 | raw_x, _ = toolbox.get_input_vec_raw(None, None, char2idx, lines=lines, 669 | limit=args.sent_limit, sent_seg=sent_seg, 670 | is_space=is_space) 671 | if ngram > 1: 672 | if sent_seg: 673 | raw_grams = toolbox.get_gram_vec_tag(None, None, gram2idx, lines=lines, 674 | limit=args.sent_limit, is_space=is_space) 675 | else: 676 | raw_grams = toolbox.get_gram_vec(None, None, gram2idx, lines=lines, is_raw=True, 677 | limit=args.sent_limit, sent_seg=sent_seg, 678 | is_space=is_space) 679 | raw_x += raw_grams 680 | 681 | for k in range(len(raw_x)): 682 | raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step) 683 | 684 | prediction, multi = model.tag(raw_x, lines, idx2tag, idx2char, unk_chars_idx, sub_dict, 685 | trans_dict, sess, transducer=trans_model, 686 | outpath=args.output_path, ensemble=args.ensemble, 687 | batch_size=args.tag_batch, sent_seg=sent_seg, 688 | seg_large=args.segment_large, form=args.format) 689 | 690 | if args.only_tokenised: 691 | for l_out in prediction: 692 | if len(l_out.strip()) > 0: 693 | l_writer.write(l_out + '\n') 694 | else: 695 | for tagged_t, multi_t in zip(prediction, multi): 696 | if len(tagged_t.strip()) > 0: 697 | l_writer.write('#sent_tok: ' + tagged_t + '\n') 698 | idx = 1 699 | tgs = multi_t.split(' ') 700 | pl = '' 701 | for _ in range(8): 702 | pl += '\t' + '_' 703 | for tg in tgs: 704 | if '!#!' in tg: 705 | segs = tg.split('!#!') 706 | l_writer.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + 707 | segs[0] + pl + '\n') 708 | else: 709 | l_writer.write(str(idx) + '\t' + tg + pl + '\n') 710 | idx += 1 711 | l_writer.write('\n') 712 | l_writer.close() 713 | 714 | print 'Done.' 715 | -------------------------------------------------------------------------------- /toolbox.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Yan Shao, yan.shao@lingfil.uu.se 4 | """ 5 | import codecs 6 | import sys 7 | import numpy as np 8 | import random 9 | import os 10 | import math 11 | from reader import get_gold 12 | 13 | sys = reload(sys) 14 | sys.setdefaultencoding('utf-8') 15 | 16 | punc = ['!', ')', ',', '.', ';', ':', '?', '»', '...', '..', '....', '%', 'º', '²', '°', '¿', '¡', '(', '«', 17 | '"', '\'', '-', '。', '·', '।', '۔'] 18 | 19 | 20 | def pre_token(line): 21 | out = [] 22 | for seg in line.split(' '): 23 | f_out = [] 24 | b_out = [] 25 | while len(seg) > 0 and (seg[0] in punc or ('0' <= seg[0] <= '9')): 26 | f_out.append(seg[0]) 27 | seg = seg[1:] 28 | while len(seg) > 0 and (seg[-1] in punc or ('0' <= seg[-1] <= '9')): 29 | b_out = [seg[-1]] + b_out 30 | seg = seg[:-1] 31 | if len(seg) > 0: 32 | out += f_out + [seg] + b_out 33 | else: 34 | out += f_out + b_out 35 | return out 36 | 37 | 38 | def get_chars(path, filelist, sea=False): 39 | char_set = {} 40 | out_char = codecs.open(path + '/chars.txt', 'w', encoding='utf-8') 41 | for i, file_name in enumerate(filelist): 42 | for line in codecs.open(path + '/' + file_name, 'rb', encoding='utf-8'): 43 | line = line.strip() 44 | if sea=='sea': 45 | line = pre_token(line) 46 | for ch in line: 47 | if ch in char_set: 48 | if i == 0: 49 | char_set[ch] += 1 50 | else: 51 | char_set[ch] = 1 52 | for k, v in char_set.items(): 53 | out_char.write(k + '\t' + str(v) + '\n') 54 | out_char.close() 55 | 56 | 57 | def get_dicts(path, sent_seg, tag_scheme='BIES', crf=1): 58 | char2idx = {'

': 0, '': 1, '<#>': 2} 59 | unk_chars_idx = [] 60 | idx = 3 61 | for line in codecs.open(path + '/chars.txt', 'r', encoding='utf-8'): 62 | segs = line.split('\t') 63 | if len(segs[0].strip()) == 0: 64 | if ' ' not in char2idx: 65 | char2idx[' '] = idx 66 | idx += 1 67 | else: 68 | char2idx[segs[0]] = idx 69 | if int(segs[1]) == 1: 70 | unk_chars_idx.append(idx) 71 | idx += 1 72 | idx2char = {k: v for v, k in char2idx.items()} 73 | if tag_scheme == 'BI': 74 | if crf > 0: 75 | tag2idx = {'

': 0, 'B': 1, 'I': 2} 76 | idx = 3 77 | else: 78 | tag2idx = {'B': 0, 'I': 1} 79 | idx = 2 80 | else: 81 | if crf > 0: 82 | tag2idx = {'

': 0, 'B': 1, 'I': 2, 'E': 3, 'S': 4} 83 | idx = 5 84 | else: 85 | tag2idx = {'B': 0, 'I':1, 'E':2, 'S':3} 86 | idx = 4 87 | for line in codecs.open(path + '/tags.txt', 'r', encoding='utf-8'): 88 | line = line.strip() 89 | if line not in tag2idx: 90 | tag2idx[line] = idx 91 | idx += 1 92 | if sent_seg: 93 | tag2idx['T'] = idx 94 | tag2idx['U'] = idx + 1 95 | idx2tag = {k: v for v, k in tag2idx.items()} 96 | 97 | trans_dict = {} 98 | key = '' 99 | if os.path.isfile(path + '/dict.txt'): 100 | for line in codecs.open(path + '/dict.txt', 'r', encoding='utf-8'): 101 | line = line.strip() 102 | if len(line) > 0: 103 | segs = line.split('\t') 104 | if len(segs) == 1: 105 | key = segs[0] 106 | trans_dict[key] = None 107 | elif len(segs) == 2: 108 | if trans_dict[key] is None: 109 | trans_dict[key] = segs[0].replace(' ', ' ') 110 | 111 | return char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict 112 | 113 | 114 | def ngrams(raw, gram, is_space): 115 | gram_set = {} 116 | li = gram/2 117 | ri = gram - li - 1 118 | p = '' 119 | last_line = '' 120 | is_first = True 121 | for line in raw: 122 | for i in range(len(line)): 123 | if i - li < 0: 124 | if is_space != 'sea': 125 | lp = p * (li - i) + line[:i] 126 | else: 127 | lp = [p] * (li - i) + line[:i] 128 | else: 129 | lp = line[i - li:i] 130 | if i + ri + 1 > len(line): 131 | if is_space != 'sea': 132 | rp = line[i:] + p*(i + ri + 1 - len(line)) 133 | else: 134 | rp = line[i:] + [p] * (i + ri + 1 - len(line)) 135 | else: 136 | rp = line[i:i+ri+1] 137 | ch = lp + rp 138 | if is_space == 'sea': 139 | ch = '_'.join(ch) 140 | if ch in gram_set: 141 | gram_set[ch] += 1 142 | else: 143 | gram_set[ch] = 1 144 | if is_first: 145 | is_first = False 146 | else: 147 | if is_space is True: 148 | last_line += ' ' 149 | start_idx = len(last_line) - ri 150 | if start_idx < 0: 151 | start_idx = 0 152 | end_idx = li + len(last_line) 153 | j_line = last_line + line 154 | for i in range(start_idx, end_idx): 155 | if i - li < 0: 156 | if is_space != 'sea': 157 | j_lp = p * (-i) + j_line[start_idx:i] 158 | else: 159 | j_lp = [p] * (-i) + j_line[start_idx:i] 160 | else: 161 | j_lp = j_line[i - li:i] 162 | if i + ri + 1 > len(j_line): 163 | if is_space != 'sea': 164 | j_rp = j_line[i:end_idx] + p * (ri + i + 1 - len(j_line)) 165 | else: 166 | j_rp = j_line[i:end_idx] + [p] * (ri + i + 1 - len(j_line)) 167 | else: 168 | j_rp = j_line[i:ri + 1 + i] 169 | j_ch = j_lp + j_rp 170 | if is_space == 'sea': 171 | ch = '_'.join(j_ch) 172 | if ch in gram_set: 173 | gram_set[ch] += 1 174 | else: 175 | gram_set[ch] = 1 176 | last_line = line 177 | return gram_set 178 | 179 | 180 | def get_ngrams(path, ng, is_space): 181 | raw = [] 182 | for line in codecs.open(path + '/raw_train.txt', 'r', encoding='utf-8'): 183 | if is_space == 'sea': 184 | segs = pre_token(line.strip()) 185 | else: 186 | segs = line.strip() 187 | raw.append(segs) 188 | if ng > 1: 189 | for i in range(2, ng + 1): 190 | out_gram = codecs.open(path + '/' + str(i) + 'gram.txt', 'w', encoding='utf-8') 191 | grams = ngrams(raw, i, is_space) 192 | for k, v in grams.items(): 193 | out_gram.write(k + '\t' + str(v) + '\n') 194 | out_gram.close() 195 | 196 | 197 | def read_ngrams(path, ng): 198 | ngs = [] 199 | for i in range(2, ng + 1): 200 | ng = {} 201 | for line in codecs.open(path + '/' + str(i) + 'gram.txt', 'r', encoding='utf-8'): 202 | line = line.rstrip() 203 | segs = line.split('\t') 204 | while len(segs[0]) < i: 205 | segs[0] += ' ' 206 | ng[segs[0]] = int(segs[1]) 207 | ngs.append(ng) 208 | return ngs 209 | 210 | 211 | def get_sample_embedding(path, emb, chars2idx): 212 | chars = chars2idx.keys() 213 | short_emb = emb[emb.index('/') + 1: emb.index('.')] 214 | emb_dic = {} 215 | valid_chars=[] 216 | for line in codecs.open(emb, 'rb', encoding='utf-8'): 217 | line = line.strip() 218 | sets = line.split(' ') 219 | emb_dic[sets[0]] = np.asarray(sets[1:], dtype='float32') 220 | fout = codecs.open(path + '/' + short_emb + '_sub.txt', 'w', encoding='utf-8') 221 | for ch in chars: 222 | p_line = ch 223 | if ch in emb_dic: 224 | valid_chars.append(ch) 225 | for emb in emb_dic[ch]: 226 | p_line += ' ' + unicode(emb) 227 | fout.write(p_line + '\n') 228 | fout.close() 229 | 230 | 231 | def read_sample_embedding(path, short_emb, char2idx): 232 | emb_values = [] 233 | valid_chars = [] 234 | emb_dic={} 235 | for line in codecs.open(path + '/' + short_emb + '_sub.txt', 'rb', encoding='utf-8'): 236 | first_ch = line[0] 237 | line = line.rstrip() 238 | sets = line.split(' ') 239 | if first_ch == ' ': 240 | emb_dic[' '] = np.asarray(sets, dtype='float32') 241 | else: 242 | emb_dic[sets[0]] = np.asarray(sets[1:], dtype='float32') 243 | emb_dim = len(emb_dic.items()[0][1]) 244 | for ch in char2idx.keys(): 245 | if ch in emb_dic: 246 | emb_values.append(emb_dic[ch]) 247 | valid_chars.append(ch) 248 | else: 249 | rand = np.random.uniform(-math.sqrt(float(3) / emb_dim), math.sqrt(float(3) / emb_dim), emb_dim) 250 | emb_values.append(np.asarray(rand, dtype='float32')) 251 | emb_dim = len(emb_values[0]) 252 | return emb_dim, emb_values, valid_chars 253 | 254 | 255 | def get_sent_raw(path, fname, is_space=True): 256 | long_line = '' 257 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'): 258 | line = line.strip() 259 | if is_space: 260 | long_line += ' ' + line 261 | else: 262 | long_line += line 263 | if is_space: 264 | long_line = long_line[1:] 265 | 266 | return long_line 267 | 268 | 269 | def chop(line, ad_s, limit): 270 | out = [] 271 | chopped = False 272 | while len(line) > 0: 273 | if chopped: 274 | s_line = line[:limit - 1] 275 | s_line = [ad_s] + s_line 276 | else: 277 | chopped = True 278 | s_line = line[:limit] 279 | out.append(s_line) 280 | line = line[limit - 10:] 281 | if len(line) < 10: 282 | line = '' 283 | while len(out) > 0 and len(out[-1]) < limit-1: 284 | out[-1].append(0) 285 | return out 286 | 287 | 288 | def get_input_vec(path, fname, char2idx, tag2idx, limit=500, sent_seg=False, is_space=True, train_size=-1, ignore_space=False): 289 | ct = 0 290 | max_len = 0 291 | space_idx = None 292 | if is_space is True: 293 | assert ' ' in char2idx 294 | space_idx = char2idx[' '] 295 | x_indices = [] 296 | y_indices = [] 297 | s_count = 0 298 | l_count = 0 299 | x = [] 300 | y = [] 301 | 302 | n_sent = 0 303 | 304 | if sent_seg: 305 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'): 306 | line = line.strip() 307 | if len(line) == 0: 308 | ct = 0 309 | elif ct == 0: 310 | if is_space == 'sea': 311 | line = pre_token(line) 312 | for ch in line: 313 | if len(ch.strip()) == 0: 314 | x.append(char2idx[' ']) 315 | elif ch in char2idx: 316 | x.append(char2idx[ch]) 317 | else: 318 | x.append(char2idx['']) 319 | if is_space is True and not ignore_space: 320 | x = [space_idx] + x 321 | x_indices += x 322 | x = [] 323 | ct = 1 324 | elif ct == 1: 325 | for ch in line: 326 | y.append(tag2idx[ch]) 327 | if y[-1] == tag2idx['S']: 328 | y[-1] = tag2idx['T'] 329 | else: 330 | y[-1] = tag2idx['U'] 331 | if is_space is True and not ignore_space: 332 | y = [tag2idx['X']] + y 333 | y_indices += y 334 | y = [] 335 | n_sent += 1 336 | if 0 < train_size <= n_sent: 337 | break 338 | x_indices = chop(x_indices, char2idx['<#>'], limit) 339 | y_indices = chop(y_indices, tag2idx['I'], limit) 340 | max_len = limit 341 | else: 342 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'): 343 | line = line.strip() 344 | if len(line) == 0: 345 | ct = 0 346 | elif ct == 0: 347 | if is_space == 'sea': 348 | line = pre_token(line) 349 | max_len = max(max_len, len(line)) 350 | s_count += 1 351 | if len(line) > limit: 352 | l_count += 1 353 | chopped = False 354 | while len(line) > 0: 355 | s_line = line[:limit - 1] 356 | line = line[limit - 10:] 357 | if len(line) < 10: 358 | line = '' 359 | if not chopped: 360 | chopped = True 361 | else: 362 | x.append(char2idx['<#>']) 363 | for ch in s_line: 364 | if len(ch.strip()) == 0: 365 | x.append(char2idx[' ']) 366 | elif ch in char2idx: 367 | x.append(char2idx[ch]) 368 | else: 369 | x.append(char2idx['']) 370 | x_indices.append(x) 371 | x = [] 372 | ct = 1 373 | elif ct == 1: 374 | chopped = False 375 | while len(line) > 0: 376 | s_line = line[:limit - 1] 377 | line = line[limit - 10:] 378 | if len(line) < 10: 379 | line = '' 380 | if not chopped: 381 | chopped = True 382 | else: 383 | y.append(tag2idx['I']) 384 | for ch in s_line: 385 | y.append(tag2idx[ch]) 386 | y_indices.append(y) 387 | y = [] 388 | n_sent += 1 389 | if 0 < train_size <= n_sent: 390 | break 391 | max_len = min(max_len, limit) 392 | if l_count > 0: 393 | print '%d (out of %d) sentences are chopped.' % (l_count, s_count) 394 | return [x_indices], [y_indices], max_len 395 | 396 | 397 | def get_input_vec_sent(path, fname, char2idx, win_size, is_space=True): 398 | pre_line = '' 399 | c_line = '' 400 | x = [] 401 | y = [] 402 | is_first = True 403 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'): 404 | line = line.strip() 405 | if is_space == 'sea': 406 | line = pre_token(line) 407 | start_idx = len(pre_line) 408 | if is_space is True: 409 | j_line = pre_line + ' ' + c_line + ' ' + line 410 | end_idx = start_idx + len(c_line) + 1 411 | if is_first: 412 | is_first = False 413 | j_line = j_line[1:] 414 | end_idx -= 1 415 | else: 416 | j_line = pre_line + c_line + line 417 | end_idx = start_idx + len(c_line) 418 | for i in range(start_idx, end_idx): 419 | indices = [] 420 | for j in range(i - win_size, i + win_size + 1): 421 | if j < 0 or j >= len(j_line): 422 | indices.append(char2idx['

']) 423 | else: 424 | if j_line[j] in char2idx: 425 | indices.append(char2idx[j_line[j]]) 426 | else: 427 | indices.append(char2idx['']) 428 | x.append(indices) 429 | if i == end_idx - 1: 430 | y.append(1) 431 | else: 432 | y.append(0) 433 | pre_line = c_line 434 | c_line = line 435 | if is_space is True: 436 | j_line = pre_line + ' ' + c_line 437 | else: 438 | j_line = pre_line + c_line 439 | start_idx = len(pre_line) 440 | end_idx = start_idx + len(c_line) 441 | for i in range(start_idx, end_idx): 442 | indices = [] 443 | for j in range(i - win_size, i + win_size + 1): 444 | if j < 0 or j >= len(j_line): 445 | indices.append(char2idx['

']) 446 | else: 447 | if j_line[j] in char2idx: 448 | indices.append(char2idx[j_line[j]]) 449 | else: 450 | indices.append(char2idx['']) 451 | x.append(indices) 452 | if i == end_idx - 1: 453 | y.append(1) 454 | else: 455 | y.append(0) 456 | 457 | assert len(x) == len(y) 458 | return x, y 459 | 460 | 461 | def get_input_vec_sent_raw(raws, char2idx, win_size): 462 | x = [] 463 | for i in range(len(raws)): 464 | indices = [] 465 | for j in range(i - win_size, i + win_size + 1): 466 | if j < 0 or j >= len(raws): 467 | indices.append(char2idx['

']) 468 | else: 469 | if raws[j] in char2idx: 470 | indices.append(char2idx[raws[j]]) 471 | else: 472 | indices.append(char2idx['']) 473 | x.append(indices) 474 | return x 475 | 476 | 477 | def get_input_vec_raw(path, fname, char2idx, lines=None, limit=500, sent_seg=False, is_space=True, ignore_space=False): 478 | max_len = 0 479 | space_idx = None 480 | is_first = True 481 | if is_space is True: 482 | assert ' ' in char2idx 483 | space_idx = char2idx[' '] 484 | x_indices = [] 485 | s_count = 0 486 | l_count = 0 487 | x = [] 488 | if lines is None: 489 | assert fname is not None 490 | if path is None: 491 | real_path = fname 492 | else: 493 | real_path = path + '/' + fname 494 | lines = codecs.open(real_path, 'r', encoding='utf-8') 495 | if sent_seg: 496 | for line in lines: 497 | line = line.strip() 498 | if is_space == 'sea': 499 | line = pre_token(line) 500 | elif ignore_space: 501 | line = ''.join(line.split()) 502 | for ch in line: 503 | if len(ch.strip()) == 0: 504 | x.append(char2idx[' ']) 505 | elif ch in char2idx: 506 | x.append(char2idx[ch]) 507 | else: 508 | x.append(char2idx['']) 509 | if is_space is True and not ignore_space: 510 | if is_first: 511 | is_first = False 512 | else: 513 | x = [space_idx] + x 514 | x_indices += x 515 | x = [] 516 | x_indices = chop(x_indices, char2idx['<#>'], limit) 517 | max_len = limit 518 | else: 519 | for line in lines: 520 | line = line.strip() 521 | if len(line) > 0: 522 | if is_space == 'sea': 523 | line = pre_token(line) 524 | elif ignore_space: 525 | line = ''.join(line.split()) 526 | max_len = max(max_len, len(line)) 527 | s_count += 1 528 | 529 | for ch in line: 530 | if len(ch.strip()) == 0: 531 | x.append(char2idx[' ']) 532 | elif ch in char2idx: 533 | x.append(char2idx[ch]) 534 | else: 535 | x.append(char2idx['']) 536 | 537 | if len(line) > limit: 538 | l_count += 1 539 | chop_x = chop(x, char2idx['<#>'], limit) 540 | x_indices += chop_x 541 | else: 542 | x_indices.append(x) 543 | x = [] 544 | max_len = min(max_len, limit) 545 | if l_count > 0: 546 | print '%d (out of %d) sentences are chopped.' % (l_count, s_count) 547 | return [x_indices], max_len 548 | 549 | 550 | def get_input_vec_tag(path, fname, char2idx, lines=None, limit=500, is_space=True): 551 | space_idx = None 552 | if is_space is True: 553 | assert ' ' in char2idx 554 | space_idx = char2idx[' '] 555 | x_indices = [] 556 | out = [] 557 | x = [] 558 | is_first = True 559 | if lines is None: 560 | assert fname is not None 561 | if path is None: 562 | real_path = fname 563 | else: 564 | real_path = path + '/' + fname 565 | lines = codecs.open(real_path, 'r', encoding='utf-8') 566 | for line in lines: 567 | line = line.strip() 568 | if len(line) > 0: 569 | if is_space == 'sea': 570 | line = pre_token(line) 571 | if len(line) > 0: 572 | for ch in line: 573 | if len(ch.strip()) == 0: 574 | x.append(char2idx[' ']) 575 | elif ch in char2idx: 576 | x.append(char2idx[ch]) 577 | else: 578 | x.append(char2idx['']) 579 | if is_space is True: 580 | if is_first: 581 | is_first = False 582 | else: 583 | x = [space_idx] + x 584 | x_indices += x 585 | x = [] 586 | elif len(x_indices) > 0: 587 | x_indices = chop(x_indices, char2idx['<#>'], limit) 588 | out += x_indices 589 | x_indices = [] 590 | is_first = True 591 | 592 | if len(x_indices) > 0: 593 | x_indices = chop(x_indices, char2idx['<#>'], limit) 594 | out += x_indices 595 | 596 | return [out], limit 597 | 598 | 599 | def get_vecs(str, char2idx): 600 | out = [] 601 | for ch in str: 602 | if ch in char2idx: 603 | out.append(char2idx[ch]) 604 | return out 605 | 606 | 607 | def get_dict_vec(trans_dict, char2idx): 608 | max_x, max_y = 0, 0 609 | x = [] 610 | y = [] 611 | for k, v in trans_dict.items(): 612 | x.append(get_vecs(k, char2idx)) 613 | y.append(get_vecs(v.replace('  ', ' '), char2idx) + [2]) 614 | if len(k) > max_x: 615 | max_x = len(k) 616 | if len(v) > max_y: 617 | max_y = len(v) 618 | max_x += 5 619 | max_y += 5 620 | x = pad_zeros(x, max_x) 621 | y = pad_zeros(y, max_y) 622 | assert len(x) == len(y) 623 | num = len(x) 624 | xy = zip(x, y) 625 | random.shuffle(xy) 626 | xy = zip(*xy) 627 | t_x = xy[0][:int(num * 0.95)] 628 | t_y = xy[1][:int(num * 0.95)] 629 | v_x = xy[0][int(num * 0.95):] 630 | v_y = xy[1][int(num * 0.95):] 631 | return t_x, t_y, v_x, v_y 632 | 633 | 634 | def get_ngram_dic(ng): 635 | gram_dics = [] 636 | for i, gram in enumerate(ng): 637 | g_dic = {'

': 0, '': 1, '<#>': 2} 638 | idx = 3 639 | for g in gram.keys(): 640 | if gram[g] > 1: 641 | g_dic[g] = idx 642 | else: 643 | g_dic[g] = 1 644 | idx += 1 645 | gram_dics.append(g_dic) 646 | return gram_dics 647 | 648 | 649 | def gram_vec(raw, dic, limit=500, sent_seg=False, is_space=True): 650 | out = [] 651 | if is_space == 'sea': 652 | ngram = len(dic.keys()[0].split('_')) 653 | else: 654 | ngram = 0 655 | for k in dic.keys(): 656 | if '' not in k: 657 | ngram = len(k) 658 | break 659 | li = ngram/2 660 | ri = ngram - li - 1 661 | p = '' 662 | indices = [] 663 | is_first = True 664 | if sent_seg: 665 | last_line = '' 666 | for line in raw: 667 | for i in range(len(line)): 668 | if i - li < 0: 669 | if is_space != 'sea': 670 | lp = p * (li - i) + line[:i] 671 | else: 672 | lp = [p] * (li - i) + line[:i] 673 | else: 674 | lp = line[i - li:i] 675 | if i + ri + 1 > len(line): 676 | if is_space != 'sea': 677 | rp = line[i:] + p * (i + ri + 1 - len(line)) 678 | else: 679 | rp = line[i:] + [p] * (i + ri + 1 - len(line)) 680 | else: 681 | rp = line[i:i + ri + 1] 682 | ch = lp + rp 683 | if is_space == 'sea': 684 | ch = '_'.join(ch) 685 | if ch in dic: 686 | indices.append(dic[ch]) 687 | else: 688 | indices.append(dic['']) 689 | if is_first: 690 | is_first = False 691 | else: 692 | start_idx = len(last_line) - ri 693 | if start_idx < 0: 694 | start_idx = 0 695 | if is_space: 696 | last_line += ' ' 697 | j_line = last_line + line 698 | end_idx = len(last_line) + li 699 | j_indices = [] 700 | for i in range(start_idx, end_idx): 701 | if i - li < 0: 702 | if is_space != 'sea': 703 | j_lp = p * (-i) + j_line[start_idx:i] 704 | else: 705 | j_lp = [p] * (-i) + j_line[start_idx:i] 706 | else: 707 | j_lp = j_line[i - li:i] 708 | if i + ri + 1 > len(j_line): 709 | if is_space != 'sea': 710 | j_rp = j_line[i:end_idx] + p * (ri + i + 1 - len(j_line)) 711 | else: 712 | j_rp = j_line[i:end_idx] + [p] * (ri + i + 1 - len(j_line)) 713 | else: 714 | j_rp = j_line[i:ri + 1 + i] 715 | j_ch = j_lp + j_rp 716 | if is_space == 'sea': 717 | j_ch = '_'.join(j_ch) 718 | if j_ch in dic: 719 | j_indices.append(dic[j_ch]) 720 | else: 721 | j_indices.append(dic['']) 722 | if ri > 0: 723 | out = out[: - ri] + j_indices[:ri] 724 | if is_space: 725 | indices = j_indices[ - (li + 1):] + indices[li:] 726 | else: 727 | indices = j_indices[ - li:] + indices[li:] 728 | out += indices 729 | indices = [] 730 | last_line = line 731 | out = chop(out, dic['<#>'], limit) 732 | 733 | else: 734 | for line in raw: 735 | chopped = False 736 | while len(line) > 0: 737 | s_line = line[:limit - 1] 738 | line = line[limit - 10:] 739 | if len(line) < 10: 740 | line = '' 741 | if not chopped: 742 | chopped = True 743 | else: 744 | indices.append(dic['<#>']) 745 | for i in range(len(s_line)): 746 | if i - li < 0: 747 | if is_space != 'sea': 748 | lp = p * (li - i) + s_line[:i] 749 | else: 750 | lp = [p] * (li - i) + s_line[:i] 751 | else: 752 | lp = s_line[i - li:i] 753 | if i + ri + 1 > len(s_line): 754 | if is_space != 'sea': 755 | rp = s_line[i:] + p * (i + ri + 1 - len(s_line)) 756 | else: 757 | rp = s_line[i:] + [p] * (i + ri + 1 - len(s_line)) 758 | else: 759 | rp = s_line[i:i + ri + 1] 760 | ch = lp + rp 761 | if is_space == 'sea': 762 | ch = '_'.join(ch) 763 | if ch in dic: 764 | indices.append(dic[ch]) 765 | else: 766 | indices.append(dic['']) 767 | out.append(indices) 768 | indices = [] 769 | return out 770 | 771 | 772 | def get_gram_vec(path, fname, gram2index, lines=None, is_raw=False, limit=500, sent_seg=False, is_space=True, ignore_space=False): 773 | raw = [] 774 | i = 0 775 | if lines is None: 776 | assert fname is not None 777 | if path is None: 778 | real_path = fname 779 | else: 780 | real_path = path + '/' + fname 781 | lines = codecs.open(real_path, 'r', encoding='utf-8') 782 | for line in lines: 783 | line = line.strip() 784 | if is_space == 'sea': 785 | line = pre_token(line) 786 | elif ignore_space: 787 | line = ''.join(line.split()) 788 | if i == 0 or is_raw: 789 | raw.append(line) 790 | i += 1 791 | if len(line) > 0: 792 | i += 1 793 | else: 794 | i = 0 795 | out = [] 796 | for g_dic in gram2index: 797 | out.append(gram_vec(raw, g_dic, limit, sent_seg, is_space)) 798 | return out 799 | 800 | 801 | def get_gram_vec_tag(path, fname, gram2index, lines=None, limit=500, is_space=True, ignore_space=False): 802 | raw = [] 803 | out = [[] for _ in range(len(gram2index))] 804 | if lines is None: 805 | assert fname is not None 806 | if path is None: 807 | real_path = fname 808 | else: 809 | real_path = path + '/' + fname 810 | lines = codecs.open(real_path, 'r', encoding='utf-8') 811 | for line in lines: 812 | line = line.strip() 813 | if is_space == 'sea': 814 | line = pre_token(line) 815 | elif ignore_space: 816 | line = ''.join(line.split()) 817 | if len(line) > 0: 818 | raw.append(line) 819 | else: 820 | for i, g_dic in enumerate(gram2index): 821 | out[i] += gram_vec(raw, g_dic, limit, True, is_space) 822 | raw = [] 823 | if len(raw) > 0: 824 | for i, g_dic in enumerate(gram2index): 825 | out[i] += gram_vec(raw, g_dic, limit, True, is_space) 826 | return out 827 | 828 | 829 | def read_vocab_tag(path): 830 | ''' 831 | Read tags from index files and create dictionaries 832 | :param path: 833 | :return tag2idx, idx2tag 834 | ''' 835 | tag2idx = {} 836 | for i, line in enumerate(codecs.open(path, 'rb', encoding='utf-8')): 837 | line = line.strip() 838 | tag2idx[line] = i 839 | idx2tag = {k: v for v, k in tag2idx.items()} 840 | return tag2idx, idx2tag 841 | 842 | 843 | def get_tags(can, action='sep', tag_scheme='BIES', ignore_mwt=False): 844 | tags = [] 845 | if tag_scheme == 'BI': 846 | for i in range(len(can)): 847 | if i == 0: 848 | if action == 'sep' or ignore_mwt: 849 | tags.append('B') 850 | else: 851 | tags.append('K') 852 | else: 853 | if action == 'sep' or ignore_mwt: 854 | tags.append('I') 855 | else: 856 | tags.append('Z') 857 | else: 858 | for i in range(len(can)): 859 | if len(can) == 1: 860 | if action == 'sep' or ignore_mwt: 861 | tags.append('S') 862 | else: 863 | tags.append('D') 864 | elif i == 0: 865 | if action == 'sep' or ignore_mwt: 866 | tags.append('B') 867 | else: 868 | tags.append('K') 869 | elif i == len(can) - 1: 870 | if action == 'sep' or ignore_mwt: 871 | tags.append('E') 872 | else: 873 | tags.append('J') 874 | else: 875 | if action == 'sep' or ignore_mwt: 876 | tags.append('I') 877 | else: 878 | tags.append('Z') 879 | return tags 880 | 881 | 882 | def update_dict(trans_dic, can, trans): 883 | can = can.lower() 884 | if can not in trans_dic: 885 | trans_dic[can] = {} 886 | if trans not in trans_dic[can]: 887 | trans_dic[can][trans] = 1 888 | else: 889 | trans_dic[can][trans] += 1 890 | return trans_dic 891 | 892 | 893 | def raw2tags(raw, sents, path, train_file, creat_dict=True, gold_path=None, ignore_space=False, reset=False, 894 | tag_scheme='BIES', ignore_mwt=False): 895 | wt = codecs.open(path + '/' + train_file, 'w', encoding='utf-8') 896 | if creat_dict and not ignore_mwt: 897 | wd = codecs.open(path + '/dict.txt', 'w', encoding='utf-8') 898 | wg = None 899 | if gold_path is not None: 900 | wg = codecs.open(path + '/' + gold_path, 'w', encoding='utf-8') 901 | wtg = None 902 | if reset or not os.path.isfile(path + '/tags.txt'): 903 | wtg = codecs.open(path + '/tags.txt', 'w', encoding='utf-8') 904 | final_dic = {} 905 | assert len(raw) == len(sents) 906 | invalid = 0 907 | s_tags = set() 908 | 909 | def matched(can, sent_l, tags, trans_dic): 910 | if '-' in sent_l[0][0]: 911 | nums = sent_l[0][0].split('-') 912 | count = int(nums[1]) - int(nums[0]) 913 | sent_l.pop(0) 914 | segs = [] 915 | while count >= 0: 916 | segs.append(sent_l[0][1]) 917 | sent_l.pop(0) 918 | count -= 1 919 | j_seg = ''.join(segs) 920 | if j_seg == can: 921 | for seg in segs: 922 | tags += get_tags(seg, tag_scheme=tag_scheme) 923 | elif can.replace('-', '') == j_seg: 924 | for c_split in can.split('-'): 925 | tags += get_tags(c_split, tag_scheme=tag_scheme) 926 | if tag_scheme == 'BI': 927 | tags.append('I') 928 | else: 929 | tags.append('X') 930 | tags.pop() 931 | else: 932 | tags += get_tags(can, action='trans', tag_scheme=tag_scheme, ignore_mwt=ignore_mwt) 933 | if not ignore_mwt: 934 | trans = ' '.join(segs) 935 | trans_dic = update_dict(trans_dic, can, trans) 936 | else: 937 | tags += get_tags(can, tag_scheme=tag_scheme) 938 | sent_l.pop(0) 939 | 940 | return tags, trans_dic 941 | 942 | for raw_l, sent_l in zip(raw, sents): 943 | if ignore_space: 944 | raw_l = ''.join(raw_l.split()) 945 | tags = [] 946 | cans = raw_l.split(' ') 947 | trans_dic = {} 948 | gold = get_gold(sent_l, ignore_mwt=ignore_mwt) 949 | pre = '' 950 | for can in cans: 951 | t_can = can.strip() 952 | purged = len(can) - len(t_can) 953 | if purged > 0: 954 | can = t_can 955 | while purged > 0: 956 | if tag_scheme == 'BI': 957 | tags.append('I') 958 | else: 959 | tags.append('X') 960 | purged -= 1 961 | done = False 962 | if len(pre) > 0: 963 | can = pre + ' ' + can 964 | while not done: 965 | if can == sent_l[0][1]: 966 | tags, trans_dic = matched(can, sent_l, tags, trans_dic) 967 | done = True 968 | pre = '' 969 | else: 970 | if len(can) >= len(sent_l[0][1]): 971 | s_l = len(sent_l[0][1]) 972 | s_can = can[:s_l] 973 | if s_can != sent_l[0][1]: 974 | done = True 975 | tags, trans_dic = matched(s_can, sent_l, tags, trans_dic) 976 | can = can[s_l:] 977 | if len(can) == 0: 978 | done = True 979 | pre = '' 980 | else: 981 | pre = can 982 | done = True 983 | if len(pre) == 0: 984 | if tag_scheme == 'BI': 985 | tags.append('I') 986 | else: 987 | tags.append('X') 988 | if len(tags) > 0: 989 | tags.pop() 990 | if len(tags) == len(raw_l): 991 | for tg in tags: 992 | s_tags.add(tg) 993 | wt.write(raw_l + '\n') 994 | wt.write(''.join(tags) + '\n') 995 | wt.write('\n') 996 | for key in trans_dic: 997 | if key not in final_dic: 998 | final_dic[key] = trans_dic[key] 999 | else: 1000 | for tr in trans_dic[key]: 1001 | if tr in final_dic[key]: 1002 | final_dic[key][tr] += trans_dic[key][tr] 1003 | else: 1004 | final_dic[key][tr] = trans_dic[key][tr] 1005 | else: 1006 | invalid += 1 1007 | if wg is not None: 1008 | wg.write(gold + '\n') 1009 | if wg is not None: 1010 | wg.close() 1011 | if wtg is not None: 1012 | for stg in s_tags: 1013 | wtg.write(stg + '\n') 1014 | wtg.close() 1015 | if creat_dict and not ignore_mwt: 1016 | for key in final_dic: 1017 | wd.write(key + '\n') 1018 | s_dic = sorted(final_dic[key].items(), key=lambda x: x[1], reverse=True) 1019 | for i in s_dic: 1020 | wd.write(i[0] + '\t' + str(i[1]) + '\n') 1021 | wd.write('\n') 1022 | wt.close() 1023 | print 'invalid sentences: ', invalid, len(raw) 1024 | 1025 | 1026 | def raw2tags_sea(raw, sents, path, train_file, gold_path=None, reset=False, tag_scheme='BIES'): 1027 | wt = codecs.open(path + '/' + train_file, 'w', encoding='utf-8') 1028 | wg = None 1029 | if gold_path is not None: 1030 | wg = codecs.open(path + '/' + gold_path, 'w', encoding='utf-8') 1031 | assert len(raw) == len(sents) 1032 | invalid = 0 1033 | wtg = None 1034 | if reset or not os.path.isfile(path + '/tags.txt'): 1035 | wtg = codecs.open(path + '/tags.txt', 'w', encoding='utf-8') 1036 | 1037 | s_tags = set() 1038 | 1039 | def matched(can, sent_l, tags): 1040 | segs = can.split(' ') 1041 | sent_l.pop(0) 1042 | if len(segs) == 1: 1043 | tags.append('S') 1044 | elif len(segs) > 1: 1045 | if tag_scheme == 'BI': 1046 | tags += ['B'] + ['I'] * (len(segs) - 1) 1047 | else: 1048 | mid_t = ['I'] * (len(segs) - 2) 1049 | tags += ['B'] + mid_t + ['E'] 1050 | return tags 1051 | 1052 | for raw_l, sent_l in zip(raw, sents): 1053 | tags = [] 1054 | cans = pre_token(raw_l) 1055 | gold = get_gold(sent_l) 1056 | pre = '' 1057 | for can in cans: 1058 | t_can = can.strip() 1059 | purged = len(can) - len(t_can) 1060 | if purged > 0: 1061 | can = t_can 1062 | while purged > 0: 1063 | if tag_scheme == 'BI': 1064 | tags.append('I') 1065 | else: 1066 | tags.append('X') 1067 | purged -= 1 1068 | if len(pre) > 0: 1069 | can = pre + ' ' + can 1070 | j_can = ''.join(can.split()) 1071 | if sent_l: 1072 | j_sent = ''.join(sent_l[0][1].split()) 1073 | if j_can == j_sent: 1074 | tags = matched(can, sent_l, tags) 1075 | pre = '' 1076 | else: 1077 | assert len(j_can) < len(j_sent) 1078 | pre = can 1079 | if len(tags) == len(cans): 1080 | for tg in tags: 1081 | s_tags.add(tg) 1082 | wt.write(raw_l + '\n') 1083 | wt.write(''.join(tags) + '\n') 1084 | wt.write('\n') 1085 | else: 1086 | invalid += 1 1087 | if wg is not None: 1088 | wg.write(gold + '\n') 1089 | if wg is not None: 1090 | wg.close() 1091 | if wtg is not None: 1092 | for stg in s_tags: 1093 | wtg.write(stg + '\n') 1094 | wtg.close() 1095 | wt.close() 1096 | 1097 | print 'invalid sentences: ', invalid, len(raw) 1098 | 1099 | 1100 | def pad_zeros(l, max_len): 1101 | padded = None 1102 | if type(l) is list: 1103 | padded = [] 1104 | for item in l: 1105 | if len(item) <= max_len: 1106 | padded.append(np.pad(item, (0, max_len - len(item)), 'constant', constant_values=0)) 1107 | else: 1108 | padded.append(np.asarray(item[:max_len])) 1109 | padded = np.asarray(padded) 1110 | elif type(l) is dict: 1111 | padded = {} 1112 | for k, v in l.iteritems(): 1113 | padded[k] = [np.pad(item, (0, max_len - len(item)), 'constant', constant_values=0) for item in v] 1114 | return padded 1115 | 1116 | def unpad_zeros(l): 1117 | out = [] 1118 | for tags in l: 1119 | out.append([np.trim_zeros(line) for line in tags]) 1120 | return out 1121 | 1122 | 1123 | def buckets(x, y, size=50): 1124 | assert len(x[0]) == len(y[0]) 1125 | num_inputs = len(x) 1126 | samples = x + y 1127 | num_items = len(samples) 1128 | xy = zip(*samples) 1129 | xy.sort(key=lambda i: len(i[0])) 1130 | t_len = size 1131 | idx = 0 1132 | bucks = [[[]] for _ in range(num_items)] 1133 | for item in xy: 1134 | if len(item[0]) > t_len: 1135 | if len(bucks[0][idx]) > 0: 1136 | for buck in bucks: 1137 | buck.append([]) 1138 | idx += 1 1139 | while len(item[0]) > t_len: 1140 | t_len += size 1141 | for i in range(num_items): 1142 | #print item[i] 1143 | bucks[i][idx].append(item[i]) 1144 | 1145 | return bucks[:num_inputs], bucks[num_inputs:] 1146 | 1147 | 1148 | def pad_bucket(x, y, limit, bucket_len_c=None): 1149 | assert len(x[0]) == len(y[0]) 1150 | num_inputs = len(x) 1151 | num_tags = len(y) 1152 | padded = [[] for _ in range(num_tags + num_inputs)] 1153 | bucket_counts = [] 1154 | samples = x + y 1155 | xy = zip(*samples) 1156 | if bucket_len_c is None: 1157 | bucket_len_c = [] 1158 | for i, item in enumerate(xy): 1159 | max_len = len(item[0][-1]) 1160 | if i == len(xy) - 1: 1161 | max_len = limit 1162 | bucket_len_c.append(max_len) 1163 | bucket_counts.append(len(item[0])) 1164 | for idx in range(num_tags + num_inputs): 1165 | padded[idx].append(pad_zeros(item[idx], max_len)) 1166 | print 'Number of buckets: ', len(bucket_len_c) 1167 | else: 1168 | idy = 0 1169 | for item in xy: 1170 | max_len = len(item[0][-1]) 1171 | while idy < len(bucket_len_c) and max_len > bucket_len_c[idy]: 1172 | idy += 1 1173 | bucket_counts.append(len(item[0])) 1174 | if idy >= len(bucket_len_c): 1175 | for idx in range(num_tags + num_inputs): 1176 | padded[idx].append(pad_zeros(item[idx], max_len)) 1177 | bucket_len_c.append(max_len) 1178 | else: 1179 | for idx in range(num_tags + num_inputs): 1180 | padded[idx].append(pad_zeros(item[idx], bucket_len_c[idy])) 1181 | return padded[:num_inputs], padded[num_inputs:], bucket_len_c, bucket_counts 1182 | 1183 | 1184 | def get_real_batch(counts, b_size): 1185 | real_batch_sizes = [] 1186 | for c in counts: 1187 | if c < b_size: 1188 | real_batch_sizes.append(c) 1189 | else: 1190 | real_batch_sizes.append(b_size) 1191 | return real_batch_sizes 1192 | 1193 | 1194 | def merge_bucket(x): 1195 | out = [] 1196 | for item in x: 1197 | m = [] 1198 | for i in item: 1199 | m += i 1200 | out.append(m) 1201 | return out 1202 | 1203 | 1204 | def decode_tags(idx, index2tags): 1205 | out = [] 1206 | for id in idx: 1207 | sents = [] 1208 | for line in id: 1209 | sent = [] 1210 | for item in line: 1211 | tag = index2tags[item] 1212 | tag = tag.replace('E', 'I') 1213 | tag = tag.replace('S', 'B') 1214 | tag = tag.replace('J', 'Z') 1215 | tag = tag.replace('D', 'K') 1216 | sent.append(tag) 1217 | sents.append(sent) 1218 | out.append(sents) 1219 | return out 1220 | 1221 | 1222 | def decode_chars(idx, idx2chars): 1223 | out = [] 1224 | for line in idx: 1225 | line = np.trim_zeros(line) 1226 | out.append([idx2chars[item] for item in line]) 1227 | return out 1228 | 1229 | 1230 | def generate_output(chars, tags, trans_dict, transducer_dict=None, multi_tok=False, trans_type='mix'): 1231 | out = [] 1232 | mult_out = [] 1233 | raw_out = [] 1234 | sent_seg = False 1235 | 1236 | def map_trans(c_trans, type=trans_type): 1237 | if c_trans in trans_dict and (type == 'mix' or type == 'dict'): 1238 | c_trans = trans_dict[c_trans] 1239 | elif transducer_dict is not None and (type == 'mix' or type == 'trans'): 1240 | c_trans = transducer_dict(c_trans) 1241 | sp = c_trans.split() 1242 | c_trans = ' '.join(sp) 1243 | 1244 | return c_trans 1245 | 1246 | def add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=False): 1247 | c_trans = c_trans.strip() 1248 | if len(c_trans) > 0: 1249 | if trans: 1250 | o_trans = c_trans 1251 | c_trans = map_trans(c_trans) 1252 | if multi_tok: 1253 | num_tr = len(c_trans.split(' ')) 1254 | mt_p_line += ' ' + o_trans + '!#!' + str(num_tr) + ' ' + c_trans 1255 | else: 1256 | if multi_tok: 1257 | mt_p_line += ' ' + c_trans 1258 | p_line += ' ' + c_trans 1259 | return p_line, mt_p_line 1260 | 1261 | def split_sent(lines, s_str): 1262 | for i in range(len(lines)): 1263 | s_line = lines[i].strip() 1264 | while s_line and s_line[-1] == s_str: 1265 | s_line = s_line[:-1] 1266 | sents = s_line.split(s_str) 1267 | lines[i] = [sent.strip() for sent in sents] 1268 | return lines 1269 | 1270 | for i, tag in enumerate(tags): 1271 | assert len(chars) == len(tag) 1272 | sub_out = [] 1273 | sub_raw_out = [] 1274 | multi_sub_out = [] 1275 | j_chars = [] 1276 | j_tags = [] 1277 | is_first = True 1278 | for chs, tgs in zip(chars, tag): 1279 | if chs[0] == '<#>': 1280 | assert len(j_chars) > 0 1281 | if is_first: 1282 | is_first = False 1283 | j_chars[-1] = j_chars[-1][:-5] + chs[6:] 1284 | j_tags[-1] = j_tags[-1][:-5] + tgs[6:] 1285 | else: 1286 | j_chars[-1] = j_chars[-1][:-5] + chs[5:] 1287 | j_tags[-1] = j_tags[-1][:-5] + tgs[5:] 1288 | else: 1289 | j_chars.append(chs) 1290 | j_tags.append(tgs) 1291 | is_first = True 1292 | chars = j_chars 1293 | tag = j_tags 1294 | for chs, tgs in zip(chars, tag): 1295 | assert len(chs) == len(tgs) 1296 | c_word = '' 1297 | c_trans = '' 1298 | p_line = '' 1299 | r_line = '' 1300 | mt_p_line = '' 1301 | for ch, tg in zip(chs, tgs): 1302 | r_line += ch 1303 | if tg == 'I': 1304 | if len(c_trans) > 0: 1305 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True) 1306 | c_trans = '' 1307 | c_word = ch 1308 | else: 1309 | c_word += ch 1310 | elif tg == 'Z': 1311 | if len(c_word) > 0: 1312 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok) 1313 | c_word = '' 1314 | c_trans = ch 1315 | else: 1316 | c_trans += ch 1317 | elif tg == 'B': 1318 | if len(c_word) > 0: 1319 | c_word = c_word.strip() 1320 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok) 1321 | elif len(c_trans) > 0: 1322 | c_trans = c_trans.strip() 1323 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True) 1324 | c_trans = '' 1325 | c_word = ch 1326 | elif tg == 'K': 1327 | if len(c_word) > 0: 1328 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok) 1329 | c_word = '' 1330 | elif len(c_trans) > 0: 1331 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True) 1332 | c_trans = ch 1333 | elif tg == 'T': 1334 | sent_seg = True 1335 | if len(c_word) > 0: 1336 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok) 1337 | c_word = '' 1338 | elif len(c_trans) > 0: 1339 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True) 1340 | c_trans = '' 1341 | p_line += ' ' + ch + '' 1342 | if multi_tok: 1343 | mt_p_line += ' ' + ch + '' 1344 | r_line += '' 1345 | elif tg == 'U': 1346 | sent_seg = True 1347 | if len(c_word) > 0: 1348 | c_word += ch 1349 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok) 1350 | c_word = '' 1351 | elif len(c_trans) > 0: 1352 | c_trans += ch 1353 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True) 1354 | c_trans = '' 1355 | elif len(ch.strip()) > 0: 1356 | p_line += ch 1357 | if multi_tok: 1358 | mt_p_line += ch 1359 | p_line += '' 1360 | if multi_tok: 1361 | mt_p_line += '' 1362 | r_line += '' 1363 | elif tg == 'X' and len(ch.strip()) > 0: 1364 | if len(c_word) > 0: 1365 | c_word += ch 1366 | elif len(c_trans) > 0: 1367 | c_trans += ch 1368 | else: 1369 | c_word = ch 1370 | elif len(ch.strip()) > 0: 1371 | if len(c_word) > 0: 1372 | c_word += ' ' + ch 1373 | elif len(c_trans) > 0: 1374 | c_trans += ' ' + ch 1375 | else: 1376 | c_word = ch 1377 | if len(c_word) > 0: 1378 | c_word = c_word.strip() 1379 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok) 1380 | elif len(c_trans) > 0: 1381 | c_trans = c_trans.strip() 1382 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True) 1383 | sub_out.append(p_line.strip()) 1384 | sub_raw_out.append(r_line.strip()) 1385 | if multi_tok: 1386 | multi_sub_out.append(mt_p_line.strip()) 1387 | out.append(sub_out) 1388 | raw_out.append(sub_raw_out) 1389 | if multi_tok: 1390 | mult_out.append(multi_sub_out) 1391 | out[0][-1].rstrip('') 1392 | raw_out[0][-1].rstrip('') 1393 | if sent_seg: 1394 | out = split_sent(out[0], '') 1395 | raw_out = split_sent(raw_out[0], '') 1396 | if multi_tok: 1397 | mult_out[0][-1].rstrip('') 1398 | if sent_seg: 1399 | mult_out = split_sent(mult_out[0], '') 1400 | return out, raw_out, mult_out 1401 | else: 1402 | return out, raw_out 1403 | 1404 | 1405 | def generate_output_sea(chars, tags): 1406 | out = [] 1407 | raw_out = [] 1408 | sent_seg = False 1409 | 1410 | def split_sent(lines, s_str): 1411 | for i in range(len(lines)): 1412 | s_line = lines[i].strip() 1413 | while s_line and s_line[-1] == s_str: 1414 | s_line = s_line[:-1] 1415 | sents = s_line.split(s_str) 1416 | lines[i] = [sent.strip() for sent in sents] 1417 | return lines 1418 | 1419 | for i, tag in enumerate(tags): 1420 | assert len(chars) == len(tag) 1421 | sub_out = [] 1422 | sub_raw_out = [] 1423 | j_chars = [] 1424 | j_tags = [] 1425 | is_first = True 1426 | for chs, tgs in zip(chars, tag): 1427 | if chs[0] == '<#>': 1428 | assert len(j_chars) > 0 1429 | if is_first: 1430 | is_first = False 1431 | j_chars[-1] = j_chars[-1][:-5] + chs[6:] 1432 | j_tags[-1] = j_tags[-1][:-5] + tgs[6:] 1433 | else: 1434 | j_chars[-1] = j_chars[-1][:-5] + chs[5:] 1435 | j_tags[-1] = j_tags[-1][:-5] + tgs[5:] 1436 | else: 1437 | j_chars.append(chs) 1438 | j_tags.append(tgs) 1439 | is_first = True 1440 | chars = j_chars 1441 | tag = j_tags 1442 | for chs, tgs in zip(chars, tag): 1443 | assert len(chs) == len(tgs) 1444 | p_line = '' 1445 | r_line = '' 1446 | for ch, tg in zip(chs, tgs): 1447 | r_line += ' ' + ch 1448 | if tg == 'I': 1449 | if ch == '.' or (ch >= '0' and ch <= '9'): 1450 | p_line += ch 1451 | else: 1452 | p_line += ' ' + ch 1453 | elif tg == 'B': 1454 | p_line += ' ' + ch 1455 | elif tg == 'T': 1456 | sent_seg = True 1457 | p_line += ' ' + ch + '' 1458 | r_line += '' 1459 | elif tg == 'U': 1460 | sent_seg = True 1461 | p_line += ch + '' 1462 | r_line += '' 1463 | elif len(ch.strip()) > 0: 1464 | p_line += ' ' + ch 1465 | sub_out.append(p_line.strip()) 1466 | sub_raw_out.append(r_line.strip()) 1467 | out.append(sub_out) 1468 | raw_out.append(sub_raw_out) 1469 | out[0][-1].rstrip('') 1470 | raw_out[0][-1].rstrip('') 1471 | if sent_seg: 1472 | out = split_sent(out[0], '') 1473 | raw_out = split_sent(raw_out[0], '') 1474 | return out, raw_out 1475 | 1476 | 1477 | def trim_output(out, length): 1478 | assert len(out) == len(length) 1479 | trimmed_out = [] 1480 | for item, l in zip(out, length): 1481 | trimmed_out.append(item[:l]) 1482 | return trimmed_out 1483 | 1484 | 1485 | def generate_trans_out(x, idx2char): 1486 | out = '' 1487 | for i in x: 1488 | if i == 3: 1489 | out += ' ' 1490 | elif i in idx2char: 1491 | out += idx2char[i] 1492 | if '<#>' in out: 1493 | out = out[:out.index('<#>')] 1494 | out = out.replace(' ', ' ') 1495 | out = out.replace(' ', ' ') 1496 | return out 1497 | 1498 | 1499 | def generate_sent_out(raw, predictions): 1500 | out = [] 1501 | line = '' 1502 | assert len(raw) == len(predictions) 1503 | for ch, tag in zip(raw, predictions): 1504 | line += ch 1505 | if tag == 1: 1506 | line = line.strip() 1507 | out.append(line) 1508 | line = '' 1509 | if len(line) > 0: 1510 | line = line.strip() 1511 | out.append(line) 1512 | return out 1513 | 1514 | 1515 | def viterbi(max_scores, max_scores_pre, length, batch_size): 1516 | best_paths = [] 1517 | for m in range(batch_size): 1518 | path = [] 1519 | last_max_node = np.argmax(max_scores[m][length[m] - 1]) 1520 | path.append(last_max_node) 1521 | for t in range(1, length[m])[::-1]: 1522 | last_max_node = max_scores_pre[m][t][last_max_node] 1523 | path.append(last_max_node) 1524 | path = path[::-1] 1525 | best_paths.append(path) 1526 | return best_paths 1527 | 1528 | 1529 | def get_new_chars(path, char2idx, is_space): 1530 | new_chars = set() 1531 | for line in codecs.open(path, 'rb', encoding='utf-8'): 1532 | line = line.strip() 1533 | if is_space == 'sea': 1534 | line = pre_token(line) 1535 | for ch in line: 1536 | if ch not in char2idx: 1537 | new_chars.add(ch) 1538 | return new_chars 1539 | 1540 | 1541 | def get_valid_chars(chars, emb_path): 1542 | valid_chars = [] 1543 | total = [] 1544 | for line in codecs.open(emb_path, 'rb', encoding='utf-8'): 1545 | line = line.strip() 1546 | sets = line.split(' ') 1547 | total.append(sets[0]) 1548 | for ch in chars: 1549 | if ch in total: 1550 | valid_chars.append(ch) 1551 | return valid_chars 1552 | 1553 | 1554 | def get_new_embeddings(new_chars, emb_dim, emb_path): 1555 | assert os.path.isfile(emb_path) 1556 | emb = {} 1557 | new_emb = [] 1558 | for line in codecs.open(emb_path, 'rb', encoding='utf-8'): 1559 | line = line.strip() 1560 | sets = line.split(' ') 1561 | emb[sets[0]] = np.asarray(sets[1:], dtype='float32') 1562 | if '' not in emb: 1563 | unk = np.random.uniform(-math.sqrt(float(3) / emb_dim), math.sqrt(float(3) / emb_dim), emb_dim) 1564 | emb[''] = np.asarray(unk, dtype='float32') 1565 | for ch in new_chars: 1566 | if ch in emb: 1567 | new_emb.append(emb[ch]) 1568 | else: 1569 | new_emb.append(emb['']) 1570 | return new_emb 1571 | 1572 | 1573 | def update_char_dict(char2idx, new_chars, unk_chars_idx, valid_chars=None): 1574 | l_quos = ['"', '«', '„'] 1575 | r_quos = ['"', '»', '“'] 1576 | l_quos = [unicode(ch) for ch in l_quos] 1577 | r_quos = [unicode(ch) for ch in r_quos] 1578 | sub_dict = {} 1579 | old_chars = char2idx.keys() 1580 | dim = len(char2idx) + 10 1581 | if valid_chars is not None: 1582 | for ch in valid_chars: 1583 | if char2idx[ch] in unk_chars_idx: 1584 | unk_chars_idx.remove(ch) 1585 | for char in new_chars: 1586 | if char not in char2idx and len(char.strip()) > 0: 1587 | char2idx[char] = dim 1588 | if valid_chars is None or char not in valid_chars: 1589 | unk_chars_idx.append(dim) 1590 | dim += 1 1591 | idx2char = {k: v for v, k in char2idx.items()} 1592 | for ch in new_chars: 1593 | if ch in l_quos: 1594 | for l_ch in l_quos: 1595 | if l_ch in old_chars: 1596 | sub_dict[char2idx[ch]] = char2idx[l_ch] 1597 | if char2idx[ch] in unk_chars_idx: 1598 | unk_chars_idx.remove(char2idx[ch]) 1599 | break 1600 | elif ch in r_quos: 1601 | for r_ch in r_quos: 1602 | if r_ch in old_chars: 1603 | sub_dict[char2idx[ch]] = char2idx[r_ch] 1604 | if char2idx[ch] in unk_chars_idx: 1605 | unk_chars_idx.remove(char2idx[ch]) 1606 | break 1607 | return char2idx, idx2char, unk_chars_idx, sub_dict 1608 | 1609 | 1610 | def get_new_grams(path, gram2idx, is_raw=False, is_space=True): 1611 | raw = [] 1612 | i = 0 1613 | for line in codecs.open(path, 'rb', encoding='utf-8'): 1614 | line = line.strip() 1615 | if is_space == 'sea': 1616 | line = pre_token(line) 1617 | if i == 0 or is_raw: 1618 | raw.append(line) 1619 | i += 1 1620 | if len(line) > 0: 1621 | i += 1 1622 | else: 1623 | i = 0 1624 | new_grams = [] 1625 | for g_dic in gram2idx: 1626 | new_g = [] 1627 | if is_space == 'sea': 1628 | n = len(g_dic.keys()[0].split('_')) 1629 | else: 1630 | n = 0 1631 | for k in g_dic.keys(): 1632 | if '' not in k: 1633 | n = len(k) 1634 | break 1635 | grams = ngrams(raw, n, is_space) 1636 | for g in grams: 1637 | if g not in g_dic: 1638 | new_g.append(g) 1639 | new_grams.append(new_g) 1640 | return new_grams 1641 | 1642 | 1643 | def printer(raw, tagged, multi_out, outpath, sent_seg, form='conll'): 1644 | assert len(tagged) == len(multi_out) 1645 | validator(raw, multi_out) 1646 | wt = codecs.open(outpath, 'w', encoding='utf-8') 1647 | if form == 'conll': 1648 | if not sent_seg: 1649 | for raw_t, tagged_t, multi_t in zip(raw, tagged, multi_out): 1650 | if len(multi_t) > 0: 1651 | wt.write('#sent_raw: ' + raw_t + '\n') 1652 | wt.write('#sent_tok: ' + tagged_t + '\n') 1653 | idx = 1 1654 | tgs = multi_t.split(' ') 1655 | pl = '' 1656 | for _ in range(8): 1657 | pl += '\t' + '_' 1658 | for tg in tgs: 1659 | if '!#!' in tg: 1660 | segs = tg.split('!#!') 1661 | wt.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + segs[0] + pl + '\n') 1662 | else: 1663 | wt.write(str(idx) + '\t' + tg + pl + '\n') 1664 | idx += 1 1665 | wt.write('\n') 1666 | else: 1667 | for tagged_t, multi_t in zip(tagged, multi_out): 1668 | if len(tagged_t.strip()) > 0: 1669 | wt.write('#sent_tok: '+ tagged_t + '\n') 1670 | idx = 1 1671 | tgs = multi_t.split(' ') 1672 | pl = '' 1673 | for _ in range(8): 1674 | pl += '\t' + '_' 1675 | for tg in tgs: 1676 | if '!#!' in tg: 1677 | segs = tg.split('!#!') 1678 | wt.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + segs[0] + pl + '\n') 1679 | else: 1680 | wt.write(str(idx) + '\t' + tg + pl + '\n') 1681 | idx += 1 1682 | wt.write('\n') 1683 | else: 1684 | for tg in tagged: 1685 | wt.write(tg + '\n') 1686 | wt.close() 1687 | 1688 | 1689 | def biased_out(prediction, bias): 1690 | out = [] 1691 | b_pres = [] 1692 | for pre in prediction: 1693 | b_pres.append(pre[:,0] - pre[:,1]) 1694 | props = np.concatenate(b_pres) 1695 | props = np.sort(props)[::-1] 1696 | idx = int(bias*len(props)) 1697 | if idx == len(props): 1698 | idx -= 1 1699 | th = props[idx] 1700 | print 'threshold: ', th, 1 / (1 + np.exp(-th)) 1701 | for pre in b_pres: 1702 | pre[pre >= th] = 0 1703 | pre[pre != 0] = 1 1704 | out.append(pre) 1705 | return out 1706 | 1707 | 1708 | def to_one_hot(y, nb_classes=None): 1709 | '''Convert class vector (integers from 0 to nb_classes) to binary class matrix, for use with categorical_crossentropy. 1710 | # Arguments 1711 | y: class vector to be converted into a matrix 1712 | nb_classes: total number of classes 1713 | # Returns 1714 | A binary matrix representation of the input. 1715 | ''' 1716 | if not nb_classes: 1717 | nb_classes = np.max(y)+1 1718 | Y = np.zeros((len(y), nb_classes)) 1719 | for i in range(len(y)): 1720 | Y[i, y[i]] = 1. 1721 | return Y 1722 | 1723 | 1724 | def validator(raw, generated): 1725 | raw_l = ''.join(raw) 1726 | raw_l = ''.join(raw_l.split()) 1727 | for g in generated: 1728 | g_tokens = g.split(' ') 1729 | j = 0 1730 | while j < len(g_tokens): 1731 | if '!#!' in g_tokens[j]: 1732 | segs = g_tokens[j].split('!#!') 1733 | c_t = int(segs[1]) 1734 | r_seg = ''.join(segs[0].split()) 1735 | l_w = len(r_seg) 1736 | if r_seg == raw_l[:l_w]: 1737 | raw_l = raw_l[l_w:] 1738 | raw_l = raw_l.strip() 1739 | else: 1740 | raise Exception('Error: unmatch...') 1741 | j += c_t 1742 | else: 1743 | r_seg = ''.join(g_tokens[j].split()) 1744 | l_w = len(r_seg) 1745 | if r_seg == raw_l[:l_w]: 1746 | raw_l = raw_l[l_w:] 1747 | raw_l = raw_l.strip() 1748 | else: 1749 | print r_seg 1750 | print raw_l[:l_w] 1751 | print '' 1752 | raise Exception('Error: unmatch...') 1753 | j += 1 1754 | 1755 | 1756 | def mlp_post(raw, prediction, is_space=False, form='mlp1'): 1757 | assert len(raw) == len(prediction) 1758 | out = [] 1759 | for r_l, p_l in zip(raw, prediction): 1760 | st = '' 1761 | rtokens = r_l.split() 1762 | ptokens = p_l.split(' ') 1763 | purged = [] 1764 | for pt in ptokens: 1765 | purged.append(pt.strip()) 1766 | ptokens = purged 1767 | ptokens_str = ''.join(ptokens) 1768 | assert ''.join(rtokens) == ''.join(ptokens_str.split()) 1769 | if form == 'mlp1': 1770 | if is_space == 'sea': 1771 | for p_t in ptokens: 1772 | st += p_t.replace(' ', '_') + ' ' 1773 | else: 1774 | while rtokens and ptokens: 1775 | if rtokens[0] == ptokens[0]: 1776 | st += ptokens[0] + ' ' 1777 | rtokens.pop(0) 1778 | ptokens.pop(0) 1779 | else: 1780 | if len(rtokens[0]) <= len(ptokens[0]): 1781 | assert ptokens[0][:len(rtokens[0])] == rtokens[0] 1782 | st += rtokens[0] + ' ' 1783 | ptokens[0] = ptokens[0][len(rtokens[0]):].strip() 1784 | rtokens.pop(0) 1785 | else: 1786 | can = '' 1787 | while can != rtokens[0] and ptokens: 1788 | can += ptokens[0] 1789 | st += ptokens[0] + '\\\\' 1790 | ptokens.pop(0) 1791 | st = st[:-2] + ' ' 1792 | rtokens.pop(0) 1793 | else: 1794 | for p_t in ptokens: 1795 | st += p_t + ' ' 1796 | out.append(st.strip()) 1797 | return out -------------------------------------------------------------------------------- /transducer_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import tensorflow.contrib.legacy_seq2seq as seq2seq 4 | import toolbox 5 | import batch as Batch 6 | import numpy as np 7 | import cPickle as pickle 8 | import evaluation 9 | 10 | import os 11 | 12 | class Seq2seq(object): 13 | 14 | def __init__(self, trained_model): 15 | self.en_vec = None 16 | self.de_vec = None 17 | self.trans_output = None 18 | self.trans_labels = None 19 | self.feed_previouse = None 20 | self.trans_l_rate = None 21 | self.trained = trained_model 22 | self.decode_step = None 23 | self.encode_step = None 24 | 25 | def define(self, char_num, rnn_dim, emb_dim, max_x, max_y, write_trans_model=True): 26 | self.decode_step = max_y 27 | self.encode_step = max_x 28 | self.en_vec = [tf.placeholder(tf.int32, [None], name='en_input' + str(i)) for i in range(max_x)] 29 | self.trans_labels = [tf.placeholder(tf.int32, [None], name='de_input' + str(i)) for i in range(max_y)] 30 | weights = [tf.cast(tf.sign(ot_t), tf.float32) for ot_t in self.trans_labels] 31 | self.de_vec = [tf.zeros_like(self.trans_labels[0], tf.int32)] + self.trans_labels[:-1] 32 | self.feed_previous = tf.placeholder(tf.bool) 33 | self.trans_l_rate = tf.placeholder(tf.float32, [], name='learning_rate') 34 | seq_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, state_is_tuple=True) 35 | self.trans_output, states = seq2seq.embedding_attention_seq2seq(self.en_vec, self.de_vec, seq_cell, char_num, 36 | char_num, emb_dim, feed_previous=self.feed_previous) 37 | 38 | loss = seq2seq.sequence_loss(self.trans_output, self.trans_labels, weights) 39 | optimizer = tf.train.AdagradOptimizer(learning_rate=self.trans_l_rate) 40 | 41 | params = tf.trainable_variables() 42 | gradients = tf.gradients(loss, params) 43 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5.0) 44 | self.trans_train = optimizer.apply_gradients(zip(clipped_gradients, params)) 45 | 46 | self.saver = tf.train.Saver() 47 | 48 | if write_trans_model: 49 | param_dic = {} 50 | param_dic['char_num'] = char_num 51 | param_dic['rnn_dim'] = rnn_dim 52 | param_dic['emb_dim'] = emb_dim 53 | param_dic['max_x'] = max_x 54 | param_dic['max_y'] = max_y 55 | # print param_dic 56 | f_model = open(self.trained + '_model', 'w') 57 | pickle.dump(param_dic, f_model) 58 | f_model.close() 59 | 60 | def train(self, t_x, t_y, v_x, v_y, lrv, char2idx, sess, epochs, batch_size=10, reset=True): 61 | 62 | idx2char = {k: v for v, k in char2idx.items()} 63 | v_y_g = [np.trim_zeros(v_y_t) for v_y_t in v_y] 64 | gold_out = [toolbox.generate_trans_out(v_y_t, idx2char) for v_y_t in v_y_g] 65 | 66 | best_score = 0 67 | 68 | if reset or not os.path.isfile(self.trained + '_weights.index'): 69 | for epoch in range(epochs): 70 | Batch.train_seq2seq(sess, model=self.en_vec + self.trans_labels, decoding=self.feed_previous, 71 | batch_size=batch_size, config=self.trans_train, lr=self.trans_l_rate, lrv=lrv, 72 | data=[t_x] + [t_y]) 73 | pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output, 74 | decoding=self.feed_previous, decode_len=self.decode_step, 75 | data=[v_x], argmax=True, batch_size=100) 76 | pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred] 77 | 78 | c_scores = evaluation.trans_evaluator(gold_out, pred_out) 79 | 80 | print 'epoch: %d' % (epoch + 1) 81 | 82 | print 'ACC: %f' % c_scores[0] 83 | print 'Token F score: %f' % c_scores[1] 84 | 85 | if c_scores[1] > best_score: 86 | best_score = c_scores[1] 87 | self.saver.save(sess, self.trained + '_weights', write_meta_graph=False) 88 | 89 | if best_score > 0 or not reset: 90 | self.saver.restore(sess, self.trained + '_weights') 91 | 92 | def tag(self, t_x, char2idx, sess, batch_size=100): 93 | 94 | t_x = [t_x_t[:self.encode_step] for t_x_t in t_x] 95 | t_x = toolbox.pad_zeros(t_x, self.encode_step) 96 | 97 | idx2char = {k: v for v, k in char2idx.items()} 98 | 99 | pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output, decoding=self.feed_previous, 100 | decode_len=self.decode_step, data=[t_x], argmax=True, batch_size=batch_size) 101 | pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred] 102 | 103 | return pred_out 104 | 105 | 106 | --------------------------------------------------------------------------------