├── Dependencies.md ├── LICENSE.txt ├── README.md ├── mrs ├── eval_edm.py ├── extract_ace_mrs.py ├── extract_data_lexicon.py ├── extract_dmrs_lines.py ├── extract_eds_lines.py ├── extract_erg_lexicon.py ├── extract_sdp_dmrs_lines.py ├── extract_sdp_eds_lines.py ├── graph.py ├── linear_to_mrs.py ├── read_mrs.py ├── sentence.py ├── stanford_to_linear.py └── util.py ├── rnn ├── data_utils.py ├── parser.py ├── seq2seq.py ├── seq2seq_decoders.py ├── seq2seq_helpers.py └── seq2seq_model.py └── scripts ├── export-deepbank.sh ├── extract-deepbank-sdp.sh ├── extract-deepbank.sh ├── find_bucket_sizes.py ├── preprocess.sh └── run-ace-deepbank.sh /Dependencies.md: -------------------------------------------------------------------------------- 1 | This file specifies the external software required for running this code and obtaining and processing the data. 2 | Note that environment variables have to be set. 3 | The implementation is in Python 2.7. 4 | 5 | ## Tensorflow 0.11 or 0.12. 6 | Earlier or later versions may not be compatible. 7 | https://www.tensorflow.org/ 8 | 9 | ## Stanford CoreNLP 3.5.2. 10 | Requires Java 1.8. 11 | http://nlp.stanford.edu/software/stanford-corenlp-full-2015-04-20.zip. 12 | Set 13 | 14 | JAVA=java 15 | STANFORD_NLP=stanford-corenlp-full-2015-04-20 16 | 17 | ## ERG 1214 18 | Includes Redwoods and DeepBank treebanks. 19 | 20 | ERG_DIR=erg1214 21 | svn checkout http://svn.delph-in.net/erg/tags/1214 $ERG_DIR 22 | 23 | ## LOGON 24 | Contains code to extract graph representations from the ERG treebanks. 25 | 26 | LOGONROOT=logon 27 | svn checkout http://svn.emmtee.net/trunk $LOGONROOT 28 | 29 | Include in your .bashrc: 30 | 31 | $LOGONROOT=logon 32 | if [ -f ${LOGONROOT}/dot.bashrc ]; then 33 | . ${LOGONROOT}/dot.bashrc 34 | fi 35 | 36 | ## PyDelphin 37 | (D)MRS conversion tools. 38 | https://github.com/delph-in/pydelphin 39 | 40 | ## ACE 41 | ERG parser. 42 | http://sweaglesw.org/linguistics/ace/ 43 | 44 | Download ACE: 45 | http://sweaglesw.org/linguistics/ace/download/ace-0.9.25-x86-64.tar.gz 46 | 47 | as well as the ERG 1214 grammar image (unzip and place in $ERG_DIR): 48 | http://sweaglesw.org/linguistics/ace/download/erg-1214-x86-64-0.9.25.dat.bz2 49 | 50 | ## Smatch 51 | Graph parser evaluation. 52 | https://github.com/snowblink14/smatch 53 | 54 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepDeepParser 2 | 3 | Code and data preparation scripts for the paper [Robust Incremental Neural Semantic Graph Parsing](https://arxiv.org/abs/1704.07092), Jan Buys and Phil Blunsom, ACL 2017. 4 | 5 | ## Prerequisites 6 | See Dependencies.md 7 | 8 | ## Data preparation 9 | 10 | To extract DMRS and EDS graphs from DeepBank (requires the LOGON environment and full original data): 11 | 12 | scripts/extract-deepbank.sh 13 | 14 | To extract DMRS and EDS graphs from the SDP release of DeepBank (does not require the LOGON environment): 15 | 16 | scripts/extract-deepbank-sdp.sh 17 | 18 | Pre-processing (constructs lexicon, runs Stanford CoreNLP, constructs graph linearizations/oracle transition sequences): 19 | 20 | scripts/preprocess.sh 21 | 22 | ## Training 23 | 24 | Train the transition-based parser: 25 | 26 | python rnn/parser.py --decode_dev --decode_train --use_hard_attention_arc_eager_decoder --predict_span_end_pointers --data_dir [data_dir] --train_dir [working_dir] --embedding_vectors [embedding_file] --train_name train --dev_name dev --singleton_keep_prob 0.5 --size 256 --input_embedding_size 256 --output_embedding_size 128 --tag_embedding_size 32 --use_encoder_tags --input_drop_prob 0.3 --output_drop_prob 0.3 --initialize_word_vectors 27 | 28 | where `data_dir` contains the pre-processed files for training. 29 | 30 | Word embeddings are initialized with pre-trained structured skip-gram embeddings: [sskip.100.vectors](https://drive.google.com/file/d/0B8nESzOdPhLsdWF2S1Ayb1RkTXc/view?usp=sharing) 31 | 32 | ## Decoding 33 | 34 | A pre-trained EDS model is available [here](https://drive.google.com/open?id=0BzlDJzogHw4fdGMtazJqb1RHWmc) 35 | 36 | Decode with the parser (transition-based model): 37 | 38 | python rnn/parser.py --decode --decode_dev --use_hard_attention_arc_eager_decoder --predict_span_end_pointers --data_dir [data_dir] --train_dir [working_dir] --dev_name [filename] --size 256 --input_embedding_size 256 --output_embedding_size 128 --tag_embedding_size 32 --use_encoder_tags --input_drop_prob 0.3 --output_drop_prob 0.3 --checkpoint_file model.ckpt 39 | 40 | where `data_dir` contains the pre-processed files for decoding (`filename.en`, `filename.ne`, `filename.pos`) as well as a `buckets` file, and `working_dir` contains the model (checkpoint) file. 41 | 42 | Suggested `buckets`: 43 | 44 | 24 77 45 | 37 133 46 | 52 201 47 | 48 | ### Post-processing 49 | 50 | Restore lemmas and constants and convert to output graph formats. 51 | 52 | python mrs/linear_to_mrs.py [data_dir] [filename] [working_dir] output -arceagerbuffershift -unlex -withendspan 53 | 54 | -------------------------------------------------------------------------------- /mrs/eval_edm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Computes EDM F1 scores.""" 17 | 18 | import sys 19 | 20 | def strip_span_ends(triple): 21 | ts = triple.split(' ') 22 | ts[0] = ts[0].split(':')[0] 23 | if len(ts) >= 3 and ':' in ts[2]: 24 | ts[2] = ts[2].split(':')[0] 25 | return ' '.join(ts) 26 | 27 | def list_predicate_spans(triples): 28 | spans = [] 29 | for triple in triples: 30 | if len(triple.split(' ')) > 1 and triple.split(' ')[1] == 'NAME': 31 | spans.append(triple.split(' ')[0]) 32 | return spans 33 | 34 | def inc_end_spans(spans): 35 | new_spans = [span.split(':')[0] + ':' + str(int(span.split(':')[1])+1) 36 | for span in spans] 37 | return new_spans 38 | 39 | def inc_start_spans(spans): 40 | new_spans = [str(int(span.split(':')[0]) + 1) + ':' + span.split(':')[1] 41 | for span in spans] 42 | return new_spans 43 | 44 | def dec_end_spans(spans): 45 | new_spans = [span.split(':')[0] + ':' + str(int(span.split(':')[1])-1) 46 | for span in spans] 47 | return new_spans 48 | 49 | def compute_f1(gold_set, predicted_set, inref, verbose=False, 50 | span_starts_only=False, exclude_nones=False, no_cargs=False, 51 | no_predicates=False, only_predicates=False, predicates_no_span=False): 52 | total_gold = 0.0 53 | total_predicted = 0.0 54 | total_correct = 0.0 55 | none_count = 0.0 56 | 57 | for k, line1 in enumerate(gold_set): 58 | line2 = predicted_set[k] 59 | triples1 = [t.strip() for t in line1.split(';')] 60 | triples2 = [] if line2.strip() == 'NONE' else [t.strip() for t in line2.split(';')] 61 | if triples2 == []: 62 | none_count += 1 63 | if inref is not None and inref[k].strip() == 'NONE': 64 | triples1 = [] 65 | triples2 = [] 66 | 67 | gold_spans = set(list_predicate_spans(triples1)) 68 | predicted_spans = list_predicate_spans(triples2) 69 | 70 | def replace_new_spans(new_spans): 71 | for i, new_span in enumerate(new_spans): 72 | old_span = predicted_spans[i] 73 | if old_span not in gold_spans and new_span in gold_spans: 74 | for j, triple in enumerate(triples2): 75 | triples2[j] = triple.replace(old_span, new_span) 76 | 77 | replace_new_spans(inc_end_spans(predicted_spans)) 78 | replace_new_spans(dec_end_spans(predicted_spans)) 79 | 80 | if span_starts_only: 81 | triples1 = [strip_span_ends(t) for t in triples1] 82 | 83 | if no_cargs: 84 | triples1 = filter(lambda x: (x.split(' ')[1] <> 'CARG' 85 | if len(x.split(' ')) > 2 else True), 86 | triples1) 87 | 88 | if only_predicates: 89 | triples1 = filter(lambda x: (x.split(' ')[1] == 'NAME' 90 | or x.split(' ')[1] == 'CARG' if len(x.split(' ')) > 2 else False), 91 | triples1) 92 | if predicates_no_span: 93 | triples1 = map(lambda x: x.split(' ')[2], triples1) 94 | elif no_predicates: 95 | triples1 = filter(lambda x: (x.split(' ')[1] <> 'NAME' 96 | and x.split(' ')[1] <> 'CARG' if len(x.split(' ')) > 2 else False), 97 | triples1) 98 | triples1 = set(triples1) 99 | if line2.strip() == 'NONE': 100 | if exclude_nones: 101 | triples1 = set() 102 | triples2 = set() 103 | else: 104 | if span_starts_only: 105 | triples2 = [strip_span_ends(t) for t in triples2] 106 | 107 | if no_cargs: 108 | triples2 = filter(lambda x: (x.split(' ')[1] <> 'CARG' 109 | if len(x.split(' ')) > 2 else True), 110 | triples2) 111 | 112 | if only_predicates: 113 | triples2 = filter(lambda x: (x.split(' ')[1] == 'NAME' 114 | or x.split(' ')[1] == 'CARG' if len(x.split(' ')) > 2 else False), 115 | triples2) 116 | if predicates_no_span: 117 | triples2 = map(lambda x: x.split(' ')[2], triples2) 118 | elif no_predicates: 119 | triples2 = filter(lambda x: (x.split(' ')[1] <> 'NAME' 120 | and x.split(' ')[1] <> 'CARG' if len(x.split(' ')) > 2 else False), 121 | triples2) 122 | triples2 = set(triples2) 123 | 124 | correct_triples = triples1.intersection(triples2) 125 | incorrect_predicted = triples2 - correct_triples 126 | missed_predicted = triples1 - correct_triples 127 | 128 | total_gold += len(triples1) 129 | total_predicted += len(triples2) 130 | total_correct += len(correct_triples) 131 | 132 | if total_predicted == 0 or total_gold == 0: 133 | print "F1: 0.0" 134 | return 135 | precision = total_correct/total_predicted 136 | recall = total_correct/total_gold 137 | f1 = 2*precision*recall/(precision+recall) 138 | 139 | if verbose: 140 | print 'Precision', precision 141 | print 'Recall', recall 142 | print 'F1-score: %.2f ' % (f1*100) 143 | 144 | 145 | if __name__=='__main__': 146 | assert len(sys.argv) >= 3 147 | # Assumes in2 may contain NONE. 148 | in1 = open(sys.argv[1], 'r').read().split('\n') # Gold 149 | in2 = open(sys.argv[2], 'r').read().split('\n') # Predicted 150 | if len(sys.argv) >= 4 and not sys.argv[3].startswith('-'): 151 | inref = open(sys.argv[3], 'r').read().split('\n') # Reference 152 | else: 153 | inref = None 154 | no_cargs = len(sys.argv) >= 4 and sys.argv[3] == "-nocarg" 155 | verbose = len(sys.argv) >= 4 and sys.argv[3] == "-verbose" 156 | 157 | print 'All' 158 | compute_f1(in1, in2, inref, verbose) 159 | 160 | print 'All, start spans only' 161 | compute_f1(in1, in2, inref, verbose, span_starts_only=True) 162 | 163 | print 'All, predicates only' 164 | compute_f1(in1, in2, inref, verbose, only_predicates=True) 165 | 166 | print 'All, predicates only, start spans only' 167 | compute_f1(in1, in2, inref, verbose, span_starts_only=True, only_predicates=True) 168 | 169 | print 'All, predicates only - without spans' 170 | compute_f1(in1, in2, inref, verbose, only_predicates=True, predicates_no_span=True) 171 | 172 | print 'All, relations only' 173 | compute_f1(in1, in2, inref, verbose, no_predicates=True) 174 | 175 | print 'All, relations only, start spans only' 176 | compute_f1(in1, in2, inref, verbose, span_starts_only=True, no_predicates=True) 177 | 178 | -------------------------------------------------------------------------------- /mrs/extract_ace_mrs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import sys 17 | 18 | import delphin.mrs as mrs 19 | import delphin.mrs.simplemrs as simplemrs 20 | import delphin.mrs.simpledmrs as simpledmrs 21 | import delphin.mrs.eds as eds 22 | 23 | import graph as mrs_graph 24 | 25 | if __name__=='__main__': 26 | assert len(sys.argv) >= 4, 'Invalid number of arguments (at least 3 required).' 27 | set_name = sys.argv[1] 28 | working_dir = sys.argv[2] + '/' 29 | extract_dmrs = sys.argv >= 4 and '-dmrs' in sys.argv[3:] 30 | extract_eds = sys.argv >= 4 and '-eds' in sys.argv[3:] 31 | assert extract_dmrs or extract_eds 32 | if extract_dmrs: 33 | mrs_type = 'dmrs' 34 | else: 35 | mrs_type = 'eds' 36 | 37 | extract_dmrs_files = sys.argv >= 4 and '-extract_dmrs_files' in sys.argv[3:] 38 | extract_mrs_files = sys.argv >= 4 and '-extract_mrs_files' in sys.argv[3:] 39 | ignore_unparsed = sys.argv >= 4 and '-ignore_unparsed' in sys.argv[3:] 40 | 41 | filename = working_dir + set_name + '.mrs' 42 | in_file = open(filename, 'r') 43 | 44 | mrs_strs = [] 45 | sentences = [] 46 | 47 | sentence_str = '' 48 | simple_mrs_str = '' 49 | state = 0 50 | 51 | for line in in_file: 52 | line = line.strip() 53 | if state == 0: 54 | if line.startswith('[ LTOP:'): 55 | simple_mrs_str = line + ' ' 56 | state = 1 57 | else: 58 | if line.startswith('SKIP:'): 59 | sentence_str = line[line.index(':') + 2:] 60 | if not ignore_unparsed: 61 | sentences.append(sentence_str) 62 | mrs_strs.append('') 63 | elif line.startswith('SENT:'): 64 | sentence_str = line[line.index(':') + 2:] 65 | elif state == 1: 66 | if line: 67 | simple_mrs_str += line + ' ' 68 | if line.startswith('HCONS:'): 69 | state = 0 70 | sentences.append(sentence_str) 71 | mrs_strs.append(simple_mrs_str) 72 | else: 73 | state = 0 74 | sentences.append(sentence_str) 75 | mrs_strs.append(simple_mrs_str) 76 | 77 | print len(sentences), "sentences" 78 | assert len(mrs_strs) == len(sentences) 79 | graphs = [] 80 | simple_strs = [] 81 | const_strs = [] 82 | hyphen_sent_strs = [] 83 | nohyphen_sent_strs = [] 84 | 85 | for i, mrs_str in enumerate(mrs_strs): 86 | token_inds = {} 87 | token_starts = [] 88 | token_ends = [] 89 | token_start = 0 90 | 91 | state = True 92 | counter = 0 93 | for k, char in enumerate(sentences[i]): 94 | if state: 95 | if char.isspace(): 96 | token_end = k 97 | token_inds[str(token_start) + ':' + str(token_end)] = counter 98 | token_starts.append(token_start) 99 | token_ends.append(token_end) 100 | counter += 1 101 | state = False 102 | elif not char.isspace(): 103 | state = True 104 | token_start = k 105 | if state: 106 | token_end = len(sentences[i]) 107 | token_inds[str(token_start) + ':' + str(token_end)] = counter 108 | token_starts.append(token_start) 109 | token_ends.append(token_end) 110 | 111 | if mrs_str: 112 | # Writes out mrs string to its own file. 113 | # This is required so that we can convert MRS to EDS using LOGON. 114 | if extract_mrs_files: 115 | mrs_file = open(working_dir + 'mrs/' + set_name + str(i) + '.mrs', 'w') 116 | mrs_file.write(mrs_str) 117 | mrs_file.close() 118 | 119 | # TODO Delphin passes utf-8 as encoding parameter, but somewhere it changes to ascii... 120 | # pydelphin/delphin/mrs/util.py:94 121 | # /usr/lib/python2.7/xml/etree/ElementTree.py 122 | simple_mrs_code = mrs_str.decode('utf-8', 'replace') 123 | mrs_str = simple_mrs_code.encode('ascii', 'replace') 124 | if extract_eds: 125 | mrx_str = mrs.convert(mrs_str, 'simplemrs', 'mrx') 126 | eds_object = mrs.mrx.loads(mrx_str) 127 | 128 | try: 129 | eds_str = eds.dumps(eds_object) 130 | except (KeyError, TypeError, IndexError) as err: 131 | if not ignore_unparsed: 132 | graphs.append(None) 133 | simple_strs.append('') 134 | const_strs.append('') 135 | hyphen_sent_strs.append('') 136 | nohyphen_sent_strs.append('') 137 | continue 138 | 139 | eds_lines = eds_str.split('\n') 140 | single_eds_str = eds_lines[0][1:eds_lines[0].index(':')] 141 | for line in eds_lines[1:]: 142 | if line.strip() <> '}': 143 | single_eds_str += ' ; ' + line.strip() 144 | 145 | graph = mrs_graph.parse_eds(single_eds_str) 146 | simple_dmrs_str = single_eds_str 147 | dmrs_const_str = '' 148 | else: 149 | dmrs_xml_str = mrs.convert(mrs_str, 'simplemrs', 'dmrx') 150 | dmrs_object = mrs.dmrx.loads(dmrs_xml_str) 151 | 152 | try: 153 | simple_dmrs_str = simpledmrs.dumps(dmrs_object) 154 | except (KeyError, TypeError) as err: 155 | if not ignore_unparsed: 156 | graphs.append(None) 157 | simple_strs.append('') 158 | const_strs.append('') 159 | hyphen_sent_strs.append('') 160 | nohyphen_sent_strs.append('') 161 | continue 162 | 163 | simple_mrs = simplemrs.loads_one(mrs_str) 164 | graph = mrs_graph.parse_dmrs(simple_dmrs_str) 165 | 166 | if graph.root_index == -1: 167 | graph.root_index = 0 168 | 169 | dmrs_const_str = '' 170 | # Add constants 171 | for ep in simple_mrs.eps(): 172 | if ep.args.has_key('CARG'): 173 | dmrs_const_str += (str(ep.lnk)[1:-1] + ' ' + str(ep.pred) + ' ' 174 | + ep.args['CARG'] + ' ') 175 | 176 | # Find head node 177 | found = False 178 | pred = str(ep.pred) 179 | for j, node in enumerate(graph.nodes): 180 | if node.ind == str(ep.lnk)[1:-1] and node.concept == pred: 181 | graph.nodes[j].constant = ep.args['CARG'] 182 | found = True 183 | if not found: 184 | print pred, str(ep.lnk)[1:-1] 185 | 186 | nohyphen_sent_str = sentences[i] 187 | for j in xrange(1, len(nohyphen_sent_str)-1): 188 | if (nohyphen_sent_str[j] == '-' and nohyphen_sent_str[j-1] <> ' ' and 189 | nohyphen_sent_str[j+1] <> ' '): 190 | nohyphen_sent_str = nohyphen_sent_str[:j] + ' ' + nohyphen_sent_str[j+1:] 191 | 192 | graph.find_span_tree(graph.root_index) 193 | graphs.append(graph) 194 | simple_strs.append(simple_dmrs_str) 195 | const_strs.append(dmrs_const_str) 196 | hyphen_sent_strs.append(sentences[i]) 197 | nohyphen_sent_strs.append(nohyphen_sent_str) 198 | else: 199 | if not ignore_unparsed: 200 | graphs.append(None) 201 | simple_strs.append('') 202 | const_strs.append('') 203 | hyphen_sent_strs.append('') 204 | nohyphen_sent_strs.append('') 205 | 206 | if extract_dmrs_files: 207 | # Writes out the dmrs's so that it can be used for training data. 208 | sent_out_file = open(working_dir + set_name + '.hraw', 'w') 209 | print working_dir + set_name + '.hraw' 210 | print len(hyphen_sent_strs) 211 | for sentence_str in hyphen_sent_strs: 212 | sent_out_file.write(sentence_str + '\n') 213 | sent_out_file.close() 214 | 215 | sent_out_file = open(working_dir + set_name + '.raw', 'w') 216 | for sentence_str in nohyphen_sent_strs: 217 | sent_out_file.write(sentence_str + '\n') 218 | sent_out_file.close() 219 | 220 | lin_out_file = open(working_dir + set_name + '.sdmrs', 'w') 221 | print len(simple_strs) 222 | for simple_dmrs_str in simple_strs: 223 | lin_out_file.write(simple_dmrs_str + '\n') 224 | lin_out_file.close() 225 | 226 | lin_out_file = open(working_dir + set_name + '.carg', 'w') 227 | for dmrs_const_str in const_strs: 228 | lin_out_file.write(dmrs_const_str + '\n') 229 | lin_out_file.close() 230 | 231 | # Writes out char-level EDM. 232 | edm_out_file = open(working_dir + set_name + '.' + mrs_type + '.edm', 'w') 233 | for graph in graphs: 234 | if graph is None: 235 | edm_out_file.write('NONE\n') 236 | else: 237 | edm_out_file.write(graph.edm_ch_str() + '\n') 238 | edm_out_file.close() 239 | 240 | # Writes out AMR for Smatch evaluation (modifies graph). 241 | amr_out_file = open(working_dir + set_name + '.' + mrs_type + '.amr', 'w') 242 | for graph in graphs: 243 | if graph is None or len(graph.nodes) == 0: 244 | amr_out_file.write('( n1 / _UNK )\n\n') 245 | else: 246 | graph.correct_concept_names() 247 | graph.correct_node_names() 248 | amr_out_file.write(graph.amr_graph_str(graph.root_index, 1) + '\n\n') 249 | amr_out_file.close() 250 | 251 | 252 | -------------------------------------------------------------------------------- /mrs/extract_data_lexicon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import sys 17 | 18 | import delphin.mrs as mrs 19 | import delphin.mrs.simplemrs as simplemrs 20 | import delphin.mrs.simpledmrs as simpledmrs 21 | 22 | import graph as mrs_graph 23 | import util as mrs_util 24 | 25 | if __name__=='__main__': 26 | assert len(sys.argv) >= 2 27 | input_dir = sys.argv[1] + '/' 28 | working_dir = sys.argv[2] + '/' 29 | out_dir = working_dir 30 | 31 | set_name = 'train' 32 | suffix = '.raw' 33 | 34 | erg_dir = erg_name + '/' 35 | erg_pred_map = mrs_util.read_lexicon(working_dir + 'predicates.erg.lexicon') 36 | erg_const_map = mrs_util.read_lexicon(working_dir + 'constants.erg.lexicon') 37 | 38 | sent_file = open(input_dir + set_name + suffix, 'r') 39 | sentences_raw = [sent for sent in sent_file.read().split('\n')[:-1]] 40 | 41 | sdmrs_file = open(input_dir + set_name + '.sdmrs', 'r') 42 | simple_dmrs_strs = [sent for sent in sdmrs_file.read().split('\n')[:-1]] 43 | carg_file = open(input_dir + set_name + '.carg', 'r') 44 | dmrs_carg_strs = [sent for sent in carg_file.read().split('\n')[:-1]] 45 | 46 | pred_dict = {} 47 | const_dict = {} 48 | 49 | for i, simple_dmrs_str in enumerate(simple_dmrs_strs): 50 | graph = mrs_graph.parse_dmrs(simple_dmrs_str, sentence_str=sentences_raw[i]) 51 | 52 | # Adds constants. 53 | carg_list = dmrs_carg_strs[i].split() 54 | carg_inds = carg_list[::3] 55 | carg_preds = carg_list[1::3] 56 | carg_values = carg_list[2::3] 57 | 58 | for ind, pred, const in zip(carg_inds, carg_preds, carg_values): 59 | # Finds head node of CARG. 60 | found = False 61 | if (pred[0] == '"' and pred[-1] == '"') or pred[0] == '_': 62 | continue 63 | for j, node in enumerate(graph.nodes): 64 | if node.ind == ind and node.concept == pred: 65 | graph.nodes[j].constant = const 66 | if const[0] =='"' and const[-1] == '"': 67 | const = const[1:-1] 68 | const = mrs_util.clean_punct(const) 69 | ind_start, ind_end = int(ind.split(':')[0]), int(ind.split(':')[1]) 70 | const_raw = sentences_raw[i][ind_start:ind_end] 71 | const_raw = mrs_util.clean_punct(const_raw) 72 | if const_raw == '': 73 | continue 74 | if const_dict.has_key(const_raw): 75 | if const_dict[const_raw].has_key(const): 76 | const_dict[const_raw][const] += 1 77 | else: 78 | const_dict[const_raw][const] = 1 79 | else: 80 | const_dict[const_raw] = {} 81 | const_dict[const_raw][const] = 1 82 | found = True 83 | 84 | # Extracts lexical dict. 85 | for node in graph.nodes: 86 | if (node.alignment and node.concept.startswith('_') 87 | and '/' not in node.concept): 88 | ind_start = int(node.ind.split(':')[0]) 89 | ind_end = int(node.ind.split(':')[1]) 90 | pred = node.concept[:node.concept.index('_', 1)] 91 | pred_raw = sentences_raw[i][ind_start:ind_end] 92 | pred_raw = mrs_util.clean_punct(pred_raw) 93 | if pred_raw == '': 94 | continue 95 | if pred_raw[0].isupper(): # lowercase if only first letter is upper 96 | if pred_raw[0].lower() + pred_raw[1:] == pred_raw.lower(): 97 | pred_raw = pred_raw.lower() 98 | if pred_dict.has_key(pred_raw): 99 | if pred_dict[pred_raw].has_key(pred): 100 | pred_dict[pred_raw][pred] += 1 101 | else: 102 | pred_dict[pred_raw][pred] = 1 103 | else: 104 | pred_dict[pred_raw] = {} 105 | pred_dict[pred_raw][pred] = 1 106 | 107 | # Extracts 1-best map, disambiguate with the ERG. 108 | pred_map = {} 109 | for pred_raw, dic in pred_dict.iteritems(): 110 | max_count = max(dic.values()) 111 | found_pred = False 112 | if erg_pred_map.has_key(pred_raw): 113 | erg_pred = erg_pred_map[pred_raw] 114 | if dic.has_key(erg_pred) and dic[erg_pred] == max_count: 115 | pred_map[pred_raw] = erg_pred 116 | found_pred = True 117 | for pred, count in dic.iteritems(): 118 | if found_pred: 119 | break 120 | if count == max_count: 121 | pred_map[pred_raw] = pred 122 | found_pred = True 123 | 124 | for pred_raw, pred in erg_pred_map.iteritems(): 125 | if not pred_map.has_key(pred_raw): 126 | pred_map[pred_raw] = pred 127 | 128 | const_map = {} 129 | for const_raw, dic in const_dict.iteritems(): 130 | max_count = max(dic.values()) 131 | found_const = False 132 | if erg_const_map.has_key(const_raw): 133 | erg_const = erg_const_map[const_raw] 134 | if dic.has_key(erg_const) and dic[erg_const] == max_count: 135 | const_map[const_raw] = erg_const 136 | found_const = True 137 | for const, count in dic.iteritems(): 138 | if found_const: 139 | break 140 | if count == max_count: 141 | const_map[const_raw] = const 142 | found_const = True 143 | 144 | for const_raw, const in erg_const_map.iteritems(): 145 | if not const_map.has_key(const_raw): 146 | const_map[const_raw] = const 147 | 148 | pred_out_file = open(out_dir + 'predicates.lexicon', 'w') 149 | for orth, pred in pred_map.iteritems(): 150 | pred_out_file.write(orth + '\n' + pred + '\n') 151 | pred_out_file.close() 152 | 153 | const_out_file = open(out_dir + 'constants.lexicon', 'w') 154 | for orth, const in const_map.iteritems(): 155 | const_out_file.write(orth + '\n' + const + '\n') 156 | const_out_file.close() 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /mrs/extract_dmrs_lines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import gzip 17 | import sys 18 | 19 | import delphin.mrs as mrs 20 | import delphin.mrs.simplemrs as simplemrs 21 | import delphin.mrs.simpledmrs as simpledmrs 22 | 23 | """Reads in deepbank's export format and converts it to DMRS.""" 24 | 25 | if __name__=='__main__': 26 | assert len(sys.argv) >= 4, 'Invalid number of arguments (3 or 4 required).' 27 | output_dir = sys.argv[1] + '/' 28 | filename = sys.argv[2] 29 | set_name = sys.argv[3] 30 | 31 | # Default export format includes tokenization, mrs and eds. 32 | # Alternatively, only mrs can be included. 33 | mrs_only = len(sys.argv) >= 5 and sys.argv[4] == '--mrs_only' 34 | 35 | sentence_str = '' 36 | simple_mrs_str = '' 37 | 38 | in_file = gzip.open(filename, 'r') 39 | state = 0 40 | for line in in_file: 41 | if state == 0 and line[0] == '[' and '`' in line and '\'' in line: 42 | # State 0: Input sentence. 43 | sentence_str = line[line.index('`')+1:line.rindex('\'')] 44 | state = 1 45 | elif state == 1 and line[0] == '<': 46 | state = 2 47 | elif state == 2: 48 | # State 2: PTB tokenization 49 | if line[0] == '>': 50 | state = 3 51 | elif ((line.strip().startswith('[ TOP:') and state == 3) 52 | or (line.strip().startswith('[ TOP:') and state == 1 and mrs_only)): 53 | # State 4: MRS 54 | simple_mrs_str = line.strip() + ' ' 55 | state = 4 56 | elif state == 4: 57 | if line[0] == '{' or (line.strip() == '' and mrs_only): 58 | # EDS 59 | state = 5 60 | else: 61 | simple_mrs_str += line.strip() + ' ' 62 | 63 | simple_mrs_code = simple_mrs_str.decode('utf-8', 'replace') 64 | simple_mrs_str = simple_mrs_code.encode('ascii', 'replace') 65 | 66 | dmrs_xml_str = mrs.convert(simple_mrs_str, 'simplemrs', 'dmrx') 67 | dmrs_object = mrs.dmrx.loads(dmrs_xml_str) 68 | simple_dmrs_str = simpledmrs.dumps(dmrs_object) # TODO this can give an error 69 | mrs_object = simplemrs.loads_one(simple_mrs_str) 70 | 71 | dmrs_const_str = '' 72 | 73 | # Adds constants. 74 | for ep in mrs_object.eps(): 75 | if ep.args.has_key('CARG'): 76 | dmrs_const_str += (str(ep.lnk)[1:-1] + ' ' + str(ep.pred) + ' ' 77 | + ep.args['CARG'] + ' ') 78 | 79 | hyphen_sentence_str = sentence_str 80 | # Removes in-word hyphens in sentence. 81 | for i in xrange(1, len(sentence_str)-1): 82 | if (sentence_str[i] == '-' and sentence_str[i-1] <> ' ' and 83 | sentence_str[i+1] <> ' '): 84 | sentence_str = sentence_str[:i] + ' ' + sentence_str[i+1:] 85 | 86 | sent_out_file = open(output_dir + set_name + '.hraw', 'a') 87 | sent_out_file.write(hyphen_sentence_str + '\n') 88 | sent_out_file.close() 89 | 90 | sent_out_file = open(output_dir + set_name + '.raw', 'a') 91 | sent_out_file.write(sentence_str + '\n') 92 | sent_out_file.close() 93 | 94 | lin_out_file = open(output_dir + set_name + '.sdmrs', 'a') 95 | lin_out_file.write(simple_dmrs_str + '\n') 96 | lin_out_file.close() 97 | 98 | lin_out_file = open(output_dir + set_name + '.carg', 'a') 99 | lin_out_file.write(dmrs_const_str + '\n') 100 | lin_out_file.close() 101 | 102 | -------------------------------------------------------------------------------- /mrs/extract_eds_lines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | """Reads in deepbank's export format and converts it to various outputs.""" 18 | 19 | import gzip 20 | import sys 21 | 22 | if __name__=='__main__': 23 | assert len(sys.argv) == 4, 'Invalid number of arguments (3 required).' 24 | output_dir = sys.argv[1] + '/' 25 | 26 | filename = sys.argv[2] 27 | set_name = sys.argv[3] 28 | in_file = gzip.open(filename, 'r') 29 | 30 | sentence_str = '' 31 | eds_str = '' 32 | 33 | tokens = [] 34 | token_inds = {} 35 | token_start = [] 36 | token_end = [] 37 | 38 | state = 0 39 | for line in in_file: 40 | if state == 0 and line[0] == '[' and '`' in line and '\'' in line: 41 | sentence_str = line[line.index('`')+1:line.rindex('\'')] 42 | state = 1 43 | elif state == 1 and line[0] == '<': 44 | state = 2 45 | elif state == 2: 46 | if line[0] == '>': 47 | state = 3 48 | else: 49 | items = line.strip().split(', ') 50 | tokens.append(items[5].strip()[1:-1]) 51 | ind = items[3].strip()[1:-1] 52 | token_inds[ind] = len(token_inds) 53 | token_start.append(int(ind.split(':')[0])) 54 | token_end.append(int(ind.split(':')[1])) 55 | elif state == 3 and line[0] == '{' and ':' in line: 56 | eds_str = line[1:line.index(':')] 57 | state = 4 58 | elif state == 4: 59 | if line[0] == '}': 60 | state = 5 61 | else: 62 | line = line.strip() 63 | eds_str += ' ; ' + line 64 | 65 | sent_out_file = open(output_dir + set_name + '.raw', 'a') 66 | sent_out_file.write(sentence_str + '\n') 67 | sent_out_file.close() 68 | 69 | lin_out_file = open(output_dir + set_name + '.eds', 'a') 70 | lin_out_file.write(eds_str + '\n') 71 | lin_out_file.close() 72 | 73 | -------------------------------------------------------------------------------- /mrs/extract_erg_lexicon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import sys 18 | 19 | if __name__ == '__main__': 20 | assert len(sys.argv) >= 3 21 | lexicon_dir = sys.argv[1] + '/' 22 | lexicon_file = open(lexicon_dir + '/lexicon.tdl', 'r') 23 | out_dir = sys.argv[2] + '/' 24 | 25 | state = False 26 | entry = [] 27 | pred_dict = {} 28 | const_dict = {} 29 | 30 | for line in lexicon_file: 31 | line = line.strip() 32 | if state: 33 | if line <> '': 34 | entry += line.split(' ') 35 | elif entry: 36 | # Parses entry. 37 | assert 'ORTH' in entry, entry 38 | orth_start = entry.index('ORTH') + 2 39 | assert '>,' in entry[orth_start:] or '>' in entry[orth_start:], entry 40 | if '>' in entry[orth_start:]: 41 | orth_end = entry.index('>', orth_start) 42 | else: 43 | orth_end = entry.index('>,', orth_start) 44 | orth = '' 45 | for i in xrange(orth_start, orth_end): 46 | if entry[i][-1] == ',': 47 | if entry[i][-3] == '-': 48 | orth += entry[i][1:-3] + ' ' 49 | else: 50 | orth += entry[i][1:-2] + ' ' 51 | else: 52 | orth += entry[i][1:-1] + ' ' 53 | orth = orth[:-1] 54 | pred = '' 55 | const = '' 56 | 57 | const_keys = filter(lambda x: 'KEYREL.CARG' in x or x == 'CARG', entry) 58 | if const_keys: 59 | assert len(const_keys) == 1 60 | const = entry[entry.index(const_keys[0]) + 1] 61 | assert const[0] == '"' 62 | if const[-1] == ',': 63 | assert const[-2] == '"' 64 | const = const[1:-2] 65 | elif const[-1] == '"': 66 | const = const[1:-1] 67 | else: 68 | const = const[1:] 69 | pred_keys = filter(lambda x: 'KEYREL.PRED' in x or x == 'PRED', entry) 70 | if pred_keys: 71 | pred = entry[entry.index(pred_keys[0]) + 1] 72 | start_ind = 1 if pred[0] == '"' else 0 73 | if pred[start_ind] == '_': 74 | pred = pred[start_ind:pred.index('_', 2)] 75 | else: 76 | pred = '' 77 | 78 | if pred: 79 | if orth in pred_dict: 80 | if pred <> pred_dict[orth]: 81 | if pred[1:] == orth or (pred_dict[orth][1:] <> orth and 82 | abs(len(orth) - len(pred)) < abs(len(orth) - len(pred_dict[orth])-1)): 83 | pred_dict[orth] = pred 84 | else: 85 | pred_dict[orth] = pred 86 | if const: 87 | const_dict[orth] = const 88 | state = False 89 | else: 90 | if line <> '': 91 | entry = line.split(' ') 92 | state = True 93 | 94 | pred_out_file = open(out_dir + 'predicates.erg.lexicon', 'w') 95 | for orth, pred in pred_dict.iteritems(): 96 | pred_out_file.write(orth + '\n' + pred + '\n') 97 | pred_out_file.close() 98 | 99 | const_out_file = open(out_dir + 'constants.erg.lexicon', 'w') 100 | for orth, const in const_dict.iteritems(): 101 | const_out_file.write(orth + '\n' + const + '\n') 102 | const_out_file.close() 103 | 104 | -------------------------------------------------------------------------------- /mrs/extract_sdp_dmrs_lines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Reads in deepbank's export format and converts it to DMRS.""" 17 | 18 | import gzip 19 | import sys 20 | 21 | import delphin.mrs as mrs 22 | import delphin.mrs.simplemrs as simplemrs 23 | import delphin.mrs.simpledmrs as simpledmrs 24 | 25 | if __name__=='__main__': 26 | assert len(sys.argv) == 4, 'Invalid number of arguments (3 required).' 27 | output_dir = sys.argv[1] + '/' 28 | filename = sys.argv[2] 29 | set_name = sys.argv[3] 30 | 31 | # Reads text file. 32 | assert filename.endswith('mrs.gz'), filename 33 | file_prefix = filename[:-len('mrs.gz')] 34 | txt_filename = file_prefix + 'txt.gz' 35 | sentence_str = gzip.open(txt_filename, 'r').read().strip() 36 | 37 | in_file = gzip.open(filename, 'r') 38 | simple_mrs_str = '' 39 | 40 | tokens = [] 41 | token_inds = {} 42 | token_start = [] 43 | token_end = [] 44 | 45 | state = 0 46 | for line in in_file: 47 | if state == 0 and line.strip().startswith('[ LTOP:'): 48 | # State 4: MRS 49 | simple_mrs_str = line.strip() + ' ' 50 | state = 4 51 | elif state == 4: 52 | if line.strip() == '': 53 | state = 5 54 | else: 55 | simple_mrs_str += line.strip() + ' ' 56 | 57 | simple_mrs_code = simple_mrs_str.decode('utf-8', 'replace') 58 | simple_mrs_str = simple_mrs_code.encode('ascii', 'replace') 59 | 60 | dmrs_xml_str = mrs.convert(simple_mrs_str, 'simplemrs', 'dmrx') 61 | dmrs_object = mrs.dmrx.loads(dmrs_xml_str) 62 | simple_dmrs_str = simpledmrs.dumps(dmrs_object) #TODO this can give error 63 | mrs_object = simplemrs.loads_one(simple_mrs_str) 64 | 65 | dmrs_const_str = '' 66 | 67 | # Adds constants. 68 | for ep in mrs_object.eps(): 69 | if ep.args.has_key('CARG'): 70 | dmrs_const_str += (str(ep.lnk)[1:-1] + ' ' + str(ep.pred) + ' ' 71 | + ep.args['CARG'] + ' ') 72 | 73 | hyphen_sentence_str = sentence_str 74 | # Removes in-word hyphens in sentence. 75 | for i in xrange(1, len(sentence_str)-1): 76 | if (sentence_str[i] == '-' and sentence_str[i-1] <> ' ' and 77 | sentence_str[i+1] <> ' '): 78 | sentence_str = sentence_str[:i] + ' ' + sentence_str[i+1:] 79 | 80 | sent_out_file = open(output_dir + set_name + '.raw', 'a') 81 | sent_out_file.write(sentence_str + '\n') 82 | sent_out_file.close() 83 | 84 | lin_out_file = open(output_dir + set_name + '.sdmrs', 'a') 85 | lin_out_file.write(simple_dmrs_str + '\n') 86 | lin_out_file.close() 87 | 88 | lin_out_file = open(output_dir + set_name + '.carg', 'a') 89 | lin_out_file.write(dmrs_const_str + '\n') 90 | lin_out_file.close() 91 | 92 | -------------------------------------------------------------------------------- /mrs/extract_sdp_eds_lines.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Reads in deepbank's export format and extracts EDS.""" 17 | 18 | import gzip 19 | import sys 20 | 21 | if __name__=='__main__': 22 | assert len(sys.argv) == 4, 'Invalid number of arguments (3 required).' 23 | output_dir = sys.argv[1] + '/' 24 | filename = sys.argv[2] 25 | set_name = sys.argv[3] 26 | 27 | # Reads text file. 28 | assert filename.endswith('eds.gz'), filename 29 | file_prefix = filename[:-len('eds.gz')] 30 | txt_filename = file_prefix + 'txt.gz' 31 | sentence_str = gzip.open(txt_filename, 'r').read().strip() 32 | 33 | in_file = gzip.open(filename, 'r') 34 | eds_str = '' 35 | 36 | tokens = [] 37 | token_inds = {} 38 | token_start = [] 39 | token_end = [] 40 | 41 | state = 0 42 | for line in in_file: 43 | if state == 0 and line[0] == '{' and ':' in line: 44 | eds_str = line[1:line.index(':')] 45 | state = 1 46 | elif state == 1: 47 | if line[0] == '}': 48 | state = 2 49 | else: 50 | line = line.strip() 51 | eds_str += ' ; ' + line 52 | 53 | sent_out_file = open(output_dir + set_name + '.raw', 'a') 54 | sent_out_file.write(sentence_str + '\n') 55 | sent_out_file.close() 56 | 57 | lin_out_file = open(output_dir + set_name + '.eds', 'a') 58 | lin_out_file.write(eds_str + '\n') 59 | lin_out_file.close() 60 | 61 | -------------------------------------------------------------------------------- /mrs/linear_to_mrs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import re 18 | import string 19 | import sys 20 | import codecs 21 | 22 | import graph as mrs_graph 23 | 24 | def clean_quoted(word): 25 | constant = word.replace('"', '') 26 | constant = constant.replace('\'', '') 27 | constant = constant.replace(':', '') 28 | constant = constant.replace('(', '_') 29 | constant = constant.replace(')', '_') 30 | constant = constant.strip() 31 | constant = constant.replace(' ', '_') 32 | return constant 33 | 34 | 35 | def read_span(filename): 36 | token_starts = [] 37 | token_ends = [] 38 | inds_file = open(filename, 'r') 39 | for line in inds_file: 40 | inds = line.strip().split(' ') 41 | token_starts.append(map(int, [ind.split(':')[0] for ind in inds])) 42 | token_ends.append(map(int, [ind.split(':')[1] for ind in inds])) 43 | return token_starts, token_ends 44 | 45 | 46 | def read_ints(filename): 47 | sentences = [] 48 | sentence_file = open(filename, 'r') 49 | for line in sentence_file: 50 | sentences.append(int(line.strip())) 51 | return sentences 52 | 53 | 54 | def read_ids(filename): 55 | sentences = [] 56 | sentence_file = open(filename, 'r') 57 | for line in sentence_file: 58 | sentences.append(line[:line.index('.txt')]) 59 | return sentences 60 | 61 | 62 | def read_tokens(filename): 63 | sentences = [] 64 | sentence_file = codecs.open(filename, 'r', 'utf-8') 65 | for line in sentence_file: 66 | sentences.append(line.strip().split(' ')) 67 | return sentences 68 | 69 | 70 | def linear_to_mrs(base_path, data_path, mrs_path, set_name, convert_train, 71 | is_inorder=False, is_arceager=False, 72 | is_arceager_buffer_shift=False, 73 | is_lexicalized=True, is_no_span=False, 74 | recover_end_spans=False, is_amr=False, is_epe=False, 75 | domain=''): 76 | graphs = mrs_graph.read_linear_dmrs_graphs(mrs_path + '.lin', True, 77 | is_inorder, is_arceager, is_arceager_buffer_shift, is_no_span) 78 | 79 | for i, graph in enumerate(graphs): 80 | if graph.nodes: 81 | graphs[i].spanned = [False for _ in graph.nodes] 82 | graphs[i].find_span_tree(graph.root_index) 83 | 84 | if recover_end_spans: 85 | for i, graph in enumerate(graphs): 86 | for j in xrange(len(graphs[i].nodes)): # clear any end alignments 87 | graphs[i].nodes[j].alignment_end = graphs[i].nodes[j].alignment 88 | graphs[i].recover_end_spans(graph.root_index, -1) 89 | 90 | if convert_train: 91 | # Keeps token-level spans. 92 | for i, graph in enumerate(graphs): 93 | for j, node in enumerate(graph.nodes): 94 | if node.alignment_end < node.alignment: 95 | node.alignment_end = node.alignment # correct negative spans 96 | graphs[i].nodes[j].ind = (str(node.alignment) + ':' 97 | + str(node.alignment_end)) 98 | if node.concept.endswith('_CARG'): 99 | concept = node.concept[:node.concept.index('_CARG')] 100 | graphs[i].nodes[j].concept = concept 101 | graphs[i].nodes[j].constant = 'CARG' 102 | 103 | # Reads untokenized sentences. 104 | sentences = [] 105 | sentences_unstripped = [] 106 | if is_epe: 107 | sentence_file = codecs.open(base_path + '.txt', 'r', 'utf-8') 108 | else: 109 | sentence_file = open(data_path + '.raw', 'r') 110 | 111 | for line in sentence_file: 112 | sentences_unstripped.append(line.replace('\n', ' ')) 113 | sentences.append(line.strip()) 114 | 115 | # Reads const and pred candidate tokens. 116 | const_tokens = read_tokens(base_path + '.lex.const') 117 | pred_tokens = read_tokens(base_path + '.lex.pred') 118 | pos_tokens = read_tokens(base_path + '.pos') 119 | ne_tokens = read_tokens(base_path + '.ne') 120 | if is_amr: 121 | nom_tokens = read_tokens(amr_base_path + '.lex.nom') 122 | if is_epe: 123 | offsets = read_ints(base_path + '.off') 124 | file_ids = read_ids(base_path + '.ids') 125 | 126 | # Reads token span indexes. 127 | token_starts, token_ends = read_span(base_path + '.span') 128 | const_token_starts, const_token_ends = read_span(base_path + '.span.const') 129 | pred_token_starts, pred_token_ends = read_span(base_path + '.span.pred') 130 | 131 | for i, graph in enumerate(graphs): 132 | for j, node in enumerate(graph.nodes): 133 | if node.alignment_end < node.alignment: 134 | node.alignment_end = node.alignment # correct negative sized spans 135 | align = min(len(token_starts[i]) -1, node.alignment) 136 | set_span = False 137 | if node.concept.endswith('_CARG'): 138 | concept = node.concept[:node.concept.index('_CARG')] 139 | graphs[i].nodes[j].concept = concept 140 | constant = const_tokens[i][align] 141 | if constant: 142 | graphs[i].nodes[j].constant = '"' + constant + '"' 143 | span_start = const_token_starts[i][align] 144 | span_end = const_token_ends[i][align] 145 | elif is_amr and node.concept == 'CONST': 146 | constant = const_tokens[i][align] 147 | if re.match(r'^\d+(\.\d+)?$', constant): # match numbers 148 | graphs[i].nodes[j].constant = constant 149 | else: 150 | graphs[i].nodes[j].constant = '"' + constant + '"' 151 | span_start = const_token_starts[i][align] 152 | span_end = const_token_ends[i][align] 153 | elif is_amr and node.concept[0] == '"' and node.concept[-1] == '"': 154 | constant = node.concept[1:-1] 155 | if re.match(r'^\d+(\.\d+)?$', constant): # match numbers 156 | graphs[i].nodes[j].constant = constant 157 | else: 158 | graphs[i].nodes[j].constant = '"' + constant + '"' 159 | span_start = 0 160 | span_end = 0 161 | elif is_amr and not is_lexicalized and node.concept.startswith('_'): 162 | if node.concept.startswith('_p_'): 163 | pred = pred_tokens[i][align][1:] 164 | concept = pred + '-' + node.concept[3:] 165 | else: 166 | pred = nom_tokens[i][align][1:] 167 | concept = pred 168 | span_start = pred_token_starts[i][align] 169 | span_end = pred_token_ends[i][align] 170 | graphs[i].nodes[j].concept = concept 171 | elif is_amr and node.concept.startswith('_'): 172 | sense_index = node.concept.index('_', 1) 173 | pred = node.concept[1:sense_index] 174 | suffix = node.concept[sense_index:] 175 | if suffix.startswith('_p_'): 176 | concept = pred + '-' + suffix[3:] 177 | else: 178 | concept = pred 179 | graphs[i].nodes[j].concept = concept 180 | elif ((not is_lexicalized and node.concept.startswith('_')) 181 | or (is_lexicalized and 'u_unknown' in node.concept)): 182 | pred = pred_tokens[i][align] 183 | if pred.startswith('_'): 184 | if node.concept.startswith('_+') or node.concept.startswith('_-'): 185 | concept = pred + node.concept[1:] 186 | else: 187 | concept = pred + node.concept 188 | else: 189 | # Dictionary overrules prediction. 190 | concept = '_' + pred + '_u_unknown' 191 | span_start = pred_token_starts[i][align] 192 | span_end = pred_token_ends[i][align] 193 | graphs[i].nodes[j].concept = concept 194 | if not set_span: 195 | if node.alignment >= len(token_starts[i]): 196 | span_start = token_starts[i][-1] 197 | else: 198 | span_start = token_starts[i][node.alignment] 199 | if node.alignment_end >= len(token_ends[i]): 200 | span_end = token_ends[i][-1] 201 | else: 202 | span_end = token_ends[i][node.alignment_end] 203 | 204 | if node.alignment >= len(pos_tokens[i]): 205 | graphs[i].nodes[j].pos = pos_tokens[i][-1] 206 | graphs[i].nodes[j].ne = ne_tokens[i][-1] 207 | else: 208 | graphs[i].nodes[j].pos = pos_tokens[i][node.alignment] 209 | graphs[i].nodes[j].ne = ne_tokens[i][node.alignment] 210 | # Post-process span for punctuation. 211 | if (span_end + 1 < len(sentences[i]) 212 | and sentences[i][span_end] in string.punctuation 213 | and sentences[i][span_end+1].isspace()): 214 | span_end += 1 215 | graphs[i].nodes[j].ind = str(span_start) + ':' + str(span_end) 216 | 217 | # Write out char-level EDM. 218 | if not is_no_span: 219 | edm_out_file = open(mrs_path + '.edm', 'w') 220 | for graph in graphs: 221 | if graph is None or len(graph.nodes) == 0: 222 | edm_out_file.write('NONE\n') 223 | else: 224 | str_enc = graph.edm_ch_str().encode('utf-8', 'replace') 225 | edm_out_file.write(str_enc + '\n') 226 | edm_out_file.close() 227 | 228 | epe_out_file = open(mrs_path + '.epe', 'w') 229 | offset = 0 230 | for i, graph in enumerate(graphs): 231 | if is_epe: 232 | offset = offsets[i] 233 | if not (graph is None or len(graph.nodes) == 0): 234 | epe_out_file.write((graph.epe_str(i, sentences_unstripped[i], offset) + '\n').encode('utf-8', 'replace')) 235 | if not is_epe: 236 | offset += len(sentences_unstripped[i]) 237 | epe_out_file.close() 238 | 239 | if is_epe: 240 | # write out to seperate files: 241 | type_map = {'train': 'training', 'dev': 'development', 'test': 'evaluation'} 242 | file_id = '' 243 | file_i = 0 244 | out_file = None 245 | for i, graph in enumerate(graphs): 246 | offset = offsets[i] 247 | if not (graph is None or len(graph.nodes) == 0): 248 | if file_ids[i] == file_id: 249 | file_i += 1 250 | else: 251 | file_id = file_ids[i] 252 | file_i = 0 253 | filename = ('epe-results/' + domain + '/' + type_map[set_name] 254 | + '/' + file_id + '.epe') 255 | out_file = open(filename, 'w') 256 | enc_str = (graph.epe_str(file_i, sentences_unstripped[i], offset) + '\n').encode('utf-8', 'replace') 257 | out_file.write(enc_str) 258 | 259 | # Writes out AMR for Smatch evaluation. 260 | amr_out_file = open(mrs_path + '.amr', 'w') 261 | for i, graph in enumerate(graphs): 262 | if graph is None or len(graph.nodes) == 0: 263 | amr_out_file.write('( n1 / _UNK )\n\n') 264 | else: 265 | graph.correct_concept_names() 266 | if is_amr: 267 | graph.restore_op_indexes() 268 | graph.restore_original_constants(graph.root_index, 'focus') 269 | if is_amr and graph.nodes[graph.root_index].constant <> '': 270 | concept = graph.nodes[graph.root_index].constant 271 | if concept[0] == '"' and concept[-1] == '"': 272 | concept = concept[1:-1] 273 | amr_out_file.write('( n1 / ' + concept + ' )\n\n') 274 | else: 275 | amr_out_file.write(graph.amr_graph_str(graph.root_index, 1, is_amr).encode('ascii', 'replace') + '\n\n') 276 | amr_out_file.close() 277 | 278 | 279 | def point_to_linear(amr_path, copy_only, shift_pointer, no_pointer, no_endspan, 280 | with_endspan): 281 | if no_pointer or copy_only: 282 | input_names = ['parse'] 283 | elif with_endspan: 284 | input_names = ['parse', 'att', 'endatt'] 285 | else: 286 | input_names = ['parse', 'att'] 287 | 288 | input_dmrs = {} 289 | for name in input_names: 290 | mrs_file = open(amr_path + '.' + name, 'r') 291 | lines = [] 292 | for line in mrs_file: 293 | lines.append(line.strip().split(' ')) 294 | input_dmrs[name] = lines 295 | 296 | out_file = open(amr_path + '.lin', 'w') 297 | for i, parse_line in enumerate(input_dmrs['parse']): 298 | if copy_only: 299 | out_file.write(parse_line) 300 | continue 301 | dmrs = [] 302 | start_ind = 0 303 | if not no_pointer: 304 | assert len(input_dmrs['att'][i]) == len(parse_line) 305 | if with_endspan: 306 | assert len(input_dmrs['endatt'][i]) == len(parse_line) 307 | for j, parse_symbol in enumerate(parse_line): 308 | if no_pointer: 309 | ind = 0 310 | elif shift_pointer: 311 | if (j + 1 >= len(input_dmrs['att'][i]) or 312 | len(input_dmrs['att'][i][j+1]) == 0): 313 | ind = 0 314 | else: 315 | ind = max(0, int(input_dmrs['att'][i][j+1])) 316 | else: 317 | if len(input_dmrs['att'][i][j]) == 0: 318 | ind = 0 319 | else: 320 | ind = max(0, int(input_dmrs['att'][i][j])) 321 | if parse_symbol == ')' or parse_symbol == 'RE': 322 | dmrs.append(parse_symbol) 323 | if with_endspan: 324 | if len(input_dmrs['endatt'][i][j]) == 0: 325 | end_ind = 0 326 | else: 327 | end_ind = max(0, int(input_dmrs['endatt'][i][j])) 328 | dmrs.append(str(end_ind) + '>') 329 | elif not no_endspan: 330 | dmrs.append(str(ind) + '>') 331 | elif (parse_symbol.startswith(':') or parse_symbol.startswith('LA:') 332 | or parse_symbol.startswith('RA:') or parse_symbol.startswith('UA:') 333 | or parse_symbol.startswith('STACK*') or parse_symbol == 'ROOT'): 334 | dmrs.append(parse_symbol) 335 | else: # shift 336 | dmrs.append('<' + str(ind)) 337 | dmrs.append(parse_symbol) 338 | out_file.write(' '.join(dmrs) + '\n') 339 | out_file.close() 340 | 341 | 342 | if __name__=='__main__': 343 | assert len(sys.argv) >= 5 344 | data_name = sys.argv[1] 345 | set_name = sys.argv[2] 346 | 347 | convert_train = len(sys.argv) >= 6 and sys.argv[5] == '-t' 348 | 349 | is_inorder = '-inorder' in sys.argv[5:] 350 | is_arceager = '-arceager' in sys.argv[5:] 351 | is_arceager_buffer_shift = '-arceagerbuffershift' in sys.argv[5:] 352 | is_lexicalized = '-unlex' not in sys.argv[5:] 353 | is_no_span = '-nospan' in sys.argv[5:] 354 | 355 | copy_only = len(sys.argv) >= 6 and '-copy' in sys.argv[5:] 356 | shift_pointer = len(sys.argv) >= 6 and '-shift' in sys.argv[5:] 357 | no_pointer = len(sys.argv) >= 6 and '-nopointer' in sys.argv[5:] 358 | no_endspan = len(sys.argv) >= 6 and '-noendspan' in sys.argv[5:] 359 | with_endspan = len(sys.argv) >= 6 and '-withendspan' in sys.argv[5:] 360 | 361 | is_amr = '-amr' in sys.argv[5:] 362 | is_epe = '-epe' in sys.argv[5:] 363 | recover_end_spans = False 364 | 365 | amr_dir = data_name + '-working/' 366 | if is_epe: 367 | domain = data_name[4:] 368 | else: 369 | domain = '' 370 | amr_file_name = set_name 371 | amr_base_path = amr_dir + amr_file_name 372 | amr_data_path = data_name + '/' + amr_file_name 373 | 374 | working_dir = sys.argv[3] + '/' 375 | if sys.argv[4] == '-': 376 | amr_path = working_dir + amr_file_name 377 | else: 378 | amr_path = working_dir + amr_file_name + '.' + sys.argv[4] 379 | 380 | point_to_linear(amr_path, copy_only, shift_pointer, no_pointer, no_endspan, with_endspan) 381 | 382 | linear_to_mrs(amr_base_path, amr_data_path, amr_path, set_name, convert_train, 383 | is_inorder, is_arceager, is_arceager_buffer_shift, 384 | is_lexicalized, is_no_span, recover_end_spans, is_amr, 385 | is_epe, domain) 386 | 387 | -------------------------------------------------------------------------------- /mrs/read_mrs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import sys 17 | 18 | import graph as mrs_graph 19 | 20 | if __name__=='__main__': 21 | assert len(sys.argv) >= 4 22 | input_dir = sys.argv[1] + '/' 23 | working_dir = sys.argv[2] + '/' 24 | data_type = sys.argv[3] # dmrs or eds 25 | 26 | set_names = ['train', 'dev', 'test'] 27 | suffix = '.raw' 28 | 29 | for set_name in set_names: 30 | sent_file = open(input_dir + set_name + suffix, 'r') 31 | sentences_raw = [sent for sent in sent_file.read().split('\n')[:-1]] 32 | 33 | # Reads token indexes. 34 | inds_filename = working_dir + set_name + '.span' 35 | inds_file = open(inds_filename, 'r') 36 | token_inds = [] 37 | token_starts = [] 38 | token_ends = [] 39 | for line in inds_file: 40 | inds = line.strip().split(' ') 41 | token_ind = {} 42 | for i, ind in enumerate(inds): 43 | token_ind[ind] = i 44 | token_inds.append(token_ind) 45 | token_starts.append(map(int, [ind.split(':')[0] for ind in inds])) 46 | token_ends.append(map(int, [ind.split(':')[1] for ind in inds])) 47 | 48 | if data_type == 'dmrs': 49 | sdmrs_file = open(input_dir + set_name + '.sdmrs', 'r') 50 | simple_strs = [sent for sent in sdmrs_file.read().split('\n')[:-1]] 51 | carg_file = open(input_dir + set_name + '.carg', 'r') 52 | dmrs_carg_strs = [sent for sent in carg_file.read().split('\n')[:-1]] 53 | else: 54 | eds_file = open(input_dir + set_name + '.eds', 'r') 55 | simple_strs = [sent for sent in eds_file.read().split('\n')[:-1]] 56 | 57 | # Files to write output to. 58 | lin_amr_out_file = open(working_dir + set_name + '.amr.lin', 'w') 59 | lin_dmrs_out_file = open(working_dir + set_name + '.dmrs.lin', 'w') 60 | unlex_dmrs_out_file = open(working_dir + set_name + '.dmrs.unlex.lin', 'w') 61 | nospan_dmrs_out_file = open(working_dir + set_name + '.dmrs.nospan.lin', 'w') 62 | nospan_unlex_dmrs_out_file = open(working_dir + set_name + '.dmrs.nospan.unlex.lin', 'w') 63 | point_dmrs_out_file = open(working_dir + set_name + '.dmrs.point.lin', 'w') 64 | 65 | lin_dmrs_ae_out_file = open(working_dir + set_name + '.dmrs.ae.lin', 'w') 66 | lin_dmrs_ae_io_out_file = open(working_dir + set_name + '.dmrs.ae.io.lin', 'w') 67 | lin_dmrs_unlex_ae_io_out_file = open(working_dir + set_name + '.dmrs.ae.io.unlex.lin', 'w') 68 | lin_dmrs_nospan_unlex_ae_io_out_file = open(working_dir + set_name + '.dmrs.ae.io.nospan.unlex.lin', 'w') 69 | point_dmrs_ae_io_out_file = open(working_dir + set_name + '.dmrs.ae.io.point.lin', 'w') 70 | end_point_dmrs_ae_io_out_file = open(working_dir + set_name + '.dmrs.ae.io.endpoint.lin', 'w') 71 | 72 | lin_dmrs_unlex_ae_ioc_out_file = open(working_dir + set_name + '.dmrs.ae.ioc.unlex.lin', 'w') 73 | lin_dmrs_nospan_unlex_ae_ioc_out_file = open(working_dir + set_name + '.dmrs.ae.ioc.nospan.unlex.lin', 'w') 74 | point_dmrs_ae_ioc_out_file = open(working_dir + set_name + '.dmrs.ae.ioc.point.lin', 'w') 75 | end_point_dmrs_ae_ioc_out_file = open(working_dir + set_name + '.dmrs.ae.ioc.endpoint.lin', 'w') 76 | 77 | lin_dmrs_unlex_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.unlex.lin', 'w') 78 | lin_dmrs_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.lin', 'w') 79 | lin_dmrs_nospan_unlex_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.nospan.unlex.lin', 'w') 80 | lin_dmrs_action_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.action.lin', 'w') 81 | lin_dmrs_concept_unlex_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.concept.unlex.lin', 'w') 82 | lin_dmrs_morph_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.morph.lin', 'w') 83 | point_dmrs_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.point.lin', 'w') 84 | end_point_dmrs_ae_ao_out_file = open(working_dir + set_name + '.dmrs.ae.ao.endpoint.lin', 'w') 85 | 86 | lin_dmrs_unlex_ae_out_file = open(working_dir + set_name + '.dmrs.ae.unlex.lin', 'w') 87 | nospan_dmrs_ae_out_file = open(working_dir + set_name + '.dmrs.ae.nospan.lin', 'w') 88 | nospan_unlex_dmrs_ae_out_file = open(working_dir + set_name + '.dmrs.ae.nospan.unlex.lin', 'w') 89 | point_dmrs_ae_out_file = open(working_dir + set_name + '.dmrs.ae.point.lin', 'w') 90 | 91 | lin_preds_out_file = open(working_dir + set_name + '.preds.lin', 'w') 92 | lin_preds_unlex_out_file = open(working_dir + set_name + '.preds.unlex.lin', 'w') 93 | lin_preds_nospan_out_file = open(working_dir + set_name + '.preds.nospan.lin', 'w') 94 | lin_preds_nospan_unlex_out_file = open(working_dir + set_name + '.preds.nospan.unlex.lin', 'w') 95 | lin_preds_point_out_file = open(working_dir + set_name + '.preds.point.lin', 'w') 96 | 97 | 98 | lin_dmrs_io_out_file = open(working_dir + set_name + '.dmrs.io.lin', 'w') 99 | lin_dmrs_unlex_io_out_file = open(working_dir + set_name + '.dmrs.io.unlex.lin', 'w') 100 | lin_dmrs_ioha_out_file = open(working_dir + set_name + '.dmrs.ioha.lin', 'w') 101 | point_dmrs_ioha_out_file = open(working_dir + set_name + '.dmrs.ioha.ind', 'a') 102 | 103 | edm_out_file = open(working_dir + set_name + '.edm', 'w') 104 | edmu_out_file = open(working_dir + set_name + '.edmu', 'w') 105 | amr_out_file = open(working_dir + set_name + '.amr', 'w') 106 | 107 | for i, simple_str in enumerate(simple_strs): 108 | if data_type == 'dmrs': 109 | graph = mrs_graph.parse_dmrs(simple_str, token_inds[i], 110 | token_starts[i], token_ends[i], 111 | sentences_raw[i]) 112 | 113 | # Adds constants. 114 | carg_list = dmrs_carg_strs[i].split() 115 | carg_inds = carg_list[::3] 116 | carg_preds = carg_list[1::3] 117 | carg_values = carg_list[2::3] 118 | 119 | for ind, pred, const in zip(carg_inds, carg_preds, carg_values): 120 | # Find head node of each CARG. 121 | found = False 122 | if (pred[0] == '"' and pred[-1] == '"') or pred[0] == '_': 123 | continue 124 | for i, node in enumerate(graph.nodes): 125 | if node.ind == ind and node.concept == pred: 126 | graph.nodes[i].constant = const 127 | found = True 128 | if not found: 129 | print "Constant not found:", pred, const 130 | else: 131 | graph = mrs_graph.parse_eds(simple_str, token_inds[i], 132 | token_starts[i], token_ends[i], 133 | sentences_raw[i]) 134 | 135 | 136 | graph.find_span_tree(graph.root_index) 137 | 138 | graph.find_alignment_spans(graph.root_index) 139 | graph.find_span_edge_directions() 140 | 141 | # Validate alignments. 142 | for i, node in enumerate(graph.nodes): 143 | if graph.spanned[i] and node.alignment >= 0: 144 | if node.constant: 145 | graph.nodes[i].is_aligned = True 146 | elif node.concept.startswith('_'): # lexical concepts 147 | graph.nodes[i].is_aligned = True 148 | 149 | # Writes output to files. 150 | lin_amr_out_file.write(':focus( ' 151 | + graph.linear_amr_str(graph.root_index) + ' )\n') 152 | 153 | span_start = str(graph.nodes[graph.root_index].alignment) 154 | span_end = str(graph.nodes[graph.root_index].alignment_end) 155 | lin_dmrs_out_file.write(':/H( <' + span_start + ' ') 156 | lin_dmrs_out_file.write(graph.dmrs_str(graph.root_index) + ' ) ' + span_end + '>\n') 157 | 158 | unlex_dmrs_out_file.write(':/H( <' + span_start + ' ') 159 | unlex_dmrs_out_file.write(graph.dmrs_str(graph.root_index, False) + ' ) ' + span_end + '>\n') 160 | 161 | nospan_dmrs_out_file.write(':/H( ' + graph.dmrs_str(graph.root_index, 162 | True, False) + ' )\n') 163 | nospan_unlex_dmrs_out_file.write(':/H( ' 164 | + graph.dmrs_str(graph.root_index, False, False) + ' )\n') 165 | 166 | point_dmrs_out_file.write(span_start + ' ' + 167 | graph.dmrs_point_str(graph.root_index) + span_end + '\n') 168 | 169 | lin_dmrs_ae_out_file.write(graph.dmrs_arceager_str(graph.root_index, '/H')) 170 | lin_dmrs_ae_out_file.write(') ' + span_end + '>\n') 171 | 172 | lin_dmrs_unlex_ae_out_file.write(graph.dmrs_arceager_str(graph.root_index, 173 | '/H', False)) 174 | lin_dmrs_unlex_ae_out_file.write(') ' + span_end + '>\n') 175 | 176 | ae_io_str, _, _, _, _, _, _ = graph.dmrs_arceager_oracle_str('inorder', True) 177 | lin_dmrs_ae_io_out_file.write(ae_io_str + '\n') 178 | 179 | unlex_ae_io_str, unlex_ae_io_nospan_str, _, _, _, ae_io_point_str, ae_io_end_point_str = graph.dmrs_arceager_oracle_str('inorder', False) 180 | lin_dmrs_unlex_ae_io_out_file.write(unlex_ae_io_str + '\n') 181 | lin_dmrs_nospan_unlex_ae_io_out_file.write(unlex_ae_io_nospan_str + '\n') 182 | point_dmrs_ae_io_out_file.write(ae_io_point_str + '\n') 183 | end_point_dmrs_ae_io_out_file.write(ae_io_end_point_str + '\n') 184 | 185 | unlex_ae_ioc_str, unlex_ae_ioc_nospan_str, _, _, _, ae_ioc_point_str, ae_ioc_end_point_str = graph.dmrs_arceager_oracle_str('cleaninorder', False) 186 | lin_dmrs_unlex_ae_ioc_out_file.write(unlex_ae_ioc_str + '\n') 187 | lin_dmrs_nospan_unlex_ae_ioc_out_file.write(unlex_ae_ioc_nospan_str + '\n') 188 | point_dmrs_ae_ioc_out_file.write(ae_ioc_point_str + '\n') 189 | end_point_dmrs_ae_ioc_out_file.write(ae_ioc_end_point_str + '\n') 190 | 191 | unlex_ae_ao_str, unlex_ae_ao_nospan_str, ae_ao_action_str, unlex_ae_ao_concept_str, ae_ao_morph_str, ae_ao_point_str, ae_ao_end_point_str = graph.dmrs_arceager_oracle_str('alignorder', False) 192 | lex_ae_ao_str, _, _, _, _, _, _ = graph.dmrs_arceager_oracle_str('alignorder', True) 193 | lin_dmrs_unlex_ae_ao_out_file.write(unlex_ae_ao_str + '\n') 194 | lin_dmrs_ae_ao_out_file.write(lex_ae_ao_str + '\n') 195 | lin_dmrs_nospan_unlex_ae_ao_out_file.write(unlex_ae_ao_nospan_str + '\n') 196 | lin_dmrs_concept_unlex_ae_ao_out_file.write(unlex_ae_ao_concept_str + '\n') 197 | lin_dmrs_action_ae_ao_out_file.write(ae_ao_action_str + '\n') 198 | lin_dmrs_morph_ae_ao_out_file.write(ae_ao_morph_str + '\n') 199 | point_dmrs_ae_ao_out_file.write(ae_ao_point_str + '\n') 200 | end_point_dmrs_ae_ao_out_file.write(ae_ao_end_point_str + '\n') 201 | 202 | nospan_dmrs_ae_out_file.write(graph.dmrs_arceager_str(graph.root_index, 203 | '/H', True, False) + ')\n') 204 | nospan_unlex_dmrs_ae_out_file.write(graph.dmrs_arceager_str( 205 | graph.root_index, '/H', False, False) + ')\n') 206 | 207 | point_dmrs_ae_out_file.write(graph.dmrs_arceager_point_str( 208 | graph.root_index, True) + span_end + '\n') 209 | 210 | io_lin_str, _ = graph.dmrs_inorder_str(graph.root_index) 211 | lin_dmrs_io_out_file.write(io_lin_str + '\n') 212 | io_unlex_lin_str, _ = graph.dmrs_inorder_str(graph.root_index, False) 213 | lin_dmrs_unlex_io_out_file.write(io_unlex_lin_str + '\n') 214 | 215 | ioha_lin_str, ioha_ind_str, _ = graph.dmrs_inorder_ha_str(graph.root_index, '0') 216 | lin_dmrs_ioha_out_file.write(ioha_lin_str + '\n') 217 | point_dmrs_ioha_out_file.write(ioha_ind_str + '\n') 218 | 219 | preds, preds_nospan, preds_point = graph.ordered_predicates_str() 220 | 221 | lin_preds_out_file.write(preds + '\n') 222 | lin_preds_nospan_out_file.write(preds_nospan + '\n') 223 | lin_preds_point_out_file.write(preds_point + '\n') 224 | 225 | preds_unlex, preds_nospan_unlex, _ = graph.ordered_predicates_str(False) 226 | 227 | lin_preds_unlex_out_file.write(preds_unlex + '\n') 228 | lin_preds_nospan_unlex_out_file.write(preds_nospan_unlex + '\n') 229 | 230 | edm_out_file.write(graph.edm_ch_str(True) + '\n') 231 | edmu_out_file.write(graph.edm_ch_str(False) + '\n') 232 | 233 | if len(graph.nodes) == 0: 234 | amr_out_file.write('( n1 / _UNK )\n\n') 235 | else: 236 | graph.correct_constants() 237 | graph.correct_concept_names() 238 | graph.correct_node_names() 239 | amr_out_file.write(graph.amr_graph_str(graph.root_index, 1) + '\n\n') 240 | 241 | -------------------------------------------------------------------------------- /mrs/sentence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import re 17 | import util 18 | from nltk.corpus import wordnet as wn 19 | 20 | import util as mrs_util 21 | 22 | class Token(): 23 | def __init__(self, word, original_word, pos, constant_label, is_ne, is_timex, 24 | ne_tag, normalized_ne_tag, edge = -1, relation = '', timex_attr = dict(), 25 | char_start=-1, char_end=-1): 26 | self.word = word 27 | self.original_word = original_word 28 | if ne_tag == '': 29 | self.pred_lexeme = original_word.lower() + u'/' + pos.lower() 30 | else: 31 | self.pred_lexeme = original_word + u'/' + pos.lower() 32 | self.const_lexeme = original_word 33 | self.wiki_lexeme = original_word 34 | self.nom_lexeme = original_word 35 | self.verb_lexeme = '' 36 | self.noun_lexeme = '' 37 | self.pos = pos 38 | self.is_const = False 39 | self.is_pred = False 40 | self.is_nom = False 41 | self.is_wiki = False 42 | self.is_ne = is_ne 43 | self.is_timex = is_timex 44 | self.ne_tag = ne_tag 45 | self.normalized_ne_tag = normalized_ne_tag 46 | self.constant_label = constant_label 47 | self.timex_attr = timex_attr 48 | self.edge = edge 49 | self.relation = relation 50 | self.char_start = char_start 51 | self.char_end = char_end 52 | self.pred_char_start = char_start 53 | self.pred_char_end = char_end 54 | self.const_char_start = char_start 55 | self.const_char_end = char_end 56 | 57 | self.children = [] 58 | 59 | '''Format:Text, CharacterOffsetBegin, CharacterOffsetEnd, PartOfSpeech, Lemma, NamedEntityTag, ''' 60 | @classmethod 61 | def parse_stanford_line(cls, line, name_normalize_dict): 62 | items = [] 63 | timex_attr = dict() 64 | is_timex = False 65 | attr_re = re.compile(r'(\w\w*)=\"(\w\w*)\"') 66 | for item in line.split(' '): 67 | if is_timex: 68 | m = attr_re.match(item) 69 | if m: 70 | timex_attr[m.group(1)] = m.group(2) 71 | else: 72 | break 73 | elif len(item) >= 3 and '=' in item[1:-1]: 74 | if 'Timex=' in item: 75 | is_timex = True 76 | else: 77 | items.append(item[item.index('=') + 1:]) 78 | else: 79 | items[-1] += '_' + item # no spaces in item 80 | word = items[0] 81 | char_start = int(items[1]) 82 | char_end = int(items[2]) 83 | pos = items[3] 84 | lemma = items[4] 85 | if name_normalize_dict.has_key(lemma): 86 | lemma = name_normalize_dict[lemma] 87 | 88 | is_ne = False 89 | ne_tag = '' 90 | normalized_ne_tag = '' 91 | if len(items) > 5 and items[5] <> 'O': 92 | is_ne = True 93 | ne_tag = items[5] 94 | if len(items) > 6: 95 | normalized_ne_tag = ' '.join(items[6].split('_')) 96 | return cls(lemma, word, pos, '', is_ne, is_timex, ne_tag, normalized_ne_tag, 97 | timex_attr=timex_attr, char_start=char_start, char_end=char_end) 98 | 99 | '''Format: Index, Word, Pos, constant_label, is_ne, is_timex, ne_tag, normalized_ne_tag.''' 100 | @classmethod 101 | def parse_conll_line(cls, conll_line): 102 | word = conll_line[1] 103 | original_word = conll_line[2] 104 | pos = conll_line[3] 105 | constant_label = '' if conll_line[4] == '_' else conll_line[4] 106 | is_ne = conll_line[5] == '1' 107 | is_timex = conll_line[6] == '1' 108 | ne_tag = '' if conll_line[7]=='_' else conll_line[7] 109 | normalized_ne_tag = '' if conll_line[8] == '_' else conll_line[8] 110 | return cls(word, original_word, pos, constant_label, is_ne, is_timex, 111 | ne_tag, normalized_ne_tag) 112 | 113 | def conll_line_str(self, i): 114 | conll_str = str(i+1) + '\t' + self.word + '\t' + self.original_word + '\t' 115 | conll_str += self.pos + '\t' 116 | conll_str += self.constant_label + '\t' if self.constant_label else '_\t' 117 | conll_str += '1\t' if self.is_ne else '0\t' 118 | conll_str += '1\t' if self.is_timex else '0\t' 119 | conll_str += self.ne_tag + '\t' if self.ne_tag else '_\t' 120 | if self.ne_tag and self.normalized_ne_tag: 121 | conll_str += self.normalized_ne_tag 122 | else: 123 | conll_str += '_' 124 | return conll_str + '\n' 125 | 126 | def col_line_str(self, i): 127 | col_str = str(i+1) + '\t' 128 | word = self.word.lower() 129 | if self.original_word == '_': 130 | col_str += '_\t' + self.word 131 | else: 132 | col_str += word + '\t' + self.pos 133 | return col_str + '\n' 134 | 135 | '''Format: Index, Word, Lemma, pos, pos, _, head_index, relation, _, _.''' 136 | @classmethod 137 | def parse_conll_dep_line(cls, conll_line): 138 | word = conll_line[2] 139 | original_word = conll_line[1] 140 | pos = conll_line[3] 141 | if conll_line[6] == '0': 142 | head_index = -1 143 | relation = '' 144 | else: 145 | head_index = int(conll_line[6]) - 1 146 | relation = conll_line[7] 147 | 148 | return cls(word, original_word, pos, '', False, False, '', '', head_index, 149 | relation) 150 | 151 | def conll_dep_line_str(self, i): 152 | # Assumes no dependencies. 153 | conll_str = str(i+1) + '\t' + self.word + '\t' 154 | conll_str += self.original_word if self.original_word else '_' 155 | conll_str += '\t' + self.pos + '\t' + self.pos + '\t_\t' 156 | if self.edge >= -1: 157 | conll_str += str(self.edge+1) + '\t' 158 | else: 159 | conll_str += '_\t' 160 | conll_str += self.relation if self.relation else '_' 161 | conll_str += '\t_\t_' 162 | return conll_str + '\n' 163 | 164 | def reset_char_spans(self, char_start, char_end): 165 | self.char_start = char_start 166 | self.char_end = char_end 167 | self.pred_char_start = char_start 168 | self.pred_char_end = char_end 169 | self.const_char_start = char_start 170 | self.const_char_end = char_end 171 | 172 | def find_wordnet_lemmas(self): 173 | ptb_tag_preffixes = ['J', 'R', 'V', 'N'] 174 | wordnet_tags = ['a', 'r', 'v', 'n'] 175 | wordnet_tag = [] 176 | for i, prefix in enumerate(ptb_tag_preffixes): 177 | if self.pos.startswith(prefix): 178 | wordnet_tag.append(wordnet_tags[i]) 179 | if prefix == 'J': 180 | wordnet_tag.append('s') 181 | derived_nouns = set() 182 | derived_verbs = set() 183 | for tag in wordnet_tag: 184 | for synset in wn.synsets(self.word, pos=tag): 185 | for wn_lemma in synset.lemmas(): 186 | for form in wn_lemma.derivationally_related_forms(): 187 | word = form.name() 188 | pos = form.synset().pos() 189 | if pos == 'n': 190 | derived_nouns.add(word) 191 | elif pos == 'v': 192 | derived_verbs.add(word) 193 | if derived_nouns and not self.pos.startswith('N'): 194 | derived_list = list(derived_nouns) 195 | derived_match = [mrs_util.prefix_sim_long(self.word.lower(), deriv) 196 | for deriv in derived_nouns] 197 | self.noun_lexeme = derived_list[derived_match.index(max(derived_match))] 198 | if derived_verbs and not self.pos.startswith('V'): 199 | derived_list = list(derived_verbs) 200 | derived_match = [mrs_util.prefix_sim_long(self.word.lower(), deriv) 201 | for deriv in derived_verbs] 202 | self.verb_lexeme = derived_list[derived_match.index(max(derived_match))] 203 | 204 | 205 | class Sentence(): 206 | def __init__(self, sentence, index_map=[], sent_ind=0): 207 | self.sentence = sentence # tokens 208 | self.const_vars = util.ConstantVars() 209 | self.sent_ind = sent_ind 210 | self.root_index = -1 211 | self.index_map = index_map 212 | 213 | def word_at(self, i): 214 | return self.sentence[i].word 215 | 216 | @classmethod 217 | def parse_conll(cls, sent_conll): 218 | sentence = [] 219 | for token_line in sent_conll: 220 | token = Token.parse_conll_line(token_line) 221 | sentence.append(token) 222 | return cls(sentence) 223 | 224 | @classmethod 225 | def parse_conll_dep(cls, sent_conll): 226 | sentence = [] 227 | for token_line in sent_conll: 228 | token = Token.parse_conll_dep_line(token_line) 229 | sentence.append(token) 230 | return cls(sentence) 231 | 232 | def original_sentence_str(self): 233 | words = [token.original_word for token in self.sentence] 234 | return ' '.join(words) + '\n' 235 | 236 | def pred_lexeme_str(self): 237 | words = [token.pred_lexeme for token in self.sentence] 238 | return ' '.join(words) + '\n' 239 | 240 | def pred_verb_lexeme_str(self): 241 | words = [] 242 | for token in self.sentence: 243 | if token.is_pred: 244 | word = token.pred_lexeme 245 | elif token.verb_lexeme: 246 | word = '_' + token.verb_lexeme 247 | else: 248 | word = token.pred_lexeme 249 | words.append(word) 250 | return ' '.join(words) + '\n' 251 | 252 | def nom_lexeme_str(self): 253 | words = [] 254 | for token in self.sentence: 255 | if token.is_nom: 256 | word = token.nom_lexeme 257 | elif token.noun_lexeme: 258 | word = '_' + token.noun_lexeme 259 | else: 260 | word = token.nom_lexeme 261 | words.append(word) 262 | return ' '.join(words) + '\n' 263 | 264 | def const_lexeme_str(self): 265 | conc = u'' 266 | for token in self.sentence: 267 | try: 268 | conc += u' ' + token.const_lexeme 269 | except UnicodeDecodeError: 270 | print 'Cannot write: ' + token.const_lexeme 271 | conc += u' X' 272 | return conc[1:] + u'\n' 273 | 274 | def wiki_lexeme_str(self): 275 | conc = u'' 276 | for token in self.sentence: 277 | #if token.is_wiki: 278 | # print token.wiki_lexeme 279 | try: 280 | conc += u' ' + token.wiki_lexeme 281 | except UnicodeDecodeError: 282 | print 'Cannot write: ' + token.wiki_lexeme 283 | conc += u' X' 284 | return conc[1:] + u'\n' 285 | 286 | 287 | def ch_span_str(self): 288 | words = [str(token.char_start) + ':' + str(token.char_end) 289 | for token in self.sentence] 290 | return ' '.join(words) + '\n' 291 | 292 | def pred_ch_span_str(self): 293 | words = [str(token.pred_char_start) + ':' + str(token.pred_char_end) 294 | for token in self.sentence] 295 | return ' '.join(words) + '\n' 296 | 297 | def const_ch_span_str(self): 298 | words = [str(token.const_char_start) + ':' + str(token.const_char_end) 299 | for token in self.sentence] 300 | return ' '.join(words) + '\n' 301 | 302 | def pos_str(self): 303 | words = [token.pos for token in self.sentence] 304 | words = ['_' + word if token.is_pred else word 305 | for word, token in zip(words, self.sentence)] 306 | return ' '.join(words) + '\n' 307 | 308 | def ne_tag_str(self): 309 | words = ['O' if token.ne_tag == '' else token.ne_tag 310 | for token in self.sentence] 311 | words = [word + '_C' if token.is_const else word 312 | for word, token in zip(words, self.sentence)] 313 | return ' '.join(words) + '\n' 314 | 315 | def raw_sentence_str(self, to_lower): 316 | if to_lower: 317 | words = [token.word.lower() for token in self.sentence] 318 | else: 319 | words = [token.word for token in self.sentence] 320 | return ' '.join(words) + '\n' 321 | 322 | def const_variable_sentence(self, varnames=[]): 323 | var_sent = [] 324 | const_vars = util.ConstantVars() 325 | index_map = [] 326 | 327 | constant_state = False 328 | constant_type = '' 329 | var_value = '' 330 | for i, token in enumerate(self.sentence): 331 | if token.constant_label: 332 | if constant_state and token.constant_label == 'I': 333 | if constant_type == 'name': 334 | const_vars.names[-1].append(token.word) 335 | if constant_type == 'number': 336 | const_vars.numbers[-1].append(token.word) 337 | var_value += '_' + token.word 338 | else: 339 | if constant_state: 340 | var_sent[-1].normalized_ne_tag = var_value 341 | constant_state = True 342 | constant_type = token.constant_label 343 | var_name = '' 344 | var_value = '' 345 | if constant_type == 'name': 346 | const_vars.names.append([token.word]) 347 | var_name = 'name' + str(len(const_vars.names)) 348 | elif constant_type == 'number': 349 | const_vars.numbers.append([token.word]) 350 | var_name = 'number' + str(len(const_vars.numbers)) 351 | if var_name and varnames and varnames[i]: 352 | var_name = varnames[i] 353 | if var_name: 354 | index_map.append(i) 355 | var_sent.append(Token(var_name, '_', token.pos, constant_type, 356 | False, False, token.ne_tag, '')) 357 | var_value = token.word 358 | else: 359 | print 'unknown constant', constant_type 360 | else: 361 | if constant_state: 362 | constant_state = False 363 | constant_type = '' 364 | var_sent[-1].normalized_ne_tag = var_value 365 | var_value = '' 366 | index_map.append(i) 367 | var_sent.append(Token(token.word, token.original_word, token.pos, '', 368 | False, False, '', '')) 369 | return Sentence(var_sent, index_map) 370 | 371 | def record_children(self): 372 | total_children = 0 373 | for i, token in enumerate(self.sentence): 374 | self.sentence[i].children = [] 375 | for i, token in enumerate(self.sentence): 376 | head_index = token.edge 377 | if head_index == -1: 378 | self.root_index = i 379 | else: 380 | self.sentence[head_index].children.append(i) 381 | total_children += 1 382 | for i, token in enumerate(self.sentence): 383 | self.sentence[i].children.sort() 384 | 385 | def sentence_str(self, remove_punct): 386 | words = [token.word for token in self.sentence] 387 | s = '' 388 | indexes = [] 389 | if remove_punct: 390 | for i, w in enumerate(words): 391 | if not util.is_punct(w): 392 | s += w + ' ' 393 | if w == 'NULL': 394 | i = -1 395 | indexes.append(i) 396 | s = s[:-1] + '\n' 397 | else: 398 | s = ' '.join(words) + '\n' 399 | indexes = range(len(words)) 400 | s_ind = ' '.join(map(str, indexes)) + '\n' 401 | return s.lower(), s_ind 402 | 403 | def linear_dep_str(self, i): 404 | # Pre-order traversal of graph. 405 | # No constants. 406 | graph_str = '( ' + self.sentence[i].word 407 | for child_index in self.sentence[i].children: 408 | graph_str += ' :' + self.sentence[child_index].relation + ' ' \ 409 | + self.linear_dep_str(child_index) 410 | return graph_str + ' )' 411 | 412 | def read_conll_sentences(sent_file_name): 413 | sent_file = open(sent_file_name, 'r') 414 | sent_conll_in = [[line.split('\t') for line in sent.split('\n')] 415 | for sent in sent_file.read().split('\n\n')[:-1]] 416 | sentences = [] 417 | for sent_conll in sent_conll_in: 418 | sentences.append(Sentence.parse_conll(sent_conll)) 419 | return sentences 420 | 421 | def read_conll_dep_sentences(sent_file_name): 422 | sent_file = open(sent_file_name, 'r') 423 | sent_conll_in = [[line.split('\t') for line in sent.split('\n')] 424 | for sent in sent_file.read().split('\n\n')[:-1]] 425 | sentences = [] 426 | for sent_conll in sent_conll_in: 427 | sentences.append(Sentence.parse_conll_dep(sent_conll)) 428 | return sentences 429 | 430 | def read_raw_sentences(sent_file_name): 431 | sent_file = open(sent_file_name, 'r') 432 | sent_in = [sent.split(' ') for sent in sent_file.read().split('\n')[:-1]] 433 | return sent_in 434 | 435 | -------------------------------------------------------------------------------- /mrs/stanford_to_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import re 18 | import codecs 19 | import string 20 | import sys 21 | 22 | import sentence as asent 23 | import util as mrs_util 24 | 25 | def convert_number(num_str, has_unit): 26 | try: 27 | unit_ind = 1 if num_str[0] in ['>', '<', '~'] else 0 28 | if len(num_str) > 1 and num_str[1] == '=': 29 | unit_ind = 2 30 | if has_unit: 31 | num_str = num_str[unit_ind+1:] 32 | else: 33 | num_str = num_str[unit_ind:] 34 | 35 | # For now, ignore first value in ranges. 36 | if '-' in num_str: 37 | num_str = num_str[(num_str.index('-') + 1):] 38 | value = float(num_str) 39 | try: 40 | value = int(value) 41 | except ValueError: 42 | value = value 43 | return str(value) 44 | except ValueError: 45 | print 'Cannot parse:', num_str 46 | return num_str 47 | 48 | def convert_period(period_str): 49 | rate_re = re.compile(r'P(\d\d*)([A-Z])') 50 | time_re = re.compile(r'PT(\d\d*)([A-Z])') 51 | period_m = rate_re.match(period_str) 52 | time_m = time_re.match(period_str) 53 | if period_m: 54 | period = int(period_m.group(1)) 55 | unit = period_m.group(2) 56 | if unit == 'Y': 57 | unit = 'year' 58 | elif unit == 'M': 59 | unit = 'month' 60 | elif unit == 'W': 61 | unit = 'week' 62 | elif unit == 'D': 63 | unit = 'day' 64 | return period, unit 65 | elif time_m: 66 | period = int(time_m.group(1)) 67 | unit = time_m.group(2) 68 | if unit == 'H': 69 | unit = 'hour' 70 | elif unit == 'M': 71 | unit = 'minute' 72 | elif unit == 'S': 73 | unit = 'second' 74 | return period, unit 75 | else: 76 | return 0, '' 77 | 78 | def read_sentences_normalize_ne(stanford_file_name): 79 | stanford_file = codecs.open(stanford_file_name, 'r', 'utf-8') 80 | 81 | sentences = [] 82 | tokens = [] 83 | 84 | token_alignments = [] 85 | text_line = '' 86 | 87 | state = False 88 | ne_state = False 89 | money_state = False 90 | percent_state = False 91 | number_state = False 92 | ordinal_state = False 93 | time_state = False 94 | date_state = False 95 | duration_state = False 96 | set_state = False 97 | last_ne_tag = '' 98 | token_counter = 0 99 | 100 | date_re = re.compile(r'^(\d\d\d\d|XXXX)-(\d\d|XX)-(\d\d|XX)$') 101 | date2_re = re.compile(r'^(\d\d\d\d|XXXX)-(\d\d|XX)$') 102 | date3_re = re.compile(r'^(\d\d\d\d|XXXX)$') 103 | 104 | for line in stanford_file: 105 | if line.startswith('Sentence #'): 106 | if state: 107 | sentences.append(asent.Sentence(tokens, token_alignments)) 108 | tokens = [] 109 | token_alignments = [] 110 | state = False 111 | ne_state = False 112 | money_state = False 113 | percent_state = False 114 | number_state = False 115 | ordinal_state = False 116 | time_state = False 117 | date_state = False 118 | duration_state = False 119 | set_state = False 120 | last_ne_tag = '' 121 | token_counter = 0 122 | elif line.startswith('[Text=') and line[-2]==']': 123 | token = asent.Token.parse_stanford_line(line[1:-2], {}) 124 | #For LOCATION, PERSON, ORGANIZATION, MISC. 125 | if ne_state and not (token.is_ne and token.ne_tag == last_ne_tag): 126 | ne_state = False 127 | if not ne_state and token.is_ne and token.ne_tag in \ 128 | ['LOCATION', 'PERSON', 'ORGANIZATION', 'MISC']: 129 | ne_state = True 130 | # Appends to the front. 131 | last_ne_tag = token.ne_tag 132 | token.constant_label = 'name' 133 | token.const_lexeme = token.word 134 | # For MONEY: 135 | if money_state and not (token.is_ne and token.ne_tag == 'MONEY'): 136 | money_state = False 137 | elif not money_state and token.is_ne and token.ne_tag == 'MONEY': 138 | money_state = True 139 | money_str = token.normalized_ne_tag 140 | if len(money_str) == 0: 141 | # Not treated as money. 142 | token.is_ne = False 143 | token.ne_tag = '' 144 | money_state = False 145 | elif len(money_str) > 1: # length 1 is for units 146 | unit_ind = 1 if money_str[0] in ['>', '<', '~'] else 0 147 | if money_str[1] == '=': 148 | unit_ind = 2 149 | token.const_lexeme = convert_number(money_str, True) 150 | # Percentage. 151 | if percent_state and not (token.is_ne and token.ne_tag == 'PERCENT'): 152 | percent_state = False 153 | elif not percent_state and token.is_ne and token.ne_tag == 'PERCENT': 154 | percent_state = True 155 | percent_str = token.normalized_ne_tag 156 | if len(percent_str) > 1: 157 | token.normalized_ne_tag = convert_number(percent_str, True) 158 | if number_state and not (token.is_ne and token.ne_tag == 'NUMBER'): 159 | number_state = False 160 | elif not number_state and token.is_ne and token.ne_tag == 'NUMBER': 161 | number_state = True 162 | number_str = token.normalized_ne_tag 163 | if len(number_str) == 0: 164 | number_state = False 165 | token.is_ne = False 166 | token.ne_tag = '' 167 | else: 168 | token.const_lexeme = convert_number(number_str, False) 169 | if ordinal_state and not (token.is_ne and token.ne_tag == 'ORDINAL'): 170 | ordinal_state = False 171 | elif not ordinal_state and token.is_ne and token.ne_tag == 'ORDINAL': 172 | ordinal_state = True 173 | number_str = token.normalized_ne_tag 174 | if len(number_str) == 0: 175 | number_state = False 176 | token.is_ne = False 177 | token.ne_tag = '' 178 | else: 179 | token.const_lexeme = convert_number(number_str, False) 180 | if time_state and not (token.is_timex 181 | and token.ne_tag in ['DATE', 'TIME']): 182 | time_state = False 183 | elif not time_state and (token.is_timex 184 | and token.ne_tag in ['DATE', 'TIME']): 185 | # The same date and time expression and contain both DATE and TIME. 186 | time_state = True 187 | if time_state and not date_state and token.ne_tag == 'DATE': 188 | # Only match pure date expressions 189 | # - cannot convert compound expressions cleanly enough. 190 | date_str = token.normalized_ne_tag 191 | if len(date_str.split()) == 1: 192 | # Strip time from string. 193 | if 'T' in date_str: 194 | date_str = date_str[:date_str.index('T')] 195 | if re.match(r'^\d\d\dX$', date_str): 196 | date_str = date_str[:3] + '0' 197 | if re.match(r'^\d\dXX$', date_str): 198 | date_str = date_str[:2] + '00' 199 | m = date_re.match(date_str) 200 | m2 = date2_re.match(date_str) 201 | m3 = date3_re.match(date_str) 202 | if m or m2 or m3: 203 | date_state = True 204 | if m: 205 | date_list = list(m.groups()) 206 | elif m2: 207 | date_list = list(m2.groups()) 208 | elif m3: 209 | date_list = list(m3.groups()) 210 | date_list = filter(lambda d: 'X' not in d, date_list) 211 | date_list = [convert_number(date, False) for date in date_list] 212 | if date_list: 213 | token.const_lexeme = date_list[0] 214 | #else don't handle as a date. 215 | if date_state and token.ne_tag <> 'DATE': 216 | date_state = False 217 | # For Duration: 218 | if duration_state and not (token.is_timex and token.ne_tag == 'DURATION'): 219 | duration_state = False 220 | elif not duration_state and token.is_timex and token.ne_tag == 'DURATION': 221 | duration_state = True 222 | time_str = token.normalized_ne_tag 223 | period, unit = convert_period(time_str) 224 | if period == 0: 225 | duration_state = False 226 | else: 227 | token.const_lexeme = str(period) 228 | token.ne_tag += '_' + unit 229 | # For SET: 230 | if set_state and not (token.is_timex and token.ne_tag == 'SET'): 231 | set_state = False 232 | elif not set_state and token.is_timex and token.ne_tag == 'SET': 233 | set_state = True 234 | freq = 1 235 | period = 0 236 | unit = '' 237 | if token.timex_attr.has_key('freq'): 238 | rate_re = re.compile(r'P(\d\d*)([A-Z])') 239 | freq_m = rate_re.match(token.timex_attr['freq']) 240 | freq = int(freq_m.group(1)) 241 | if token.timex_attr.has_key('periodicity'): 242 | period, unit = convert_period(token.timex_attr['periodicity']) 243 | if period == 0: 244 | set_state = False 245 | token.ne_tag = '' 246 | else: 247 | if freq > 1: 248 | token_ne_tag += '_rate' 249 | token.const_lexeme = str(period) 250 | token.ne_tag += '_temporal_' + unit 251 | # Identify numbers: 252 | if re.match(r'^[+-]?\d+(\.\d+)?$', token.word): 253 | if token.const_lexeme == '': 254 | token.const_lexeme = convert_number(token.word, False) 255 | token.constant_label = 'number' 256 | token.pred_lexeme = token.word 257 | tokens.append(token) 258 | state = True 259 | if state: 260 | sentences.append(asent.Sentence(tokens)) 261 | return sentences 262 | 263 | 264 | def read_sentences(stanford_file_name, file_id): 265 | stanford_file = codecs.open(stanford_file_name, 'r', 'utf-8') 266 | 267 | sentences = [] 268 | raw_sentences = [] 269 | tokens = [] 270 | 271 | text_line = '' 272 | state_line = '' 273 | sent_offset = 0 274 | state = False 275 | state1 = False 276 | 277 | for line in stanford_file: 278 | if line.startswith('Sentence #'): 279 | if state: 280 | sentences.append(asent.Sentence(tokens)) 281 | sentences[-1].offset = sent_offset 282 | sentences[-1].raw_txt = text_line 283 | sentences[-1].file_id = file_id 284 | text_line = '' 285 | state_line = '' 286 | tokens = [] 287 | state = False 288 | state1 = False 289 | elif len(line) > 1 and line[-2]==']' and (state or line.startswith('[Text=')): 290 | if state_line: 291 | token = asent.Token.parse_stanford_line(state_line + ' ' + line[:-2], {}) 292 | else: 293 | token = asent.Token.parse_stanford_line(line[1:-2], {}) 294 | if not state1: 295 | sent_offset = token.char_start 296 | ind_start = token.char_start - sent_offset 297 | ind_end = token.char_end - sent_offset 298 | token.reset_char_spans(ind_start, ind_end) 299 | 300 | word = token.original_word 301 | word = word.replace(u"\u00A0", "_") 302 | if '_' in word: 303 | split_word = word.split('_') 304 | split_inds = filter(lambda x: word[x] == '_', 305 | range(len(word))) 306 | first_word = word[:split_inds[0]] 307 | token.original_word = first_word 308 | token.word = first_word 309 | if normalize_ne: 310 | token.pred_lexeme = first_word.lower() 311 | else: 312 | token.pred_lexeme = first_word.lower() + u'/' + token.pos.lower() 313 | token.const_lexeme = first_word 314 | token.char_end = token.char_start + split_inds[0] 315 | tokens.append(token) 316 | for j, w in enumerate(split_word[1:]): 317 | char_start = token.char_start + split_inds[j] + 1 318 | if j + 1 < len(split_inds): 319 | char_end = token.char_start + split_inds[j+1] 320 | else: 321 | char_end = token.char_start + len(word) 322 | new_token = asent.Token(w, w, token.pos, token.constant_label, 323 | token.is_ne, token.is_timex, token.ne_tag, 324 | token.normalized_ne_tag, char_start=char_start, char_end=char_end) 325 | tokens.append(new_token) 326 | else: 327 | tokens.append(token) 328 | state = True 329 | state1 = True 330 | elif line.startswith('[Text='): 331 | state_line = line[1:].strip() 332 | state = True 333 | else: #if line.strip(): 334 | if state: 335 | state_line += ' ' + line.strip() 336 | else: 337 | text_line += line.replace('\n', ' ') 338 | if state: 339 | sentences.append(asent.Sentence(tokens)) 340 | sentences[-1].offset = sent_offset 341 | sentences[-1].raw_txt = text_line 342 | sentences[-1].file_id = file_id 343 | return sentences 344 | 345 | 346 | def process_stanford(input_dir, working_dir, erg_dir, set_name, 347 | use_pred_lexicon=True, use_const_lexicon=True, normalize_ne=False, 348 | read_epe=False): 349 | nom_map = {} 350 | wiki_map = {} 351 | if use_pred_lexicon: 352 | pred_map = mrs_util.read_lexicon(erg_dir + 'predicates.lexicon') 353 | if normalize_ne: 354 | nom_map = mrs_util.read_lexicon(erg_dir + 'nominals.lexicon') 355 | else: 356 | pred_map = {} 357 | if use_const_lexicon: 358 | const_map = mrs_util.read_lexicon(erg_dir + 'constants.lexicon') 359 | if normalize_ne: 360 | wiki_map = mrs_util.read_lexicon(erg_dir + 'wiki.lexicon') 361 | else: 362 | const_map = {} 363 | 364 | if read_epe: 365 | file_ids = [] 366 | in_type = input_dir[4:-1] 367 | file_list = open(in_type + '.' + set_name + '.list', 'r').read().split('\n')[:-1] 368 | file_ids = [name[name.rindex('/')+1:] for name in file_list] 369 | sentences = [] 370 | for file_id in file_ids: 371 | sentences.extend(read_sentences( 372 | (working_dir + '/raw-' + set_name + '/' + file_id + '.out'), 373 | file_id)) 374 | else: 375 | suffix = '.raw' 376 | if normalize_ne: 377 | sentences = read_sentences_normalize_ne((working_dir + set_name + suffix + '.out')) 378 | else: 379 | sentences = read_sentences((working_dir + set_name + suffix + '.out'), '0') 380 | 381 | max_token_span_length = 5 382 | for i, sent in enumerate(sentences): 383 | for j, token in enumerate(sent.sentence): 384 | if normalize_ne: 385 | sentences[i].sentence[j].find_wordnet_lemmas() 386 | 387 | # Matches lexemes. 388 | lexeme = '' 389 | if token.original_word in const_map: 390 | lexeme = const_map[token.original_word] 391 | elif token.original_word.lower() in const_map: 392 | lexeme = const_map[token.original_word.lower()] 393 | elif token.word in const_map: 394 | lexeme = const_map[token.word] 395 | if lexeme <> '': 396 | sentences[i].sentence[j].const_lexeme = lexeme 397 | sentences[i].sentence[j].is_const = True 398 | 399 | lexeme = '' 400 | if token.original_word in pred_map: 401 | lexeme = pred_map[token.original_word] 402 | elif token.original_word.lower() in pred_map: 403 | lexeme = pred_map[token.original_word.lower()] 404 | elif token.word in pred_map: # lemma 405 | lexeme = pred_map[token.word] 406 | if normalize_ne: 407 | nom_lexeme = '' 408 | if token.original_word in nom_map: 409 | nom_lexeme = nom_map[token.original_word] 410 | elif token.original_word.lower() in nom_map: 411 | nom_lexeme = nom_map[token.original_word.lower()] 412 | elif token.word in nom_map: # lemma 413 | nom_lexeme = nom_map[token.word] 414 | if nom_lexeme == '': 415 | sentences[i].sentence[j].nom_lexeme = '_' + token.word 416 | else: 417 | sentences[i].sentence[j].nom_lexeme = nom_lexeme 418 | sentences[i].sentence[j].is_nom = True 419 | 420 | if not normalize_ne: 421 | if len(lexeme) > 2 and '+' in lexeme[:-1]: 422 | lexeme = lexeme[:lexeme.index('+')] 423 | elif len(lexeme) > 2 and '-' in lexeme[:-1]: 424 | lexeme = lexeme[:lexeme.index('-')] 425 | 426 | if lexeme <> '': 427 | sentences[i].sentence[j].is_pred = True 428 | if normalize_ne and lexeme == '': # for AMR 429 | lexeme = '_' + token.word # lemma 430 | if lexeme <> '': 431 | sentences[i].sentence[j].pred_lexeme = lexeme 432 | 433 | # Matches multi-token expressions. 434 | orth = token.original_word 435 | for k in range(j+1, min(j+max_token_span_length-1, len(sent.sentence))): 436 | orth += ' ' + sent.sentence[k].original_word 437 | if orth in const_map: 438 | sentences[i].sentence[j].const_lexeme = const_map[orth] 439 | sentences[i].sentence[j].const_char_end = sentences[i].sentence[k].char_end 440 | sentences[i].sentence[j].is_const = True 441 | if orth in pred_map: 442 | if normalize_ne: 443 | first_pred = pred_map[orth] 444 | elif len(pred_map[orth]) > 2 and '+' in pred_map[orth][:-1]: 445 | first_pred = pred_map[orth][:pred_map[orth].index('+')] 446 | elif len(pred_map[orth]) > 2 and '-' in pred_map[orth][:-1]: 447 | first_pred = pred_map[orth][:pred_map[orth].index('-')] 448 | else: 449 | first_pred = pred_map[orth] 450 | sentences[i].sentence[j].pred_lexeme = first_pred 451 | sentences[i].sentence[j].pred_char_end = sentences[i].sentence[k].pred_char_end 452 | sentences[i].sentence[j].is_pred = True 453 | 454 | if normalize_ne: 455 | wiki_lexeme = '' 456 | if token.original_word in wiki_map: 457 | wiki_lexeme = wiki_map[token.original_word] 458 | elif token.original_word.lower() in wiki_map: 459 | wiki_lexeme = wiki_map[token.original_word.lower()] 460 | elif token.word in wiki_map: # lemma 461 | wiki_lexeme = wiki_map[token.word] 462 | elif token.word.lower() in wiki_map: 463 | wiki_lexeme = wiki_map[token.word.lower()] 464 | if wiki_lexeme == '': 465 | sentences[i].sentence[j].wiki_lexeme = token.const_lexeme 466 | else: 467 | sentences[i].sentence[j].wiki_lexeme = wiki_lexeme 468 | sentences[i].sentence[j].is_wiki = True 469 | 470 | return sentences 471 | 472 | ''' 473 | Processing performed: Tokenize, lemmize, normalize numbers and time 474 | expressions, insert variable tokens for named entities etc. 475 | ''' 476 | if __name__=='__main__': 477 | assert len(sys.argv) >= 4 478 | input_dir = sys.argv[1] + '/' 479 | working_dir = sys.argv[2] + '/' 480 | erg_dir = sys.argv[3] + '/' 481 | 482 | read_epe = len(sys.argv) > 4 and '-epe' in sys.argv[4:] 483 | 484 | set_list = ['train', 'dev', 'test'] 485 | normalize_ne = len(sys.argv) > 4 and '-n' in sys.argv[4:] 486 | 487 | use_pred_lexicon = True 488 | use_const_lexicon = True 489 | 490 | for set_name in set_list: 491 | sentences = process_stanford(input_dir, working_dir, erg_dir, set_name, 492 | use_pred_lexicon, use_const_lexicon, normalize_ne, read_epe) 493 | 494 | sent_output_file = open(working_dir + set_name + '.en', 'w') 495 | sent_offsets_file = open(working_dir + set_name + '.off', 'w') 496 | sent_ids_file = open(working_dir + set_name + '.ids', 'w') 497 | sent_txt_file = open(working_dir + set_name + '.txt', 'w') 498 | pred_output_file = open(working_dir + set_name + '.lex.pred', 'w') 499 | const_output_file = open(working_dir + set_name + '.lex.const', 'w') 500 | wiki_output_file = open(working_dir + set_name + '.lex.wiki', 'w') 501 | pos_output_file = open(working_dir + set_name + '.pos', 'w') 502 | ne_output_file = open(working_dir + set_name + '.ne', 'w') 503 | span_output_file = open(working_dir + set_name + '.span', 'w') 504 | pred_span_output_file = open(working_dir + set_name + '.span.pred', 'w') 505 | const_span_output_file = open(working_dir + set_name + '.span.const', 'w') 506 | if normalize_ne: 507 | nom_output_file = open(working_dir + set_name + '.lex.nom', 'w') 508 | 509 | for sent in sentences: 510 | out_str = sent.original_sentence_str() 511 | sent_output_file.write(out_str.encode('utf-8', 'replace')) 512 | if normalize_ne: 513 | lex_str = sent.pred_verb_lexeme_str() 514 | pred_output_file.write(lex_str.encode('utf-8', 'replace')) 515 | lex_str = sent.nom_lexeme_str() 516 | nom_output_file.write(lex_str.encode('utf-8', 'replace')) 517 | else: 518 | lex_str = sent.pred_lexeme_str() 519 | pred_output_file.write(lex_str.encode('utf-8', 'replace')) 520 | lex_str = sent.const_lexeme_str() 521 | lex_enc = lex_str.encode('utf-8', 'replace') 522 | const_output_file.write(lex_enc) 523 | lex_str = sent.wiki_lexeme_str() 524 | lex_enc = lex_str.encode('utf-8', 'replace') 525 | sent_offsets_file.write(str(sent.offset) + '\n') 526 | sent_ids_file.write(str(sent.file_id) + '\n') 527 | txt_enc = sent.raw_txt.encode('utf-8', 'replace') 528 | sent_txt_file.write(txt_enc + '\n') 529 | wiki_output_file.write(lex_enc) 530 | pos_output_file.write(sent.pos_str()) 531 | ne_output_file.write(sent.ne_tag_str()) 532 | span_output_file.write(sent.ch_span_str()) 533 | const_span_output_file.write(sent.const_ch_span_str()) 534 | pred_span_output_file.write(sent.pred_ch_span_str()) 535 | 536 | -------------------------------------------------------------------------------- /mrs/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Jan Buys. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import re 17 | import os 18 | import string 19 | 20 | class ConstantVars(): 21 | def __init__(self): 22 | self.names = [] 23 | self.numbers = [] 24 | 25 | def prefix_sim(a, b): 26 | delta = 0.1 27 | l = min(len(a), len(b)) 28 | prefix_length = len(os.path.commonprefix([a, b])) 29 | return (prefix_length + delta)/(min(len(a), len(b)) + delta) 30 | 31 | def prefix_sim_long(a, b): 32 | delta = 0.1 33 | prefix_length = len(os.path.commonprefix([a, b])) 34 | return (prefix_length + delta)/(max(len(a), len(b)) + delta) 35 | 36 | def is_quoted(item): 37 | return len(item) > 2 and item[0] == '"' and item[-1] == '"' 38 | 39 | def remove_quotes(s): 40 | return re.sub(r'^\"(\S+)\"$', r'\1', s) 41 | 42 | ''' Remove concept id's and string quotes. ''' 43 | def clean_concept(s, remove_ids): 44 | if remove_ids: 45 | s = re.sub(r'^([^\s\d]+)-\d\d$', r'\1', s) 46 | s = remove_quotes(s) 47 | return s 48 | 49 | ''' Non-content-bearing punctuation. ''' 50 | def is_punct(w): 51 | punct = ['\'', '\"', '.', ',', ':', '-'] 52 | return w in punct 53 | 54 | def clean_punct(w): 55 | if w[0] == '"' or w[0] == "'": 56 | w = w[1:] 57 | if w[-1] == "'" or w[-1] == '"': 58 | w = w[:-1] 59 | if w <> '' and w[0] in string.punctuation: 60 | w = w[1:] 61 | while w <> '' and w[-1] <> '.' and w[-1] in string.punctuation: 62 | w = w[:-1] 63 | if w <> '' and w[-1] == '.' and '.' not in w[:-1]: 64 | w = w[:-1] 65 | return w 66 | 67 | 68 | def index_sort(align_ind): 69 | new_ind = range(len(align_ind)) 70 | new_ind = sorted(new_ind, key = lambda i: align_ind[i]) 71 | return new_ind 72 | 73 | 74 | def read_lexicon(lex_filename): 75 | lex_map = {} 76 | lex_file = open(lex_filename, 'r') 77 | state = False 78 | orth = '' 79 | for line in lex_file: 80 | if state: 81 | if line.strip() <> '_': 82 | lex_map[orth] = line.strip() 83 | else: 84 | orth = line.strip() 85 | state = not state 86 | return lex_map 87 | 88 | 89 | def separate_brackets(amr_str): 90 | new_str = '' 91 | open_quotes = False 92 | for ch in amr_str: 93 | if ch == '\"': 94 | open_quotes = not open_quotes 95 | if ch == '(' and not open_quotes: 96 | new_str += ch + ' ' 97 | elif ch == ')' and not open_quotes: 98 | new_str += ' ' + ch 99 | elif ch == ' ' and open_quotes: # replaces spaces inside quotes 100 | new_str += '_' 101 | else: 102 | new_str += ch 103 | return new_str 104 | 105 | -------------------------------------------------------------------------------- /rnn/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Modifications copyright 2017 Jan Buys. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Utilities for preprocessing data and vocabularies.""" 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gzip 23 | import os 24 | import random 25 | import re 26 | import tarfile 27 | 28 | from six.moves import urllib 29 | 30 | from tensorflow.python.platform import gfile 31 | 32 | # Special vocabulary symbols - we always put them at the start. 33 | _PAD = b"_PAD" 34 | _GO = b"_GO" 35 | _EOS = b"_EOS" 36 | _UNK = b"_UNK" 37 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 38 | 39 | PAD_ID = 0 40 | GO_ID = 1 41 | EOS_ID = 2 42 | UNK_ID = 3 43 | 44 | GEN_STATE = 0 45 | PAD_STATE = 1 46 | RE_STATE = 2 47 | ARC_STATE = 3 48 | ROOT_STATE = 4 49 | 50 | NUM_TR_STATES = 5 51 | 52 | NO_ATTENTION_DECODER_STATE = 0 53 | ATTENTION_DECODER_STATE = 1 54 | LINEAR_POINTER_DECODER_STATE = 2 55 | HARD_ATTENTION_DECODER_STATE = 3 56 | LINEAR_FEED_POINTER_DECODER_STATE = 4 57 | HARD_ATTENTION_ARC_EAGER_DECODER_STATE = 5 58 | STACK_DECODER_STATE = 6 59 | PURE_STACK_DECODER_STATE = 7 60 | MEMORY_STACK_DECODER_STATE = 8 61 | 62 | MAX_OUTPUT_SIZE = 300 #TODO parameterize 63 | 64 | # Regular expressions used to tokenize. 65 | _WORD_SPLIT = re.compile(b"([.,!?\"':;)(])") 66 | _DIGIT_RE = re.compile(br"\d") 67 | 68 | def basic_tokenizer(sentence): 69 | """Very basic tokenizer: split the sentence into a list of tokens.""" 70 | words = [] 71 | for space_separated_fragment in sentence.strip().split(): 72 | words.extend(re.split(_WORD_SPLIT, space_separated_fragment)) 73 | return [w for w in words if w] 74 | 75 | def space_tokenizer(sentence): 76 | """Tokenize only on whitespace.""" 77 | words = sentence.strip().split() 78 | return words 79 | 80 | def create_vocabulary(vocabulary_path, singleton_vocabulary_path, data_path, 81 | max_vocabulary_size, singleton_keep_prob, col_format, 82 | tokenizer=None, normalize_digits=False, dic=None, col=1): 83 | """Create vocabulary file (if it does not exist yet) from data file. 84 | 85 | Data file is assumed to contain one sentence per line. Each sentence is 86 | tokenized and digits are normalized (if normalize_digits is set). 87 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 88 | We write it to vocabulary_path in a one-token-per-line format, so that later 89 | token in the first line gets id=0, second line gets id=1, and so on. 90 | 91 | Args: 92 | vocabulary_path: path where the vocabulary will be created. 93 | data_path: data file that will be used to create vocabulary. 94 | max_vocabulary_size: limit on the size of the created vocabulary. 95 | tokenizer: a function to use to tokenize each data sentence; 96 | if None, space_tokenizer will be used. 97 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 98 | singleton_keep_prob: probability to include singletons in the vocabulary. 99 | """ 100 | if not gfile.Exists(vocabulary_path): 101 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 102 | if dic is not None: 103 | print("Dic is not None.") 104 | vocab = {} 105 | singleton_vocab = set() 106 | with gfile.GFile(data_path, mode="rb") as f: 107 | counter = 0 108 | for line in f: 109 | if not line.strip(): 110 | continue 111 | counter += 1 112 | if counter % 100000 == 0: 113 | print(" processing line %d" % counter) 114 | if col_format: 115 | tokens = [line.strip().split()[col]] 116 | else: 117 | tokens = tokenizer(line) if tokenizer else space_tokenizer(line) 118 | for w in tokens: 119 | word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w 120 | if word in vocab: 121 | vocab[word] += 1 122 | elif not word in _START_VOCAB: 123 | if dic is not None and dic.has_key(word): 124 | vocab[word] = 1 125 | elif word in singleton_vocab: 126 | vocab[word] = 2 127 | elif not col_format or word <> b"_": 128 | singleton_vocab.add(word) 129 | for word in singleton_vocab: 130 | if word not in vocab and random.random() < singleton_keep_prob: 131 | vocab[word] = 1 132 | singletons = set() 133 | for word in vocab: 134 | if vocab[word] == 1 and map_restricted_state(word) == GEN_STATE: 135 | singletons.add(word) 136 | vocab_list = [w for w in _START_VOCAB] 137 | if col_format: 138 | vocab_list.append(b"_") 139 | vocab_list += sorted(vocab, key=vocab.get, reverse=True) 140 | if max_vocabulary_size > 0 and len(vocab_list) > max_vocabulary_size: 141 | vocab_list = vocab_list[:max_vocabulary_size] 142 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 143 | for w in vocab_list: 144 | vocab_file.write(w + b"\n") 145 | with gfile.GFile(singleton_vocabulary_path, mode="wb") as vocab_file: 146 | if not singletons: 147 | vocab_file.write(b"") 148 | for w in singletons: 149 | vocab_file.write(w + b"\n") 150 | 151 | def initialize_vocabulary(vocabulary_path): 152 | """Initialize vocabulary from file. 153 | 154 | We assume the vocabulary is stored one-item-per-line, so a file: 155 | dog 156 | cat 157 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 158 | also return the reversed-vocabulary ["dog", "cat"]. 159 | 160 | Args: 161 | vocabulary_path: path to the file containing the vocabulary. 162 | 163 | Returns: 164 | a pair: the vocabulary (a dictionary mapping string to integers), and 165 | the reversed vocabulary (a list, which reverses the vocabulary mapping). 166 | 167 | Raises: 168 | ValueError: if the provided vocabulary_path does not exist. 169 | """ 170 | if gfile.Exists(vocabulary_path): 171 | rev_vocab = [] 172 | with gfile.GFile(vocabulary_path, mode="rb") as f: 173 | rev_vocab.extend(f.readlines()) 174 | rev_vocab = [line.strip() for line in rev_vocab] 175 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 176 | return vocab, rev_vocab 177 | else: 178 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 179 | 180 | 181 | def read_buckets(bucket_path): 182 | """Read bucket sizes from file.""" 183 | if gfile.Exists(bucket_path): 184 | lines = [] 185 | with gfile.GFile(bucket_path, mode="rb") as f: 186 | lines.extend(f.readlines()) 187 | buckets = [] 188 | for line in lines: 189 | entry = line.strip().split(b" ") 190 | buckets.append((int(entry[0]), int(entry[1]))) 191 | return buckets 192 | else: 193 | raise ValueError("Bucket file %s not found.", bucket_path) 194 | 195 | 196 | def read_word_vectors(word_vector_path, vocabulary_path): 197 | """Read word vectors from file.""" 198 | if gfile.Exists(word_vector_path): 199 | vocab_list = [] 200 | lines = [] 201 | with gfile.GFile(word_vector_path, mode="rb") as f: 202 | lines.extend(f.readlines()) 203 | vector_dict = dict() 204 | for line in lines: 205 | entry = line.strip().split(b" ") 206 | word = entry[0] # preserve case 207 | vocab_list.append(word) 208 | vector_dict[word] = map(float, entry[1:]) 209 | # Add unk entry 210 | assert not vector_dict.has_key('_UNK') 211 | vocab_list.append('_UNK') 212 | vector_dict['_UNK'] = [0.0 for _ in range(len(vector_dict['']))] 213 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 214 | for w in vocab_list: 215 | vocab_file.write(w + b"\n") 216 | return vector_dict 217 | else: 218 | raise ValueError("Vocabulary file %s not found.", word_vector_path) 219 | 220 | def sentence_to_token_ids(sentence, vocabulary, 221 | tokenizer=None, normalize_digits=False, has_unk=True): 222 | """Convert a string to list of integers representing token-ids. 223 | 224 | For example, a sentence "I have a dog" may become tokenized into 225 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 226 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 227 | 228 | Args: 229 | sentence: a string, the sentence to convert to token-ids. 230 | vocabulary: a dictionary mapping tokens to integers. 231 | tokenizer: a function to use to tokenize each sentence; 232 | if None, space_tokenizer will be used. 233 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 234 | 235 | Returns: 236 | a list of integers, the token-ids for the sentence. 237 | """ 238 | if tokenizer: 239 | words = tokenizer(sentence) 240 | else: 241 | words = space_tokenizer(sentence) 242 | tokens = [] 243 | for w in words: 244 | if normalize_digits: 245 | w = _DIGIT_RE.sub(b"0", w) 246 | if vocabulary.has_key(w): 247 | tokens.append(vocabulary[w]) 248 | elif w.startswith(b":") and vocabulary.has_key(b":op("): 249 | # Special UNK handling for single sequence model. 250 | if w.endswith(b"("): 251 | tokens.append(vocabulary[b":op("]) 252 | elif w.endswith(b"()") or w.endswith(b"(*)"): 253 | tokens.append(vocabulary[b":op()"]) 254 | elif has_unk: 255 | tokens.append(UNK_ID) 256 | else: 257 | tokens.append(len(vocabulary)-1) 258 | elif has_unk: 259 | tokens.append(UNK_ID) 260 | else: 261 | tokens.append(len(vocabulary)-1) 262 | return tokens 263 | 264 | 265 | def dict_to_token_ids(en_fr_dict, en_vocab, fr_vocab): 266 | ids_dict = dict() 267 | for word, words in en_fr_dict.iteritems(): 268 | en_id = en_vocab.get(word, UNK_ID) 269 | if en_id <> UNK_ID: 270 | ids_dict[en_id] = [fr_vocab.get(w, UNK_ID) for w in words] 271 | 272 | return ids_dict 273 | 274 | 275 | def dict_data_to_token_ids(dict_path, null_path, en_vocab, fr_vocab): 276 | with gfile.GFile(dict_path, mode="rb") as dict_file: 277 | en_fr_dict = dict() 278 | for line in dict_file: 279 | entry = line.strip().split(b"\t") 280 | if len(entry) > 1: 281 | en_fr_dict[entry[0]] = entry[1:] 282 | vocab_dict = dict_to_token_ids(en_fr_dict, en_vocab, fr_vocab) 283 | with gfile.GFile(null_path, mode="rb") as null_vocab_file: 284 | null_vocab = [] 285 | for line in null_vocab_file: 286 | if line.strip(): 287 | null_vocab.append(line.strip()) 288 | null_ids = [fr_vocab.get(w, UNK_ID) for w in null_vocab] 289 | null_ids.append(UNK_ID) 290 | return null_ids, vocab_dict 291 | 292 | 293 | def id_vocab_sets(fr_vocab): 294 | id_sets = [set() for _ in xrange(NUM_TR_STATES)] 295 | for word, ind in fr_vocab.iteritems(): 296 | state = map_restricted_state(word) 297 | id_sets[state].add(ind) 298 | return id_sets 299 | 300 | 301 | def map_restricted_state(word): 302 | state = GEN_STATE 303 | if ((word.startswith(b":") and word.endswith(b"(")) 304 | or word.startswith(b"LA:") or word.startswith(b"RA:") 305 | or word.startswith(b"UA:") or word.startswith(b"STACK*")): 306 | state = ARC_STATE 307 | elif word == b"ROOT": 308 | state = ROOT_STATE 309 | elif word == b")" or word == b"RE": 310 | state = RE_STATE 311 | elif word in [_PAD, _EOS]: 312 | state = PAD_STATE 313 | return state 314 | 315 | 316 | def construct_transition_map(vocab_sets, restrict_vocab): 317 | transitions = [range(0, 5) for _ in xrange(5)] 318 | 319 | restrictions = [] 320 | for indexes in transitions: 321 | restr = vocab_sets[indexes[0]] 322 | for k in indexes[1:]: 323 | restr = restr.union(vocab_sets[k]) 324 | restr_list = list(restr) 325 | restr_list.sort() 326 | restrictions.append(restr_list) 327 | return restrictions 328 | 329 | 330 | def extract_decoder_vocab(sent, vocab_dict): 331 | decoder_vocab = set() 332 | for en_id in sent: 333 | if type(en_id) == tuple: 334 | en_id = en_id[0] 335 | if vocab_dict.has_key(en_id): 336 | for fr_id in vocab_dict[en_id]: 337 | if fr_id <> UNK_ID: 338 | decoder_vocab.add(fr_id) 339 | return decoder_vocab 340 | 341 | 342 | def encoder_decoder_vocab_map_to_token_ids(map_path, source_vocab, target_vocab): 343 | with gfile.GFile(map_path, mode="rb") as map_file: 344 | concept_map = [-1 for _ in source_vocab] 345 | for line in map_file: 346 | word, concept = line.strip().split(b"\t")[0], line.strip().split(b"\t")[1] 347 | if source_vocab.has_key(word): 348 | word_id = source_vocab[word] 349 | if target_vocab.has_key(concept): 350 | concept_id = target_vocab[concept] 351 | else: 352 | concept_id = UNK_ID 353 | concept_map[word_id] = concept_id 354 | return concept_map 355 | 356 | 357 | def data_to_token_ids(data_path, target_path, vocabulary_path, 358 | tokenizer=None, normalize_digits=False, has_unk=True): 359 | """Tokenize data file and turn into token-ids using given vocabulary file. 360 | 361 | This function loads data line-by-line from data_path, calls the above 362 | sentence_to_token_ids, and saves the result to target_path. See comment 363 | for sentence_to_token_ids on the details of token-ids format. 364 | 365 | Args: 366 | data_path: path to the data file in one-sentence-per-line format. 367 | target_path: path where the file with token-ids will be created. 368 | vocabulary_path: path to the vocabulary file. 369 | tokenizer: a function to use to tokenize each sentence; 370 | if None, space_tokenizer will be used. 371 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 372 | """ 373 | if not gfile.Exists(target_path): 374 | print("Tokenizing data in %s" % data_path) 375 | vocab, _ = initialize_vocabulary(vocabulary_path) 376 | with gfile.GFile(data_path, mode="rb") as data_file: 377 | with gfile.GFile(target_path, mode="w") as tokens_file: 378 | counter = 0 379 | for line in data_file: 380 | counter += 1 381 | if counter % 100000 == 0: 382 | print(" tokenizing line %d" % counter) 383 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, 384 | normalize_digits, has_unk) 385 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 386 | 387 | 388 | def prepare_mrs_data(data_dir, source_dir, data_type, set_name, 389 | create_vocab): 390 | """Get data into data_dir, create vocabulary, convert to ids.""" 391 | if data_type == 'em': 392 | path = os.path.join(source_dir, set_name + ".en") 393 | else: 394 | path = os.path.join(source_dir, set_name + "." + data_type) 395 | ids_path = os.path.join(data_dir, set_name + "." + data_type + ".ids") 396 | vocab_path = os.path.join(data_dir, "vocab." + data_type) 397 | sing_vocab_path = os.path.join(data_dir, "singleton-vocab." + data_type) 398 | 399 | if create_vocab and data_type <> 'em': 400 | create_vocabulary(vocab_path, sing_vocab_path, path, 0, 1.0, False, 401 | space_tokenizer) 402 | 403 | data_to_token_ids(path, ids_path, vocab_path, space_tokenizer, False, 404 | data_type <> 'em') 405 | return ids_path, vocab_path, sing_vocab_path 406 | 407 | 408 | def copy_mrs_data(data_dir, source_dir, data_type, set_name): 409 | """Create copy of data file.""" 410 | source_path = os.path.join(source_dir, set_name + "." + data_type) 411 | target_path = os.path.join(data_dir, set_name + "." + data_type + ".pnt") 412 | 413 | if not gfile.Exists(target_path): 414 | with gfile.GFile(source_path, mode="rb") as source_file: 415 | with gfile.GFile(target_path, mode="w") as target_file: 416 | for line in source_file: 417 | line = line.strip() 418 | target_file.write(line + "\n") 419 | return target_path 420 | 421 | 422 | def read_ids_file(source_path, max_size): 423 | source_input = [] 424 | print("Reading ids file:", source_path) 425 | with gfile.GFile(source_path, mode="r") as source_file: 426 | source = source_file.readline() 427 | counter = 0 428 | while source and (not max_size or counter < max_size): 429 | counter += 1 430 | source_ids = [int(x) for x in source.split()] 431 | source_input.append(source_ids) 432 | source = source_file.readline() 433 | return source_input 434 | 435 | 436 | def write_output_file(output_path, output_lines): 437 | with gfile.GFile(output_path, mode="w") as output_file: 438 | for line in output_lines: 439 | output_file.write(line + "\n") 440 | 441 | 442 | -------------------------------------------------------------------------------- /rnn/seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Modifications copyright 2017 Jan Buys. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Library for creating sequence-to-sequence models in TensorFlow. 18 | 19 | Sequence-to-sequence recurrent neural networks can learn complex functions 20 | that map input sequences to output sequences. These models yield very good 21 | results on a number of tasks, such as speech recognition, parsing, machine 22 | translation, or even constructing automated replies to emails. 23 | 24 | Before using this module, it is recommended to read the TensorFlow tutorial 25 | on sequence-to-sequence models. It explains the basic concepts of this module 26 | and shows an end-to-end example of how to build a translation model. 27 | https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 28 | 29 | Here is an overview of functions available in this module. They all use 30 | a very similar interface, so after reading the above tutorial and using 31 | one of them, others should be easy to substitute. 32 | 33 | * Full sequence-to-sequence models. 34 | - embedding_attention_seq2seq: Advanced model with input embedding and 35 | the neural attention mechanism; recommended for complex tasks. 36 | Attention optional. 37 | 38 | * Decoders (when you write your own encoder, you can use these to decode; 39 | e.g., if you want to write a model that generates captions for images). 40 | - rnn_decoder: The basic decoder based on a pure RNN. 41 | - attention_decoder: A decoder that uses the attention mechanism. 42 | 43 | * Losses. 44 | - sequence_loss: Loss for a sequence model returning average log-perplexity. 45 | - sequence_loss_by_example: As above, but not averaging over all examples. 46 | 47 | * model_with_buckets: A convenience function to create models with bucketing 48 | (see the tutorial above for an explanation of why and how to use it). 49 | """ 50 | 51 | from __future__ import absolute_import 52 | from __future__ import division 53 | from __future__ import print_function 54 | 55 | import numpy as np 56 | 57 | # We disable pylint because we need python3 compatibility. 58 | from six.moves import xrange # pylint: disable=redefined-builtin 59 | from six.moves import zip # pylint: disable=redefined-builtin 60 | 61 | import tensorflow as tf 62 | 63 | import data_utils 64 | import seq2seq_helpers 65 | import seq2seq_decoders 66 | 67 | linear = tf.nn.rnn_cell._linear # pylint: disable=protected-access 68 | 69 | def embedding_attention_decoder(decoder_type, decoder_inputs, 70 | encoder_inputs, 71 | initial_state, attention_states, 72 | cell, aux_cell, mem_cell, 73 | decoder_vocab_sizes, 74 | decoder_embedding_sizes, 75 | decoder_restrictions=None, 76 | num_heads=1, 77 | predict_span_end_pointers=False, 78 | output_projections=None, 79 | transition_state_map=None, 80 | encoder_decoder_vocab_map=None, 81 | feed_previous=False, 82 | update_embedding_for_previous=True, 83 | dtype=tf.float32, scope=None, 84 | initial_state_attention=False): 85 | """RNN decoder with embedding and attention and multiple decoder models. 86 | 87 | Args: 88 | decoder_type: int indicating decoder to be used for encoder-decoder models 89 | - constants defined in data_utils.py. 90 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 91 | encoder_label_inputs: A list of 1D int32 Tensors of shape [batch_size]. 92 | initial_state: 2D Tensor [batch_size x cell.state_size]. 93 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 94 | cell: rnn_cell.RNNCell defining the cell function. 95 | aux_cell: rnn_cell.RNNCell defining the cell function and size. Auxiliary 96 | decoder LSTM cell. 97 | num_symbols: Integer, how many symbols come into the embedding. 98 | embedding_size: Integer, the length of the embedding vector for each symbol. 99 | decoder_restrictions: List of (dense) 1D int32 Tensors of allowed output 100 | symbols for each decoder transition state. 101 | num_heads: Number of attention heads that read from attention_states. 102 | output_size: Size of the output vectors; if None, use output_size. 103 | output_projection: None or a pair (W, B) of output projection weights and 104 | biases; W has shape [output_size x num_symbols] and B has shape 105 | [num_symbols]; if provided and feed_previous=True, each fed previous 106 | output will first be multiplied by W and added B. 107 | transition_state_map: Constant 1D int Tensor size output_vocab_size. Maps 108 | each word to its transition state. 109 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 110 | used (the "GO" symbol), and all other decoder inputs will be generated by: 111 | next = embedding_lookup(embedding, argmax(previous_output)), 112 | In effect, this implements a greedy decoder. It can also be used 113 | during training to emulate http://arxiv.org/abs/1506.03099. 114 | If False, decoder_inputs are used as given (the standard decoder case). 115 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 116 | only the embedding for the first symbol of decoder_inputs (the "GO" 117 | symbol) will be updated by back propagation. Embeddings for the symbols 118 | generated from the decoder itself remain unchanged. This parameter has 119 | no effect if feed_previous=False. 120 | dtype: The dtype to use for the RNN initial states (default: tf.float32). 121 | scope: VariableScope for the created subgraph; defaults to 122 | "embedding_attention_decoder". 123 | initial_state_attention: If False (default), initial attentions are zero. 124 | If True, initialize the attentions from the initial state and attention 125 | states -- useful when we wish to resume decoding from a previously 126 | stored decoder state and attention states. 127 | 128 | Returns: 129 | A tuple of the form (logits, outputs, pointer_logits, state), where: 130 | logits: A list of the same length as decoder_inputs of 2D Tensors with 131 | shape [batch_size x num_decoder_symbols] containing the generated 132 | output logits. 133 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 134 | shape [batch_size x output_cell_size] (if output_projection is not None) 135 | containing the outputs that are fed to the loss function. 136 | Also fed to loss function. 137 | label_logits: List of the same length as encoder_inputs for label logits. 138 | label_outputs: List of the same length as encoder_inputs for label output 139 | vectors for loss function. 140 | state: The state of each decoder cell at the final time-step. 141 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 142 | Raises: 143 | ValueError: When output_projection has the wrong shape. 144 | """ 145 | output_size = cell.output_size 146 | for key, output_projection in output_projections.iteritems(): 147 | proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype) 148 | proj_biases.get_shape().assert_is_compatible_with([decoder_vocab_sizes[key]]) 149 | 150 | with tf.variable_scope(scope or "embedding_attention_decoder"): 151 | embeddings = {} 152 | embed_functions = {} 153 | if feed_previous: 154 | loop_functions = {} 155 | else: 156 | loop_functions = None 157 | 158 | for key, vocab_size in decoder_vocab_sizes.iteritems(): 159 | embedding = tf.get_variable("decoder_input_embedding_{0}".format(key), 160 | [vocab_size, decoder_embedding_sizes[key]], 161 | initializer=tf.random_uniform_initializer(-np.sqrt(3), 162 | np.sqrt(3))) 163 | embeddings[key] = embedding 164 | embed_functions[key] = seq2seq_helpers._extract_embed(embedding) 165 | 166 | if feed_previous: 167 | loop_functions[key] = seq2seq_helpers._extract_argmax_and_embed( 168 | embedding) 169 | 170 | if decoder_type == data_utils.STACK_DECODER_STATE: 171 | return seq2seq_decoders.attention_stack_decoder( 172 | decoder_inputs, encoder_inputs, initial_state, 173 | attention_states, cell, aux_cell, mem_cell, 174 | use_aux_stack=True, output_size=output_size, 175 | num_heads=num_heads, 176 | embed_functions=embed_functions, 177 | loop_functions=loop_functions, 178 | output_projections=output_projections, 179 | decoder_restrictions=decoder_restrictions, 180 | transition_state_map=transition_state_map, 181 | initial_state_attention=initial_state_attention) 182 | elif decoder_type == data_utils.PURE_STACK_DECODER_STATE: 183 | return seq2seq_decoders.attention_stack_decoder( 184 | decoder_inputs, encoder_inputs, initial_state, 185 | attention_states, cell, 186 | aux_cell, mem_cell, use_aux_stack=False, output_size=output_size, 187 | num_heads=num_heads, 188 | embed_functions=embed_functions, 189 | loop_functions=loop_functions, output_projections=output_projections, 190 | decoder_restrictions=decoder_restrictions, 191 | transition_state_map=transition_state_map, 192 | initial_state_attention=initial_state_attention) 193 | elif decoder_type == data_utils.MEMORY_STACK_DECODER_STATE: 194 | return seq2seq_decoders.attention_stack_decoder( 195 | decoder_inputs, encoder_inputs, initial_state, 196 | attention_states, cell, 197 | aux_cell, mem_cell, use_aux_stack=True, use_memory_stack=True, 198 | output_size=output_size, num_heads=num_heads, 199 | embed_functions=embed_functions, 200 | loop_functions=loop_functions, output_projections=output_projections, 201 | decoder_restrictions=decoder_restrictions, 202 | transition_state_map=transition_state_map, 203 | initial_state_attention=initial_state_attention) 204 | elif decoder_type == data_utils.LINEAR_POINTER_DECODER_STATE: 205 | return seq2seq_decoders.attention_pointer_decoder( 206 | decoder_inputs, encoder_inputs, 207 | initial_state, attention_states, cell, feed_alignment=False, 208 | feed_post_alignment=False, 209 | predict_end_attention=predict_span_end_pointers, 210 | output_size=output_size, 211 | num_heads=num_heads, 212 | embed_functions=embed_functions, 213 | loop_functions=loop_functions, 214 | output_projections=output_projections, 215 | decoder_restrictions=decoder_restrictions, 216 | transition_state_map=transition_state_map, 217 | decoder_embedding_sizes=decoder_embedding_sizes, 218 | initial_state_attention=initial_state_attention) 219 | elif decoder_type == data_utils.LINEAR_FEED_POINTER_DECODER_STATE: 220 | return seq2seq_decoders.attention_pointer_decoder( 221 | decoder_inputs, encoder_inputs, 222 | initial_state, attention_states, cell, feed_alignment=True, 223 | feed_post_alignment=False, 224 | predict_end_attention=predict_span_end_pointers, 225 | output_size=output_size, 226 | num_heads=num_heads, 227 | embed_functions=embed_functions, 228 | loop_functions=loop_functions, 229 | output_projections=output_projections, 230 | decoder_restrictions=decoder_restrictions, 231 | transition_state_map=transition_state_map, 232 | decoder_embedding_sizes=decoder_embedding_sizes, 233 | initial_state_attention=initial_state_attention) 234 | elif decoder_type == data_utils.HARD_ATTENTION_DECODER_STATE: 235 | return seq2seq_decoders.hard_attention_decoder( 236 | decoder_inputs, encoder_inputs, 237 | initial_state, attention_states, cell, 238 | predict_end_attention=predict_span_end_pointers, 239 | output_size=output_size, 240 | num_heads=num_heads, 241 | embed_functions=embed_functions, 242 | loop_functions=loop_functions, 243 | output_projections=output_projections, 244 | decoder_restrictions=decoder_restrictions, 245 | transition_state_map=transition_state_map, 246 | initial_state_attention=initial_state_attention) 247 | elif decoder_type == data_utils.HARD_ATTENTION_ARC_EAGER_DECODER_STATE: 248 | return seq2seq_decoders.hard_attention_arc_eager_decoder( 249 | decoder_inputs, encoder_inputs, 250 | initial_state, attention_states, cell, 251 | predict_end_attention=predict_span_end_pointers, 252 | output_size=output_size, 253 | num_heads=num_heads, 254 | embed_functions=embed_functions, 255 | loop_functions=loop_functions, 256 | output_projections=output_projections, 257 | transition_state_map=transition_state_map, 258 | initial_state_attention=initial_state_attention) 259 | elif decoder_type == data_utils.ATTENTION_DECODER_STATE: 260 | return seq2seq_decoders.attention_decoder( 261 | decoder_inputs, encoder_inputs, 262 | initial_state, attention_states, cell, 263 | output_size=output_size, 264 | num_heads=num_heads, 265 | embed_functions=embed_functions, 266 | loop_functions=loop_functions, 267 | output_projections=output_projections, 268 | decoder_restrictions=decoder_restrictions, 269 | transition_state_map=transition_state_map, 270 | initial_state_attention=initial_state_attention) 271 | else: 272 | return seq2seq_decoders.rnn_decoder(emb_inp, initial_state, cell, 273 | loop_function=loop_function, output_projection=output_projection) 274 | 275 | 276 | def embedding_attention_seq2seq(decoder_type, encoder_inputs, 277 | decoder_inputs, 278 | fw_cell, bw_cell, 279 | dec_cell, dec_aux_cell, dec_mem_cell, 280 | encoder_vocab_sizes, 281 | decoder_vocab_sizes, encoder_embedding_sizes, 282 | decoder_embedding_sizes, 283 | predict_span_end_pointers=False, 284 | decoder_restrictions=None, 285 | num_heads=1, output_projections=None, 286 | word_vectors=None, 287 | transition_state_map=None, 288 | encoder_decoder_vocab_map=None, 289 | use_bidirectional_encoder=False, 290 | feed_previous=False, dtype=tf.float32, 291 | scope=None, initial_state_attention=False): 292 | """Embedding sequence-to-sequence model with attention. 293 | 294 | This model first embeds encoder_inputs by a newly created embedding (of shape 295 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 296 | embedded encoder_inputs into a state vector. It keeps the outputs of this 297 | RNN at every step to use for attention later. Next, it embeds decoder_inputs 298 | by another newly created embedding (of shape [num_decoder_symbols x 299 | input_size]). Then it runs attention decoder, initialized with the last 300 | encoder state, on embedded decoder_inputs and attending to encoder outputs. 301 | 302 | Args: 303 | decoder_type: int indicating decoder to be used for encoder-decoder models 304 | - constants defined in data_utils.py. 305 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 306 | encoder_label_inputs: A list of 1D int32 Tensors of shape [batch_size]. 307 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 308 | decoder_pointer_inputs: A list of 1D int32 Tensors of shape [batch_size]. 309 | fw_cell: rnn_cell.RNNCell defining the cell function and size. Forward 310 | encoder cell. 311 | bw_cell: rnn_cell.RNNCell defining the cell function and size. Backward 312 | encoder cell. 313 | dec_cell: rnn_cell.RNNCell defining the cell function and size. Decoder 314 | cell. 315 | dec_aux_cell: rnn_cell.RNNCell defining the cell function and size. Decoder 316 | auxiliary LSTM cell. 317 | num_encoder_symbols: Integer; number of symbols on the encoder side. 318 | num_decoder_symbols: Integer; number of symbols on the decoder side. 319 | embedding_size: Integer; length of the embedding vector for each symbol. 320 | label_embedding_size: Integer; length of the label embedding vector for each symbol. 321 | use_input_labels: Boolean; encoder should include input labels in input. 322 | predict_input_labels: Boolean; encoder should predict label for each input 323 | token. 324 | decoder_restrictions: List of (dense) 1D int32 Tensors of allowed output 325 | symbols for each decoder transition state. 326 | num_heads: Number of attention heads that read from attention_states. 327 | output_projection: None or a pair (W, B) of output projection weights and 328 | biases; W has shape [output_size x num_decoder_symbols] and B has 329 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 330 | fed previous output will first be multiplied by W and added B. 331 | label_output_projection; None or label prediction output weights and biases. 332 | word_vectors: 2D Tensor shape [source_vocab_size, embedding_size] of encoder 333 | embedding vectors. 334 | label_vectors: 2D Tensor shape [label_vocab_size, embedding_size] of encoder 335 | label embedding vectors. 336 | transition_state_map: Constant 1D int Tensor size output_vocab_size. Maps 337 | each word to its transition state. 338 | use_bidirectional_encoder: Boolean; predict output for each encoder input 339 | token, no seperate decoder. 340 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 341 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 342 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 343 | If False, decoder_inputs are used as given (the standard decoder case). 344 | dtype: The dtype of the initial RNN state (default: tf.float32). 345 | scope: VariableScope for the created subgraph; defaults to 346 | "embedding_attention_seq2seq". 347 | initial_state_attention: If False (default), initial attentions are zero. 348 | If True, initialize the attentions from the initial state and attention 349 | states. 350 | 351 | Returns: 352 | A tuple of the form (logits, outputs, pointer_logits, label_logits, 353 | label_outputs, state), where: 354 | logits: A list of the same length as decoder_inputs of 2D Tensors with 355 | shape [batch_size x num_decoder_symbols] containing the generated 356 | output logits. 357 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 358 | shape [batch_size x output_cell_size] (if output_projection is not None) 359 | containing the outputs that are fed to the loss function. 360 | Also fed to loss function. 361 | label_logits: List of the same length as encoder_inputs for label logits. 362 | label_outputs: List of the same length as encoder_inputs for label output 363 | vectors for loss function. 364 | state: The state of each decoder cell at the final time-step. 365 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 366 | """ 367 | with tf.variable_scope(scope or "embedding_attention_seq2seq"): 368 | assert word_vectors is not None 369 | encoder_input_size = fw_cell.output_size 370 | encoder_output_size = fw_cell.output_size 371 | 372 | b = tf.get_variable("input_proj_b", [encoder_input_size]) 373 | emb_layer = [b for _ in encoder_inputs] 374 | 375 | for input_type in encoder_inputs[0].iterkeys(): 376 | # Defines encoder input projection layer. 377 | w = tf.get_variable("input_proj_w_{0}".format(input_type), 378 | [encoder_embedding_sizes[input_type], encoder_input_size], 379 | initializer=tf.uniform_unit_scaling_initializer()) 380 | for i, encoder_input in enumerate(encoder_inputs): 381 | emb_inp = tf.nn.embedding_lookup(word_vectors[input_type], 382 | encoder_input[input_type]) 383 | # Linear combination of the inputs. 384 | emb_layer[i] = tf.add(emb_layer[i], tf.matmul(emb_inp, w)) 385 | 386 | if use_bidirectional_encoder: 387 | # Encoder state is final backward state. 388 | encoder_outputs, _, encoder_state = tf.nn.bidirectional_rnn(fw_cell, 389 | bw_cell, emb_layer, dtype=dtype, scope="embedding_encoder") 390 | encoder_output_size *= 2 391 | else: 392 | encoder_outputs, encoder_state = tf.nn.rnn( 393 | fw_cell, emb_layer, dtype=dtype) 394 | 395 | if decoder_type == data_utils.NO_ATTENTION_DECODER_STATE: 396 | attention_states = None 397 | else: 398 | # First calculate a concatenation of encoder outputs to put attention on. 399 | top_states = [tf.reshape(e, [-1, 1, encoder_output_size]) 400 | for e in encoder_outputs] 401 | attention_states = tf.concat(1, top_states) 402 | 403 | # Decoder. 404 | if isinstance(feed_previous, bool): 405 | logits, state = embedding_attention_decoder( 406 | decoder_type, decoder_inputs, encoder_inputs, 407 | encoder_state, attention_states, dec_cell, dec_aux_cell, dec_mem_cell, 408 | decoder_vocab_sizes, decoder_embedding_sizes, 409 | decoder_restrictions, num_heads=num_heads, 410 | predict_span_end_pointers=predict_span_end_pointers, 411 | output_projections=output_projections, 412 | transition_state_map=transition_state_map, 413 | encoder_decoder_vocab_map=encoder_decoder_vocab_map, 414 | feed_previous=feed_previous, 415 | initial_state_attention=initial_state_attention) 416 | return logits, state 417 | 418 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 419 | def decoder(feed_previous_bool): 420 | reuse = None if feed_previous_bool else True 421 | with tf.variable_scope(tf.get_variable_scope(), 422 | reuse=reuse): 423 | logits, state = embedding_attention_decoder( 424 | decoder_type, decoder_inputs, encoder_inputs, 425 | encoder_state, attention_states, dec_cell, dec_aux_cell, 426 | dec_mem_cell, decoder_vocab_sizes, decoder_embedding_sizes, 427 | decoder_restrictions, num_heads=num_heads, 428 | predict_span_end_pointers=predict_span_end_pointers, 429 | output_projections=output_projections, 430 | transition_state_map=transition_state_map, 431 | encoder_decoder_vocab_map=encoder_decoder_vocab_map, 432 | feed_previous=feed_previous_bool, 433 | update_embedding_for_previous=not feed_previous_bool, 434 | initial_state_attention=initial_state_attention) 435 | return [logits, state] 436 | 437 | outputs_and_state = tf.cond(feed_previous, lambda: decoder(True), 438 | lambda: decoder(False)) 439 | return outputs_and_state[0], outputs_and_state[1] 440 | 441 | 442 | def sequence_loss_by_example(key, logits, targets, weights, 443 | average_across_timesteps=True, 444 | softmax_loss_function=None, name=None): 445 | """Weighted cross-entropy loss for a sequence of logits (per example). 446 | 447 | Args: 448 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 449 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 450 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 451 | average_across_timesteps: If set, divide the returned cost by the total 452 | label weight. 453 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 454 | to be used instead of the standard softmax (the default if this is None). 455 | name: Optional name for this operation, default: "sequence_loss_by_example". 456 | 457 | Returns: 458 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 459 | 460 | Raises: 461 | ValueError: If len(logits) is different from len(targets) or len(weights). 462 | """ 463 | if len(targets) != len(logits) or len(weights) != len(logits): 464 | raise ValueError("Lengths of logits, weights, and targets must be the same " 465 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 466 | with tf.name_scope(name, "sequence_loss_by_example", logits + targets + weights): 467 | log_perp_list = [] 468 | weight_list = [] 469 | for logit, target, weight in zip(logits, targets, weights): 470 | if key == "parse" or key == "att" or key == "endatt": 471 | weight_key = "parse" 472 | elif key == "ind": 473 | weight_key = "ind" 474 | else: 475 | weight_key = "predicate" 476 | crossent = softmax_loss_function(logit[key], target[key]) 477 | log_perp_list.append(crossent * weight[weight_key]) 478 | weight_list.append(weight[weight_key]) 479 | log_perps = tf.add_n(log_perp_list) 480 | total_size = tf.add_n(weight_list) 481 | total_size += 1e-12 # just to avoid division by 0 for all-0 weights 482 | if average_across_timesteps: 483 | log_perps /= total_size 484 | return log_perps, total_size 485 | 486 | 487 | def sequence_loss(key, logits, targets, weights, 488 | average_across_timesteps=True, average_across_batch=True, 489 | softmax_loss_function=None, name=None): 490 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 491 | 492 | Args: 493 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 494 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 495 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 496 | average_across_timesteps: If set, divide the returned cost by the total 497 | label weight. 498 | average_across_batch: If set, divide the returned cost by the batch size. 499 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 500 | to be used instead of the standard softmax (the default if this is None). 501 | name: Optional name for this operation, defaults to "sequence_loss". 502 | 503 | Returns: 504 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 505 | 506 | Raises: 507 | ValueError: If len(logits) is different from len(targets) or len(weights). 508 | """ 509 | with tf.name_scope(name, "sequence_loss", logits + targets + weights): 510 | cost_per_example, total_size = sequence_loss_by_example(key, 511 | logits, targets, weights, 512 | average_across_timesteps=average_across_timesteps, 513 | softmax_loss_function=softmax_loss_function) 514 | cost = tf.reduce_sum(cost_per_example) 515 | total_size = tf.reduce_sum(total_size) 516 | if average_across_batch and not average_across_timesteps: 517 | return cost / total_size 518 | elif average_across_batch: 519 | batch_size = tf.shape(next(targets[0].itervalues()))[0] 520 | return cost / tf.cast(batch_size, tf.float32) 521 | else: 522 | return cost 523 | 524 | 525 | def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 526 | buckets, seq2seq, forward_only, 527 | softmax_loss_function=None, 528 | average_across_timesteps=True, 529 | name=None): 530 | """Create a sequence-to-sequence model with support for bucketing. 531 | 532 | The seq2seq argument is a function that defines a sequence-to-sequence model, 533 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 534 | 535 | Args: 536 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 537 | encoder_label_inputs: A list of Tensors to feed the encoder; second 538 | seq2seq input. 539 | decoder_inputs: A list of Tensors to feed the decoder; 3rd seq2seq input. 540 | decoder_pointer_inputs: A list of Tensors to feed the decoder; 4th seq2seq 541 | input. 542 | label_targets: A list of 1D batch-sized int32 Tensors (desired label sequence). 543 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 544 | weights: List of 1D batch-sized float-Tensors to weight the targets. 545 | pointer_targets: A list of 1D batch-sized int32 Tensors (desired pointer 546 | output sequence). 547 | pointer_weights: List of 1D batch-sized float-Tensors to weight the pointer 548 | targets. 549 | buckets: A list of pairs of (input size, output size) for each bucket. 550 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 551 | agree with encoder_inputs and decoder_inputs, and returns a pair 552 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 553 | forward_only: boolean, set True for decoding. 554 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 555 | to be used instead of the standard softmax (the default if this is None). 556 | pointer_softmax_loss_function: Loss function for pointer prediction. 557 | label_softmax_loss_function: Loss function for label prediction. 558 | average_across_timesteps: Boolean, if True average the loss for each 559 | timestep. 560 | name: Optional name for this operation, defaults to "model_with_buckets". 561 | 562 | Returns: 563 | A tuple of the form (outputs, losses), where: 564 | outputs: The outputs for each bucket. Its j'th element consists of a list 565 | of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs). 566 | losses: List of scalar Tensors, representing losses for each bucket, or, 567 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 568 | 569 | Raises: 570 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 571 | than the largest (last) bucket. 572 | """ 573 | if len(encoder_inputs) < buckets[-1][0]: 574 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 575 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 576 | if len(targets) < buckets[-1][1]: 577 | raise ValueError("Length of targets (%d) must be at least that of last" 578 | "bucket (%d)." % (len(targets), buckets[-1][1])) 579 | if len(weights) < buckets[-1][1]: 580 | raise ValueError("Length of weights (%d) must be at least that of last" 581 | "bucket (%d)." % (len(weights), buckets[-1][1])) 582 | 583 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 584 | losses = [] 585 | outputs = [] 586 | with tf.name_scope(name, "model_with_buckets", all_inputs): 587 | for j, bucket in enumerate(buckets): 588 | with tf.variable_scope(tf.get_variable_scope(), 589 | reuse=True if j > 0 else None): 590 | bucket_logits, _ = seq2seq( 591 | encoder_inputs[:bucket[0]], decoder_inputs[:bucket[1]]) 592 | outputs.append(bucket_logits) 593 | 594 | bucket_targets = targets[:bucket[1]] 595 | bucket_weights = weights[:bucket[1]] 596 | 597 | loss = 0 598 | for key in bucket_targets[0].iterkeys(): 599 | loss += sequence_loss(key, 600 | bucket_logits, bucket_targets, bucket_weights, 601 | average_across_timesteps=average_across_timesteps, 602 | softmax_loss_function=softmax_loss_function) 603 | losses.append(loss) 604 | 605 | return outputs, losses 606 | 607 | -------------------------------------------------------------------------------- /rnn/seq2seq_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Modifications copyright 2017 Jan Buys. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | from tensorflow.python.util import nest 23 | 24 | # We disable pylint because we need python3 compatibility. 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | from six.moves import zip # pylint: disable=redefined-builtin 27 | 28 | import tensorflow as tf 29 | 30 | import data_utils 31 | 32 | linear = tf.nn.rnn_cell._linear # pylint: disable=protected-access 33 | 34 | #TODO rename: remove _ (not local method) 35 | def _extract_embed(embedding, update_embedding=True): 36 | """Get a loop_function that embeds symbols. 37 | 38 | Args: 39 | embedding: list of embedding tensors for symbols. 40 | update_embedding: Boolean; if False, the gradients will not propagate 41 | through the embeddings. 42 | 43 | Returns: 44 | A loop function. 45 | """ 46 | def embed_function(symbol): 47 | emb = tf.nn.embedding_lookup(embedding, symbol) 48 | # Note that gradients will not propagate through the second parameter of 49 | # embedding_lookup. 50 | if not update_embedding: 51 | emb = tf.stop_gradient(emb) 52 | return emb 53 | return embed_function 54 | 55 | 56 | def _extract_argmax_and_embed(embedding, output_projection=None, 57 | update_embedding=True): 58 | """Get a loop_function that extracts the previous symbol and embeds it. 59 | 60 | Args: 61 | embedding: embedding tensor for symbols. 62 | output_projection: None or a pair (W, B). If provided, each fed previous 63 | output will first be multiplied by W and added B. 64 | update_embedding: Boolean; if False, the gradients will not propagate 65 | through the embeddings. 66 | 67 | Returns: 68 | A loop function. 69 | """ 70 | def loop_function(prev, _): 71 | if output_projection is not None: 72 | prev = tf.matmul(prev, output_projection[0]) + output_projection[1] 73 | prev_symbol = tf.argmax(prev, 1) 74 | # Note that gradients will not propagate through the second parameter of 75 | # embedding_lookup. 76 | emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol) 77 | if not update_embedding: 78 | emb_prev = tf.stop_gradient(emb_prev) 79 | return emb_prev 80 | return loop_function 81 | 82 | 83 | def tile_embedding_attention(emb_inp, symbol_inp, initial_state, 84 | attention_states, beam_size, embedding_size): 85 | """Make beam_size copies of the attention states.""" 86 | tile_emb_inp = [] 87 | for inp in emb_inp: 88 | tile_emb = tf.tile(tf.reshape(inp, [1, -1]), 89 | tf.pack([beam_size, 1])) 90 | tile_emb = tf.reshape(tile_emb, [-1, embedding_size]) 91 | tile_emb_inp.append(tile_emb) 92 | 93 | tile_symbol_inp = [] 94 | for inp in symbol_inp: 95 | tile_sym = tf.tile(tf.reshape(inp, [1, 1]), 96 | tf.pack([beam_size, 1])) 97 | tile_emb = tf.reshape(tile_emb, [-1]) 98 | tile_symbol_inp.append(tile_emb) 99 | 100 | tile_initial_state = tf.tile(tf.reshape(initial_state, 101 | [1, -1]), tf.pack([beam_size, 1])) 102 | 103 | attn_length = attention_states.get_shape()[1].value 104 | attn_size = attention_states.get_shape()[2].value 105 | tile_attention_states = tf.tile(attention_states, 106 | tf.pack([beam_size, 1, 1])) 107 | tile_attention_states = tf.reshape(tile_attention_states, 108 | [-1, attn_length, attn_size]) 109 | 110 | return tile_emb_inp, tile_symbol_inp, tile_initial_state, tile_attention_states 111 | 112 | 113 | def attention(query, num_heads, y_w, v, hidden, hidden_features, attention_vec_size, 114 | attn_length, use_global_attention=False): 115 | """Puts attention masks on hidden using hidden_features and query. 116 | 117 | Args: 118 | query: vector, usually the current decoder state. 119 | 2D Tensor [batch_size x state_size]. 120 | num_heads: int. Currently always 1. 121 | v: attention model parameters. 122 | hidden: attention_states. 123 | hidden_features: same linear layer applied to all attention_states. 124 | attention_vec_size: attention embedding size. 125 | attn_length: number of inputs over which the attention spans. 126 | use_impatient_reader: make attention function dependent on previous 127 | attention vector. 128 | prev_ds: previous weighted averaged attention vector. 129 | 130 | Returns: 131 | atts: softmax over attention inputs. 132 | ds: attention-weighted averaged attention vector. 133 | """ 134 | at_logits = [] # result of attention logits 135 | at_probs = [] # result of attention probabilities 136 | ds = [] # results of attention reads will be stored here. 137 | if nest.is_sequence(query): # if the query is a tuple, flatten it. 138 | query_list = nest.flatten(query) 139 | for q in query_list: # check that ndims == 2 if specified. 140 | ndims = q.get_shape().ndims 141 | if ndims: 142 | assert ndims == 2 143 | query = tf.concat(1, query_list) 144 | for a in xrange(num_heads): 145 | with tf.variable_scope("Attention_%d" % a): 146 | y = tf.matmul(query, y_w[a][0]) + y_w[a][1] 147 | y = tf.reshape(y, [-1, 1, 1, attention_vec_size]) 148 | # Attention mask is a softmax of v^T * tanh(...). 149 | if use_global_attention: 150 | s = tf.reduce_sum(hidden_features[a] * y, [2, 3]) 151 | else: 152 | # Broadcast to add y (query vector) to all hidden_features. 153 | s = tf.reduce_sum( 154 | v[a] * tf.tanh(hidden_features[a] + y), [2, 3]) 155 | at_logits.append(s) 156 | att = tf.nn.softmax(s) 157 | at_probs.append(att) 158 | # Now calculate the attention-weighted vector d. 159 | d = tf.reduce_sum( 160 | tf.reshape(att, [-1, attn_length, 1, 1]) * hidden, 161 | [1, 2]) 162 | ds.append(tf.reshape(d, [-1, attention_vec_size])) 163 | return at_logits, at_probs, ds 164 | 165 | 166 | def extend_outputs_to_labels(outputs, label_inputs, label_logits, 167 | label_vectors, feed_previous): 168 | """Include (predicted) input labels in encoder attention vectors.""" 169 | new_outputs = [] 170 | for i, cell_output in enumerate(outputs): 171 | input_label = label_inputs[i] 172 | if feed_previous: 173 | input_label = tf.argmax(label_logits[i], 1) 174 | label_emb = tf.nn.embedding_lookup(label_vectors, input_label) 175 | concat_emb = tf.concat(1, [cell_output, label_emb]) 176 | new_outputs.append(concat_emb) 177 | return new_outputs 178 | 179 | 180 | def gumbel_noise(batch_size, logit_size): 181 | """Computes Gumbel noise. 182 | 183 | When the output is added to a logit, taking the argmax will be 184 | approximately equivalent to sampling from the logit. 185 | """ 186 | size = tf.pack([batch_size, logit_size]) 187 | uniform_sample = tf.random_uniform(size, 0, 1, dtype=dtype, 188 | seed=None, name=None) 189 | noise = -tf.log(-tf.log(uniform_sample)) 190 | return noise 191 | 192 | 193 | def init_thin_stack(batch_size, max_num_concepts): 194 | """Initializes the thin stack. 195 | Returns: 196 | thin_stack: Tensor with the stack content. 197 | thin_stack_head_next: Index pointers to element after stack head. 198 | """ 199 | # Stack initialized to -1, points to initial state. 200 | thin_stack = -tf.ones(tf.pack([batch_size, max_num_concepts]), 201 | dtype=tf.int32) 202 | # Reshape to ensure dimension 1 is known. 203 | thin_stack = tf.reshape(thin_stack, [-1, max_num_concepts]) 204 | # Set to 0 at position 0. 205 | inds = tf.transpose(tf.to_int64(tf.pack( 206 | [tf.range(batch_size), tf.zeros(tf.pack([batch_size]), dtype=tf.int32)]))) 207 | delta = tf.SparseTensor(inds, tf.ones(tf.pack([batch_size]), dtype=tf.int32), 208 | tf.pack([tf.to_int64(batch_size), max_num_concepts])) 209 | new_thin_stack = thin_stack + tf.sparse_tensor_to_dense(delta) 210 | # Position 0 is for empty stack; position after head always >= 1. 211 | thin_stack_head_next = tf.ones(tf.pack([batch_size]), 212 | dtype=tf.int32) 213 | return new_thin_stack, thin_stack_head_next 214 | 215 | 216 | def write_thin_stack(thin_stack, stack_pointers, decoder_position, batch_size, 217 | max_num_concepts): 218 | """Writes to the thin stack at the given pointers the current decoder position.""" 219 | new_vals = tf.fill(tf.pack([batch_size]), decoder_position) 220 | return write_thin_stack_vals(thin_stack, stack_pointers, new_vals, batch_size, 221 | max_num_concepts) 222 | 223 | 224 | def write_thin_stack_vals(thin_stack, stack_pointers, new_vals, batch_size, 225 | max_num_concepts): 226 | """Writes to the thin stack at the given pointers the current decoder position.""" 227 | # SparseTensor requires type int64. 228 | stack_inds = tf.transpose(tf.to_int64(tf.pack( 229 | [tf.range(batch_size), stack_pointers]))) # nn_stack_pointers 230 | 231 | current_vals = tf.gather_nd(thin_stack, stack_inds) 232 | delta = tf.SparseTensor(stack_inds, new_vals - current_vals, 233 | tf.pack([tf.to_int64(batch_size), max_num_concepts])) 234 | new_thin_stack = thin_stack + tf.sparse_tensor_to_dense(delta) 235 | return new_thin_stack 236 | 237 | 238 | def pure_reduce_thin_stack(thin_stack_head_next, transition_state): 239 | """Applies reduce to the thin stack and its head if in reduce state.""" 240 | # Pop if current transition state is reduce. 241 | stack_head_updates = tf.sparse_to_dense(data_utils.RE_STATE, 242 | tf.pack([data_utils.NUM_TR_STATES]), -1) 243 | new_thin_stack_head_next = tf.add(thin_stack_head_next, 244 | tf.gather(stack_head_updates, transition_state)) 245 | return new_thin_stack_head_next 246 | 247 | 248 | def reduce_thin_stack(thin_stack, thin_stack_head_next, batch_size, 249 | max_num_concepts, decoder_position, transition_state): 250 | """Applies reduce to the thin stack and its head if in reduce state.""" 251 | # Pop if current transition state is reduce. 252 | stack_head_updates = tf.sparse_to_dense(data_utils.RE_STATE, 253 | tf.pack([data_utils.NUM_TR_STATES]), -1) 254 | new_thin_stack_head_next = tf.add(thin_stack_head_next, 255 | tf.gather(stack_head_updates, transition_state)) 256 | 257 | return new_thin_stack_head_next 258 | 259 | 260 | def update_buffer_head(buffer_head, predicted_attns, transition_state): 261 | updates = tf.sparse_to_dense(tf.pack([data_utils.GEN_STATE]), 262 | tf.pack([data_utils.NUM_TR_STATES]), 263 | True, default_value=False) 264 | is_gen_state = tf.gather(updates, transition_state) 265 | 266 | new_buffer_head = tf.select(is_gen_state, predicted_attns, buffer_head) 267 | return new_buffer_head 268 | 269 | 270 | def pure_shift_thin_stack(thin_stack_head_next, transition_state): 271 | """Applies shift to the thin stack and its head if in shift state.""" 272 | 273 | # Push if previous transition state is shift (or pointer shift). 274 | stack_head_updates = tf.sparse_to_dense(tf.pack( 275 | [data_utils.GEN_STATE]), 276 | tf.pack([data_utils.NUM_TR_STATES]), 1) 277 | new_thin_stack_head_next = tf.add(thin_stack_head_next, 278 | tf.gather(stack_head_updates, transition_state)) 279 | 280 | return new_thin_stack_head_next 281 | 282 | 283 | def shift_thin_stack(thin_stack, thin_stack_head_next, batch_size, 284 | max_num_concepts, decoder_position, 285 | prev_transition_state): 286 | """Applies shift to the thin stack and its head if in shift state.""" 287 | # Head points to item after stack top, so always update the stack entry. 288 | new_thin_stack = write_thin_stack(thin_stack, thin_stack_head_next, 289 | decoder_position, batch_size, max_num_concepts) 290 | 291 | # Push if previous transition state is shift (or pointer shift). 292 | stack_head_updates = tf.sparse_to_dense(tf.pack( 293 | [data_utils.GEN_STATE]), 294 | tf.pack([data_utils.NUM_TR_STATES]), 1) 295 | new_thin_stack_head_next = tf.add(thin_stack_head_next, 296 | tf.gather(stack_head_updates, prev_transition_state)) 297 | 298 | return new_thin_stack, new_thin_stack_head_next 299 | 300 | 301 | def update_reduce_thin_stack(thin_stack, thin_stack_head_next, batch_size, 302 | max_num_concepts, decoder_position, 303 | transition_state): 304 | """If in reduce state, replaces the stack top with current decoder_position.""" 305 | # Aim at head for reduce (update), head_next otherwise (no update). 306 | re_index_updates = tf.sparse_to_dense(data_utils.RE_STATE, 307 | tf.pack([data_utils.NUM_TR_STATES]), -1) 308 | re_stack_head = tf.add(thin_stack_head_next, 309 | tf.gather(re_index_updates, transition_state)) 310 | 311 | # Update the stack. 312 | new_thin_stack = write_thin_stack(thin_stack, re_stack_head, 313 | decoder_position, batch_size, max_num_concepts) 314 | return new_thin_stack 315 | 316 | 317 | def extract_stack_head_entries(thin_stack, thin_stack_head_next, batch_size): 318 | """Finds entries (indices) at stack head for every instance in batch.""" 319 | stack_head_inds = tf.sub(thin_stack_head_next, 320 | tf.ones(tf.pack([batch_size]), dtype=tf.int32)) 321 | 322 | # For every batch entry, get the thin stack head entry. 323 | stack_inds = tf.transpose(tf.pack( 324 | [tf.range(batch_size), stack_head_inds])) 325 | stack_heads = tf.gather_nd(thin_stack, stack_inds) 326 | return stack_heads 327 | 328 | 329 | def mask_decoder_restrictions(logit, logit_size, decoder_restrictions, 330 | transition_state): 331 | """Enforces decoder restrictions determined by the transition state.""" 332 | restrict_mask_list = [] 333 | with tf.device("/cpu:0"): # sparse-to-dense must be on CPU for now 334 | for restr in decoder_restrictions: 335 | restrict_mask_list.append(tf.sparse_to_dense(restr, 336 | tf.pack([logit_size]), np.inf, default_value=-np.inf)) 337 | mask = tf.gather(tf.pack(restrict_mask_list), transition_state) 338 | new_logit = tf.minimum(logit, mask) 339 | return new_logit 340 | 341 | def mask_decoder_reduce(logit, thin_stack_head_next, logit_size, batch_size): 342 | """Ensures that we can only reduce when the stack has at least 1 item. 343 | 344 | For each batch entry k: 345 | If thin_stack_head_next == 0, #alternatively, or 1. 346 | let logit[k][reduce_index] = -np.inf, 347 | else don't change. 348 | """ 349 | # Allow reduce only if at least 1 item on stack, i.e., pointer >= 2. 350 | update_vals = tf.pack([-np.inf, -np.inf, 0.0]) 351 | update_val = tf.gather(update_vals, 352 | tf.minimum(thin_stack_head_next, 353 | 2*tf.ones(tf.pack([batch_size]), dtype=tf.int32))) 354 | 355 | re_filled = tf.fill(tf.pack([batch_size]), 356 | tf.to_int64(data_utils.REDUCE_ID)) 357 | re_inds = tf.transpose(tf.pack( 358 | [tf.to_int64(tf.range(batch_size)), re_filled])) 359 | re_delta = tf.SparseTensor(re_inds, update_val, tf.to_int64( 360 | tf.pack([batch_size, logit_size]))) 361 | new_logit = logit + tf.sparse_tensor_to_dense(re_delta) 362 | return new_logit 363 | 364 | 365 | def mask_decoder_only_shift(logit, thin_stack_head_next, transition_state_map, 366 | logit_size, batch_size): 367 | """Ensures that if the stack is empty, has to GEN_STATE (shift transition) 368 | 369 | For each batch entry k: 370 | If thin_stack_head_next == 0, #alternatively, or 1. 371 | let logit[k][reduce_index] = -np.inf, 372 | else don't change. 373 | """ 374 | stack_is_empty_bool = tf.less_equal(thin_stack_head_next, 1) 375 | stack_is_empty = tf.select(stack_is_empty_bool, 376 | tf.ones(tf.pack([batch_size]), dtype=tf.int32), 377 | tf.zeros(tf.pack([batch_size]), dtype=tf.int32)) 378 | stack_is_empty = tf.reshape(stack_is_empty, [-1, 1]) 379 | 380 | # Sh and Re states are disallowed (but not root). 381 | state_is_disallowed_updates = tf.sparse_to_dense( 382 | tf.pack([data_utils.RE_STATE, data_utils.ARC_STATE]), 383 | tf.pack([data_utils.NUM_TR_STATES]), 1) 384 | logit_states = tf.gather(transition_state_map, tf.range(logit_size)) 385 | state_is_disallowed = tf.gather(state_is_disallowed_updates, logit_states) 386 | state_is_disallowed = tf.reshape(state_is_disallowed, [1, -1]) 387 | 388 | index_delta = tf.matmul(stack_is_empty, state_is_disallowed) # 1 if disallowed 389 | values = tf.pack([0, -np.inf]) 390 | delta = tf.gather(values, index_delta) 391 | new_logit = logit + delta 392 | return new_logit 393 | 394 | def mask_decoder_only_reduce(logit, thin_stack_head_next, transition_state_map, 395 | max_stack_size, logit_size, batch_size): 396 | """Ensures that if the stack is empty, has to GEN_STATE (shift transition) 397 | 398 | For each batch entry k: 399 | If thin_stack_head_next == 0, #alternatively, or 1. 400 | let logit[k][reduce_index] = -np.inf, 401 | else don't change. 402 | """ 403 | # Allow reduce only if at least 1 item on stack, i.e., pointer >= 2. 404 | #stack_is_empty_updates = tf.pack([-np.inf, -np.inf, 0]) 405 | stack_is_full_bool = tf.greater_equal(thin_stack_head_next, max_stack_size - 1) 406 | stack_is_full = tf.select(stack_is_full_bool, 407 | tf.ones(tf.pack([batch_size]), dtype=tf.int32), 408 | tf.zeros(tf.pack([batch_size]), dtype=tf.int32)) 409 | stack_is_full = tf.reshape(stack_is_full, [-1, 1]) 410 | 411 | # Sh and Re states are allowed. 412 | state_is_disallowed_updates = tf.sparse_to_dense( 413 | tf.pack([data_utils.RE_STATE, data_utils.ARC_STATE, data_utils.ROOT_STATE]), 414 | tf.pack([data_utils.NUM_TR_STATES]), 0, 1) 415 | logit_states = tf.gather(transition_state_map, tf.range(logit_size)) 416 | state_is_disallowed = tf.gather(state_is_disallowed_updates, logit_states) 417 | state_is_disallowed = tf.reshape(state_is_disallowed, [1, -1]) 418 | 419 | index_delta = tf.matmul(stack_is_full, state_is_disallowed) # 1 if disallowed 420 | values = tf.pack([0, -np.inf]) 421 | delta = tf.gather(values, index_delta) 422 | new_logit = logit + delta 423 | return new_logit 424 | 425 | 426 | def gather_nd_lstm_states(states_c, states_h, inds, batch_size, input_size, 427 | state_size): 428 | concat_states_c = tf.concat(1, states_c) 429 | concat_states_h = tf.concat(1, states_h) 430 | 431 | new_prev_state_c = gather_nd_states(concat_states_c, 432 | inds, batch_size, input_size, state_size) 433 | new_prev_state_h = gather_nd_states(concat_states_h, 434 | inds, batch_size, input_size, state_size) 435 | return tf.nn.rnn_cell.LSTMStateTuple(new_prev_state_c, new_prev_state_h) 436 | 437 | 438 | def gather_nd_states(inputs, inds, batch_size, input_size, state_size): 439 | """Gathers an embedding for each batch entry with index inds from inputs. 440 | 441 | Args: 442 | inputs: Tensor [batch_size, input_size, state_size]. 443 | inds: Tensor [batch_size] 444 | 445 | Returns: 446 | output: Tensor [batch_size, embedding_size] 447 | """ 448 | sparse_inds = tf.transpose(tf.pack( 449 | [tf.range(batch_size), inds])) 450 | dense_inds = tf.sparse_to_dense(sparse_inds, 451 | tf.pack([batch_size, input_size]), 452 | tf.ones(tf.pack([batch_size]))) 453 | 454 | output_sum = tf.reduce_sum(tf.reshape(dense_inds, 455 | [-1, input_size, 1, 1]) * tf.reshape(inputs, 456 | [-1, input_size, 1, state_size]), [1, 2]) 457 | output = tf.reshape(output_sum, [-1, state_size]) 458 | return output 459 | 460 | 461 | def binary_select_state(state, updates, transition_state, batch_size): 462 | """Gathers state or zero for each batch entry.""" 463 | update_inds = tf.gather(updates, transition_state) 464 | sparse_diag = tf.transpose(tf.pack( 465 | [tf.range(batch_size), tf.range(batch_size)])) 466 | dense_inds = tf.sparse_to_dense(sparse_diag, 467 | tf.pack([batch_size, batch_size]), 468 | tf.to_float(update_inds)) 469 | new_state = tf.matmul(dense_inds, state) 470 | return new_state 471 | 472 | 473 | def hard_state_selection(attn_inds, hidden, batch_size, attn_length): 474 | batch_inds = tf.transpose(tf.pack( 475 | [tf.to_int64(tf.range(batch_size)), tf.to_int64(attn_inds)])) 476 | align_index = tf.to_float(tf.sparse_to_dense(batch_inds, 477 | tf.to_int64(tf.pack([batch_size, attn_length])), 1)) 478 | attns = tf.reduce_sum(hidden * 479 | tf.reshape(align_index, [-1, attn_length, 1, 1]), [1, 2]) 480 | return attns 481 | 482 | 483 | def gather_forced_att_logits(encoder_input_symbols, encoder_decoder_vocab_map, 484 | att_logit, batch_size, attn_length, 485 | target_vocab_size): 486 | """Gathers attention weights as logits for forced attention.""" 487 | flat_input_symbols = tf.reshape(encoder_input_symbols, [-1]) 488 | flat_label_symbols = tf.gather(encoder_decoder_vocab_map, 489 | flat_input_symbols) 490 | flat_att_logits = tf.reshape(att_logit, [-1]) 491 | 492 | flat_range = tf.to_int64(tf.range(tf.shape(flat_label_symbols)[0])) 493 | batch_inds = tf.floordiv(flat_range, attn_length) 494 | position_inds = tf.mod(flat_range, attn_length) 495 | attn_vocab_inds = tf.transpose(tf.pack( 496 | [batch_inds, position_inds, tf.to_int64(flat_label_symbols)])) 497 | 498 | # Exclude indexes of entries with flat_label_symbols[i] = -1. 499 | included_flat_indexes = tf.reshape(tf.where(tf.not_equal( 500 | flat_label_symbols, -1)), [-1]) 501 | included_attn_vocab_inds = tf.gather(attn_vocab_inds, 502 | included_flat_indexes) 503 | included_flat_att_logits = tf.gather(flat_att_logits, 504 | included_flat_indexes) 505 | 506 | sparse_shape = tf.to_int64(tf.pack( 507 | [batch_size, attn_length, target_vocab_size])) 508 | 509 | sparse_label_logits = tf.SparseTensor(included_attn_vocab_inds, 510 | included_flat_att_logits, sparse_shape) 511 | forced_att_logit_sum = tf.sparse_reduce_sum(sparse_label_logits, [1]) 512 | 513 | forced_att_logit = tf.reshape(forced_att_logit_sum, 514 | [-1, target_vocab_size]) 515 | 516 | return forced_att_logit 517 | 518 | 519 | def gather_prev_stack_state_index(pointer_vals, prev_index, transition_state, 520 | batch_size): 521 | """Gathers new previous state index.""" 522 | new_pointer_vals = tf.reshape(pointer_vals, [-1, 1]) 523 | 524 | # Helper tensors. 525 | prev_vals = tf.reshape(tf.fill( 526 | tf.pack([batch_size]), prev_index), [-1, 1]) 527 | trans_inds = tf.transpose(tf.pack( 528 | [tf.range(batch_size), transition_state])) 529 | 530 | # Gather new prev state for main tf.nn. Pointer vals if reduce, else prev. 531 | # State inds dimension [batch_size, NUM_TR_STATES] 532 | state_inds = tf.concat(1, [prev_vals]*6 + [new_pointer_vals, prev_vals]) 533 | prev_state_index = tf.gather_nd(state_inds, trans_inds) 534 | return prev_state_index 535 | 536 | 537 | def gather_prev_stack_aux_state_index(pointer_vals, prev_index, transition_state, 538 | batch_size): 539 | """Gather new prev state index for aux rnn: as for main, but zero if shift.""" 540 | new_pointer_vals = tf.reshape(pointer_vals, [-1, 1]) 541 | 542 | # Helper tensors. 543 | prev_vals = tf.reshape(tf.fill( 544 | tf.pack([batch_size]), prev_index), [-1, 1]) 545 | trans_inds = tf.transpose(tf.pack( 546 | [tf.range(batch_size), transition_state])) 547 | batch_zeros = tf.reshape(tf.zeros( 548 | tf.pack([batch_size]), dtype=tf.int32), [-1, 1]) 549 | 550 | # Gather new prev state for aux tf.nn. 551 | # State inds dimension [batch_size, NUM_TR_STATES] 552 | state_inds = tf.concat(1, 553 | [prev_vals, batch_zeros] + [prev_vals]*4 + [new_pointer_vals, prev_vals]) 554 | prev_state_index = tf.gather_nd(state_inds, trans_inds) 555 | return prev_state_index 556 | 557 | 558 | -------------------------------------------------------------------------------- /rnn/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Modifications copyright 2017 Jan Buys. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Sequence-to-sequence model with an attention mechanism.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import random 24 | 25 | import numpy as np 26 | from six.moves import xrange # pylint: disable=redefined-builtin 27 | import tensorflow as tf 28 | 29 | import seq2seq 30 | import data_utils 31 | 32 | class Seq2SeqModel(object): 33 | """Sequence-to-sequence model with attention and for multiple buckets. 34 | 35 | This class implements a multi-layer recurrent neural network as encoder, 36 | and an attention-based decoder. This is the same as the model described in 37 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details, 38 | or into the seq2seq library for complete model implementation. 39 | This class also allows to use GRU cells in addition to LSTM cells, and 40 | sampled softmax to handle large output vocabulary size. A single-layer 41 | version of this model, but with bi-directional encoder, was presented in 42 | http://arxiv.org/abs/1409.0473 43 | and sampled softmax is described in Section 3 of the following paper. 44 | http://arxiv.org/pdf/1412.2007v2.pdf 45 | """ 46 | 47 | def __init__(self, buckets, source_vocab_sizes, target_vocab_sizes, 48 | size, source_embedding_sizes, target_embedding_sizes, 49 | target_data_types, max_gradient_norm, batch_size, learning_rate, 50 | learning_rate_decay_factor, decoder_type, use_lstm=True, 51 | average_loss_across_timesteps=True, 52 | forward_only=False, feed_previous=False, 53 | predict_span_end_pointers=False, use_adam=False, 54 | restrict_decoder_structure=False, 55 | transition_vocab_sets=None, 56 | transition_state_map=None, encoder_decoder_vocab_map=None, 57 | use_bidirectional_encoder=False, 58 | pretrained_word_embeddings=None, word_embeddings=None, 59 | dtype=tf.float32): 60 | """Create the model. 61 | 62 | Args: 63 | source_vocab_size: size of the source vocabulary. 64 | target_vocab_size: size of the target vocabulary. 65 | buckets: a list of pairs (I, O), where I specifies maximum input length 66 | that will be processed in that bucket, and O specifies maximum output 67 | length. Training instances that have inputs longer than I or outputs 68 | longer than O will be pushed to the next bucket and padded accordingly. 69 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. 70 | size: number of units in each layer of the model. 71 | max_gradient_norm: gradients will be clipped to maximally this norm. 72 | batch_size: the size of the batches used during training; 73 | the model construction is independent of batch_size, so it can be 74 | changed after initialization if this is convenient, e.g., for decoding. 75 | learning_rate: learning rate to start with. 76 | learning_rate_decay_factor: decay learning rate by this much when needed. 77 | use_lstm: if true, we use LSTM cells instead of GRU cells. 78 | forward_only: if set, we do not construct the backward pass in the model. 79 | dtype: the data type to use to store internal variables. 80 | """ 81 | self.buckets = buckets 82 | self.batch_size = batch_size 83 | self.decoder_type = decoder_type 84 | 85 | self.transition_vocab_sets = transition_vocab_sets 86 | if transition_state_map is None: 87 | self.transition_state_map = None 88 | else: 89 | self.transition_state_map = tf.constant(transition_state_map) 90 | self.encoder_decoder_vocab_map = tf.constant(encoder_decoder_vocab_map) 91 | self.use_stack_decoder = decoder_type == data_utils.STACK_DECODER_STATE 92 | self.average_loss_across_timesteps = average_loss_across_timesteps 93 | self.input_keep_prob = tf.placeholder(tf.float32, 94 | name="input_keep_probability") 95 | self.output_keep_prob = tf.placeholder(tf.float32, 96 | name="output_keep_probability") 97 | 98 | if not use_adam: 99 | self.learning_rate = tf.Variable( 100 | float(learning_rate), trainable=False, dtype=dtype) 101 | self.learning_rate_decay_op = self.learning_rate.assign( 102 | self.learning_rate * learning_rate_decay_factor) 103 | self.global_step = tf.Variable(0, trainable=False) 104 | 105 | self.embedding_weights = {} 106 | for source_type in source_embedding_sizes.iterkeys(): 107 | self.embedding_weights[source_type] = tf.Variable( 108 | tf.constant(0.0, shape=[source_vocab_sizes[source_type], 109 | source_embedding_sizes[source_type]]), 110 | trainable=(source_type <> 'em'), 111 | name=source_type + "_encoder_embeddings") 112 | if source_type == 'en': 113 | assert word_embeddings is not None 114 | assert source_embedding_sizes['en'] == word_embeddings.shape[1] 115 | self.embedding_weights['en'].assign(word_embeddings) 116 | elif source_type == 'em': 117 | assert pretrained_word_embeddings is not None 118 | assert source_embedding_sizes['em'] == pretrained_word_embeddings.shape[1] 119 | self.embedding_weights['em'].assign(pretrained_word_embeddings) 120 | else: 121 | init_vectors = np.random.uniform(-np.sqrt(3), np.sqrt(3), 122 | (source_vocab_sizes[source_type], 123 | source_embedding_sizes[source_type])) 124 | self.embedding_weights[source_type].assign(init_vectors) 125 | 126 | output_projections = {} 127 | for target_type in target_vocab_sizes.iterkeys(): 128 | vocab_size = target_vocab_sizes[target_type] 129 | w = tf.get_variable(target_type + "_proj_w", [size, vocab_size], 130 | initializer=tf.uniform_unit_scaling_initializer(), dtype=dtype) 131 | w_t = tf.transpose(w) 132 | b = tf.get_variable(target_type + "_proj_b", [vocab_size], dtype=dtype) 133 | output_projections[target_type] = (w, b) 134 | 135 | def full_loss(logits, labels): 136 | labels = tf.reshape(labels, [-1]) 137 | return tf.nn.sparse_softmax_cross_entropy_with_logits( 138 | logits, labels) 139 | 140 | def full_output_loss(inputs, labels): 141 | logits = tf.nn.xw_plus_b(inputs, w, b) 142 | labels = tf.reshape(labels, [-1]) 143 | return tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels) 144 | 145 | softmax_loss_function = full_loss 146 | 147 | def create_cell(use_dropout=True): 148 | # Create the internal cell for our RNN. 149 | if use_lstm: 150 | cell = tf.nn.rnn_cell.LSTMCell(size, use_peepholes=False, 151 | state_is_tuple=True, 152 | initializer=tf.uniform_unit_scaling_initializer()) 153 | else: 154 | cell = tf.nn.rnn_cell.GRUCell(size) 155 | if use_dropout: 156 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, 157 | self.input_keep_prob, self.output_keep_prob) 158 | return cell 159 | 160 | with tf.variable_scope("encoder_fw"): 161 | fw_cell = create_cell() 162 | with tf.variable_scope("encoder_bw"): 163 | bw_cell = create_cell() 164 | 165 | with tf.variable_scope("decoder_main"): 166 | dec_cell = create_cell() 167 | with tf.variable_scope("decoder_aux"): 168 | dec_aux_cell = create_cell(False) 169 | if self.decoder_type == data_utils.MEMORY_STACK_DECODER_STATE: 170 | with tf.variable_scope("decoder_lin_mem"): 171 | dec_mem_cell = create_cell() 172 | else: 173 | dec_mem_cell = None 174 | 175 | self.decoder_restrictions = [] 176 | num_decoder_restrictions = 0 177 | if restrict_decoder_structure: 178 | num_decoder_restrictions = data_utils.NUM_TR_STATES 179 | for i in xrange(num_decoder_restrictions): 180 | self.decoder_restrictions.append(tf.placeholder(tf.int32, shape=[None], 181 | name="restrictions{0}".format(i))) 182 | 183 | if self.transition_vocab_sets is None: 184 | self.decoder_transition_map = None 185 | else: 186 | self.decoder_transition_map = data_utils.construct_transition_map( 187 | self.transition_vocab_sets, False) 188 | 189 | # The seq2seq function: we use embedding for the input and attention. 190 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 191 | return seq2seq.embedding_attention_seq2seq(self.decoder_type, 192 | encoder_inputs, decoder_inputs, 193 | fw_cell, bw_cell, dec_cell, dec_aux_cell, dec_mem_cell, 194 | source_vocab_sizes, target_vocab_sizes, source_embedding_sizes, 195 | target_embedding_sizes, 196 | predict_span_end_pointers=predict_span_end_pointers, 197 | decoder_restrictions=self.decoder_restrictions, 198 | output_projections=output_projections, 199 | word_vectors=self.embedding_weights, 200 | transition_state_map=self.transition_state_map, 201 | encoder_decoder_vocab_map=self.encoder_decoder_vocab_map, 202 | use_bidirectional_encoder=use_bidirectional_encoder, 203 | feed_previous=do_decode, dtype=dtype) 204 | 205 | # Feeds for inputs. 206 | self.encoder_inputs = [] 207 | self.decoder_inputs = [] 208 | self.target_weights = [] 209 | 210 | # For now assume that we only have embedding inputs, and single sequence 211 | # of target weights. 212 | 213 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. 214 | self.encoder_inputs.append({}) 215 | for key in source_vocab_sizes.iterkeys(): 216 | self.encoder_inputs[-1][key] = tf.placeholder(tf.int32, shape=[None], 217 | name="encoder_{0}_{1}".format(key, i)) 218 | 219 | for i in xrange(buckets[-1][1] + 1): 220 | self.decoder_inputs.append({}) 221 | for key in target_data_types: 222 | self.decoder_inputs[-1][key] = tf.placeholder(tf.int32, shape=[None], 223 | name="decoder_{0}_{1}".format(key, i)) 224 | 225 | for i in xrange(buckets[-1][1] + 1): 226 | self.target_weights.append({}) 227 | for key in target_data_types: 228 | if key == "parse" or key == "predicate" or key == "ind": 229 | self.target_weights[-1][key] = tf.placeholder(dtype, shape=[None], 230 | name="weight_{0}_{1}".format(key, i)) 231 | 232 | # Our targets are decoder inputs shifted by one. 233 | targets = [self.decoder_inputs[i + 1] 234 | for i in xrange(len(self.decoder_inputs) - 1)] 235 | 236 | # Training outputs and losses. 237 | self.outputs, self.losses = seq2seq.model_with_buckets( 238 | self.encoder_inputs, self.decoder_inputs, targets, 239 | self.target_weights, buckets, 240 | lambda x, y: seq2seq_f(x, y, feed_previous), forward_only, 241 | softmax_loss_function=softmax_loss_function, 242 | average_across_timesteps=self.average_loss_across_timesteps) 243 | 244 | # Gradients and SGD update operation for training the model. 245 | params = tf.trainable_variables() 246 | if not forward_only: 247 | self.gradient_norms = [] 248 | self.updates = [] 249 | if use_adam: 250 | opt = tf.train.AdamOptimizer(learning_rate, epsilon=1e-02) 251 | else: 252 | opt = tf.train.GradientDescentOptimizer(self.learning_rate) 253 | for b in xrange(len(buckets)): 254 | gradients = tf.gradients(self.losses[b], params) 255 | if max_gradient_norm > 0: 256 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 257 | max_gradient_norm) 258 | self.gradient_norms.append(norm) 259 | self.updates.append(opt.apply_gradients( 260 | zip(clipped_gradients, params), global_step=self.global_step)) 261 | else: 262 | self.gradient_norms.append(tf.zeros([1])) 263 | self.updates.append(opt.apply_gradients( 264 | zip(gradients, params), global_step=self.global_step)) 265 | 266 | self.saver = tf.train.Saver(tf.all_variables()) 267 | 268 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 269 | bucket_id, forward_only, input_keep_prob=1.0, output_keep_prob=1.0, 270 | decoder_vocab=None): 271 | """Run a step of the model feeding the given inputs. 272 | 273 | Args: 274 | session: tensorflow session to use. 275 | encoder_inputs: list of int vectors to feed as encoder inputs. 276 | decoder_inputs: list of numpy int vectors to feed as decoder inputs. 277 | target_weights: list of numpy float vectors to feed as target weights. 278 | bucket_id: which bucket of the model to use. 279 | forward_only: whether to do the backward step or only forward. 280 | 281 | Returns: 282 | A triple consisting of gradient norm (or None if we did not do backward), 283 | average perplexity, and the outputs. 284 | 285 | Raises: 286 | ValueError: if length of enconder_inputs, decoder_inputs, or 287 | target_weights disagrees with bucket size for the specified bucket_id. 288 | """ 289 | # Check if the sizes match. 290 | encoder_size, decoder_size = self.buckets[bucket_id] 291 | if len(encoder_inputs["en"]) != encoder_size: 292 | raise ValueError("Encoder length must be equal to the one in bucket," 293 | " %d != %d." % (len(encoder_inputs["en"]), encoder_size)) 294 | if len(decoder_inputs["parse"]) != decoder_size: 295 | raise ValueError("Decoder length must be equal to the one in bucket," 296 | " %d != %d." % (len(decoder_inputs["parse"]), decoder_size)) 297 | if len(target_weights["parse"]) != decoder_size: 298 | raise ValueError("Weights length must be equal to the one in bucket," 299 | " %d != %d." % (len(target_weights["parse"]), decoder_size)) 300 | 301 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 302 | input_feed = {} 303 | for l in xrange(encoder_size): 304 | for key in encoder_inputs.iterkeys(): 305 | input_feed[self.encoder_inputs[l][key].name] = encoder_inputs[key][l] 306 | 307 | for l in xrange(decoder_size): 308 | for key in decoder_inputs.iterkeys(): 309 | input_feed[self.decoder_inputs[l][key].name] = decoder_inputs[key][l] 310 | for key in target_weights.iterkeys(): 311 | input_feed[self.target_weights[l][key].name] = target_weights[key][l] 312 | 313 | # Since our targets are decoder inputs shifted by one, we need one more. 314 | for key in decoder_inputs.iterkeys(): 315 | last_target = self.decoder_inputs[decoder_size][key].name 316 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 317 | 318 | if self.decoder_restrictions: 319 | assert self.transition_vocab_sets is not None 320 | if len(self.decoder_restrictions) == 1: 321 | assert decoder_vocab is not None 322 | decoder_restrictions = [list(decoder_vocab.union(*self.transition_vocab_sets[1:]))] 323 | else: 324 | assert len(self.decoder_restrictions) == data_utils.NUM_TR_STATES 325 | decoder_restrictions = self.decoder_transition_map 326 | 327 | for l in xrange(len(decoder_restrictions)): 328 | input_feed[self.decoder_restrictions[l].name] = np.array( 329 | decoder_restrictions[l], dtype=int) 330 | 331 | # Add dropout probabilities to input feed. 332 | assert input_keep_prob >= 0.0 and input_keep_prob <= 1.0 333 | input_feed[self.input_keep_prob.name] = input_keep_prob 334 | assert output_keep_prob >= 0.0 and output_keep_prob <= 1.0 335 | input_feed[self.output_keep_prob.name] = output_keep_prob 336 | 337 | # Output feed: depends on whether we do a backward step or not. 338 | if not forward_only: 339 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 340 | self.gradient_norms[bucket_id], 341 | self.losses[bucket_id]] 342 | else: 343 | output_feed = [self.losses[bucket_id]] # Loss for this batch. 344 | for l in xrange(decoder_size): # Dicts of output logits. 345 | output_feed.append(self.outputs[bucket_id][l]) 346 | 347 | outputs = session.run(output_feed, input_feed) 348 | if not forward_only: 349 | # Gradient norm, loss, no outputs. 350 | return outputs[1], outputs[2], None 351 | else: 352 | # No gradient norm, loss, outputs. 353 | return None, outputs[0], outputs[1:decoder_size+1] 354 | 355 | def get_batch(self, data, data_types, bucket_id, batch_number, 356 | singleton_keep_prob=1.0, singleton_sets=None): 357 | """Get a random batch of data from the specified bucket, prepare for step. 358 | 359 | To feed data in step(..) it must be a list of batch-major vectors, while 360 | data here contains single length-major cases. So the main logic of this 361 | function is to re-index data cases to be in the proper format for feeding. 362 | 363 | Args: 364 | data: a tuple of size len(self.buckets) in which each element contains 365 | lists of pairs of input and output data that we use to create a batch. 366 | bucket_id: integer, which bucket to get the btch for. 367 | batch_number: integer, which batch in the bucket to get. 368 | 369 | Returns: 370 | The triple (encoder_inputs, decoder_inputs, target_weights) for 371 | the constructed batch that has the proper format to call step(...) later. 372 | """ 373 | encoder_size, decoder_size = self.buckets[bucket_id] 374 | encoder_inputs, decoder_inputs = {}, {} 375 | 376 | for key in data_types[0]: 377 | encoder_inputs[key] = [] 378 | for key in data_types[1]: 379 | decoder_inputs[key] = [] 380 | 381 | # Get a random batch of encoder and decoder inputs from data, 382 | # pad them if needed, reverse encoder inputs and add GO to decoder. 383 | for batch_pos in xrange(self.batch_size): 384 | if batch_number == -1: 385 | encoder_input_data, decoder_input_data = random.choice(data[bucket_id]) 386 | else: 387 | # input_data is list (over types) of sequences 388 | encoder_input_data, decoder_input_data = \ 389 | data[bucket_id][min(self.batch_size*batch_number + batch_pos, 390 | len(data[bucket_id])-1)] 391 | 392 | for k, key in enumerate(data_types[0]): 393 | encoder_input = encoder_input_data[k] 394 | 395 | # Keep or replace all singletons per sentence. 396 | if (singleton_keep_prob > 0 and singleton_sets is not None 397 | and singleton_sets.has_key(key)): 398 | for i in xrange(len(encoder_input)): 399 | if encoder_input[i] in singleton_sets[key]: 400 | unk_singletons = (singleton_keep_prob < 1.0 401 | and random.random() > singleton_keep_prob) 402 | if unk_singletons: 403 | encoder_input[i] = data_utils.UNK_ID 404 | 405 | # Encoder inputs are padded. 406 | encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input)) 407 | encoder_inputs[key].append(list(encoder_input + encoder_pad)) 408 | 409 | for k, key in enumerate(data_types[1]): 410 | #decoder_input = [] # TODO for batch evaluation 411 | decoder_input = decoder_input_data[k] 412 | if key == "att" or key == "endatt": 413 | decoder_input = [min(inp, encoder_size - 1) for inp in decoder_input] 414 | 415 | # Decoder inputs get an extra "GO" symbol, and are padded then. 416 | decoder_pad_size = decoder_size - len(decoder_input) 417 | decoder_inputs[key].append([data_utils.GO_ID] + decoder_input + 418 | [data_utils.PAD_ID] * (decoder_pad_size - 1)) 419 | 420 | # Now we create batch-major vectors from the data selected above. 421 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = {}, {}, {} 422 | 423 | # Batch encoder inputs are just re-indexed encoder_inputs. 424 | for k, key in enumerate(data_types[0]): 425 | batch_encoder_inputs[key] = [] 426 | for length_idx in xrange(encoder_size): 427 | batch_encoder_inputs[key].append( 428 | np.array([encoder_inputs[key][batch_idx][length_idx] 429 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 430 | 431 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 432 | for k, key in enumerate(data_types[1]): 433 | batch_decoder_inputs[key] = [] 434 | if key == "parse" or key == "predicate" or key == "ind": 435 | batch_weights[key] = [] 436 | 437 | for length_idx in xrange(decoder_size): 438 | batch_input = np.array([decoder_inputs[key][batch_idx][length_idx] 439 | for batch_idx in xrange(self.batch_size)], dtype=np.int32) 440 | # Remove -1 indexes 441 | if key == "start" or key == "end" or key == "ind": 442 | batch_input = np.maximum(batch_input, np.zeros(self.batch_size, 443 | dtype=np.int32)) 444 | batch_decoder_inputs[key].append(batch_input) 445 | # target weights customized for certain keys. 446 | if key == "parse" or key == "predicate" or key == "ind": 447 | # Create target_weights to be 0 for targets that are padding. 448 | batch_weight = np.ones(self.batch_size, dtype=np.float32) 449 | for batch_idx in xrange(self.batch_size): 450 | # We set weight to 0 if the corresponding target is a PAD symbol. 451 | # The corresponding target is decoder_input shifted by 1 forward. 452 | if length_idx < decoder_size - 1: 453 | target = decoder_inputs["parse"][batch_idx][length_idx + 1] 454 | if (length_idx == decoder_size - 1 or target == data_utils.PAD_ID 455 | or (key == "predicate" and target == data_utils.REDUCE_ID) 456 | or (key == "ind" and target <> data_utils.OPEN_ID 457 | and target <> data_utils.CLOSE_ID)): 458 | batch_weight[batch_idx] = 0.0 459 | 460 | batch_weights[key].append(batch_weight) 461 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights 462 | 463 | -------------------------------------------------------------------------------- /scripts/export-deepbank.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exports DeepBank 1.1 (WSJ section 00-21 with HPSG/MRS annotations). 4 | 5 | # Exports DeepBank. 6 | cd $ERG_DIR 7 | 8 | for dir in wsj00a wsj00b wsj00c wsj00d wsj01a wsj01b wsj01c wsj01d \ 9 | wsj02a wsj02b wsj02c wsj02d wsj03a wsj03b wsj03c \ 10 | wsj04a wsj04b wsj04c wsj04d wsj04e wsj05a wsj05b wsj05c wsj05d wsj05e \ 11 | wsj06a wsj06b wsj06c wsj06d wsj07a wsj07b wsj07c wsj07d wsj07e \ 12 | wsj08a wsj09a wsj09b wsj09c wsj09d wsj10a wsj10b wsj10c wsj10d \ 13 | wsj11a wsj11b wsj11c wsj11d wsj11e wsj12a wsj12b wsj12c wsj12d \ 14 | wsj13a wsj13b wsj13c wsj13d wsj13e wsj14a wsj14b wsj14c wsj14d wsj14e \ 15 | wsj15a wsj15b wsj15c wsj15d wsj15e \ 16 | wsj16a wsj16b wsj16c wsj16d wsj16e wsj16f wsj17a wsj17b wsj17c wsj17d \ 17 | wsj18a wsj18b wsj18c wsj18d wsj18e wsj19a wsj19b wsj19c wsj19d \ 18 | wsj20a wsj20b wsj20c wsj20d \ 19 | wsj21a wsj21b wsj21c wsj21d; do 20 | $LOGONROOT/redwoods --binary --erg --target export --export input,mrs,eds $dir >> export.log 21 | done 22 | 23 | -------------------------------------------------------------------------------- /scripts/extract-deepbank-sdp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exports the SDP version of DeepBank 1.1. 4 | 5 | # Downloads and extracts graphbank. 6 | #wget http://sdp.delph-in.net/osdp-12.tgz 7 | #tar -xvzf osdp-12.tgz # extracts to sdp/ 8 | DB_DIR="sdp/2015/eds" 9 | 10 | # Extracts DMRS and EDS (DMRS is first converted from MRS). 11 | # Sentences are extracted one by one from their individual files to form the train/dev/test split. 12 | 13 | for TYPE in dmrs eds; do 14 | MRS_DIR="deepbank-sdp-${TYPE}" 15 | mkdir -p $MRS_DIR 16 | EXTRACT_LINES="${HOME}/DeepDeepParser/mrs/extract_sdp_${TYPE}_lines.py" 17 | 18 | if [ $TYPE == "eds" ]; then 19 | ext="eds" 20 | else 21 | ext="mrs" 22 | fi 23 | 24 | for file in $DB_DIR/20*.${ext}.gz $DB_DIR/21*.${ext}.gz; do 25 | python $EXTRACT_LINES $MRS_DIR $file train 26 | done 27 | 28 | for file in $DB_DIR/220*.${ext}.gz; do 29 | python $EXTRACT_LINES $MRS_DIR $file dev 30 | done 31 | 32 | for file in $DB_DIR/221*.${ext}.gz; do 33 | python $EXTRACT_LINES $MRS_DIR $file test 34 | done 35 | done 36 | 37 | -------------------------------------------------------------------------------- /scripts/extract-deepbank.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exports DeepBank 1.1 (WSJ section 00-21 with HPSG/MRS annotations). 4 | bash $HOME/DeepDeepParser/scripts/export-deepbank.sh 5 | DB_DIR=$ERG_DIR 6 | 7 | # Extracts DMRS and EDS (DMRS is first converted from MRS). 8 | # Sentences are extracted one by one from their individual files to form the train/dev/test split. 9 | 10 | for TYPE in dmrs eds; do 11 | MRS_DIR="deepbank-${TYPE}" 12 | mkdir -p $MRS_DIR 13 | EXTRACT_LINES="${HOME}/DeepDeepParser/mrs/extract_${TYPE}_lines.py" 14 | 15 | for dir in wsj00a wsj00b wsj00c wsj00d wsj01a wsj01b wsj01c wsj01d \ 16 | wsj02a wsj02b wsj02c wsj02d wsj03a wsj03b wsj03c \ 17 | wsj04a wsj04b wsj04c wsj04d wsj04e wsj05a wsj05b wsj05c wsj05d wsj05e \ 18 | wsj06a wsj06b wsj06c wsj06d wsj07a wsj07b wsj07c wsj07d wsj07e \ 19 | wsj08a wsj09a wsj09b wsj09c wsj09d wsj10a wsj10b wsj10c wsj10d \ 20 | wsj11a wsj11b wsj11c wsj11d wsj11e wsj12a wsj12b wsj12c wsj12d \ 21 | wsj13a wsj13b wsj13c wsj13d wsj13e wsj14a wsj14b wsj14c wsj14d wsj14e \ 22 | wsj15a wsj15b wsj15c wsj15d wsj15e \ 23 | wsj16a wsj16b wsj16c wsj16d wsj16e wsj16f wsj17a wsj17b wsj17c wsj17d \ 24 | wsj18a wsj18b wsj18c wsj18d wsj18e wsj19a wsj19b wsj19c wsj19d; do 25 | for file in $DB_DIR/export/$dir/*; do 26 | python $EXTRACT_LINES $MRS_DIR $file train 27 | done 28 | done 29 | 30 | for dir in wsj20a wsj20b wsj20c wsj20d; do 31 | for file in $DB_DIR/export/$dir/*; do 32 | python $EXTRACT_LINES $MRS_DIR $file dev 33 | done 34 | done 35 | 36 | for dir in wsj21a wsj21b wsj21c wsj21d; do 37 | for file in $DB_DIR/export/$dir/*; do 38 | python $EXTRACT_LINES $MRS_DIR $file test 39 | done 40 | done 41 | 42 | done 43 | 44 | -------------------------------------------------------------------------------- /scripts/find_bucket_sizes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if __name__=='__main__': 4 | lines1 = [sent.split() for sent in open(sys.argv[1], 'r').read().split('\n')[:-1]] 5 | lines2 = [sent.split() for sent in open(sys.argv[2], 'r').read().split('\n')[:-1]] 6 | assert len(lines1) == len(lines2) 7 | max_length1 = 200 8 | max_length2 = 400 9 | lengths = [[0 for _ in xrange(max_length2)] for _ in xrange(max_length1)] 10 | extras = [] 11 | for line1, line2 in zip(lines1, lines2): 12 | if len(line1) < max_length1 and len(line2) < max_length2: 13 | lengths[len(line1)][len(line2)] += 1 14 | else: 15 | extras.append((len(line1), len(line2))) 16 | col_totals = [[0 for _ in xrange(max_length2+1)]] 17 | for i, leng in enumerate(lengths): 18 | col_totals.append([]) 19 | for j, l in enumerate(leng): 20 | col_totals[-1].append(col_totals[i][j]+l) 21 | 22 | acc_lengths = [[0 for _ in xrange(max_length2+1)]] 23 | for i, leng in enumerate(lengths): 24 | acc_lengths.append([0]) 25 | for j, l in enumerate(leng): 26 | acc_lengths[-1].append(acc_lengths[-1][-1] + col_totals[i][j]) 27 | 28 | thresholds = [0.4, 0.8, 0.98] 29 | brackets = [] 30 | # find smallest i+j greater than threshold 31 | for a in thresholds: 32 | b_i, b_j = max_length1, max_length2 33 | for i in xrange(max_length1+1): 34 | for j in xrange(max_length2+1): 35 | if acc_lengths[i][j]/(len(lines1)+0.0) > a and i + j < b_i + b_j: 36 | b_i, b_j = i, j 37 | brackets.append((b_i, b_j)) 38 | for pair in brackets: 39 | print str(pair[0]) + ' ' + str(pair[1]) 40 | 41 | -------------------------------------------------------------------------------- /scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for TYPE in dmrs eds; do 4 | 5 | MRS_DIR="deepbank-${TYPE}" 6 | MRS_WDIR=${MRS_DIR}-working 7 | mkdir -p $MRS_WDIR 8 | 9 | # Construct lexicon. 10 | python $HOME/DeepDeepParser/mrs/extract_erg_lexicon.py $ERG_DIR $MRS_WDIR 11 | python $HOME/DeepDeepParser/mrs/extract_data_lexicon.py $MRS_DIR $MRS_WDIR 12 | 13 | # Runs Stanford NLP tools over input. 14 | 15 | printf "$MRS_DIR/train.raw\n$MRS_DIR/dev.raw\n$MRS_DIR/test.raw\n" > FILELIST 16 | $JAVA -cp "$STANFORD_NLP/*" -Xmx16g \ 17 | edu.stanford.nlp.pipeline.StanfordCoreNLP \ 18 | -annotators tokenize,ssplit,pos,lemma,ner \ 19 | -ssplit.eolonly \ 20 | -filelist FILELIST \ 21 | -outputFormat text -outputDirectory $MRS_WDIR \ 22 | -tokenize.options "normalizeCurrency=False,normalizeFractions=False"\ 23 | "normalizeParentheses=False,normalizeOtherBrackets=False,"\ 24 | "latexQuotes=False,unicodeQuotes=True,"\ 25 | "ptb3Ellipsis=False,unicodeEllipsis=True,"\ 26 | "escapeForwardSlashAsterisk=False" 27 | rm FILELIST 28 | 29 | # Processes Stanford NLP output. 30 | python $HOME/DeepDeepParser/mrs/stanford_to_linear.py $MRS_DIR $MRS_WDIR $MRS_WDIR 31 | 32 | # Converts MRS graphs to multiple linearizations. 33 | python $HOME/DeepDeepParser/mrs/read_mrs.py $MRS_DIR $MRS_WDIR $TYPE 34 | 35 | # Copies data for parser training. 36 | 37 | LIN_DIR=${TYPE}-parse-data-deepbank 38 | mkdir -p $LIN_DIR 39 | ORACLE=dmrs.ae.ao # Arc-eager parser, alignment-ordered oracle 40 | 41 | for SET in train dev test; do 42 | cp $MRS_WDIR/${SET}.en $MRS_WDIR/${SET}.pos $MRS_WDIR/${SET}.ne $LIN_DIR/ 43 | cp $MRS_WDIR/${SET}.${ORACLE}.nospan.unlex.lin $LIN_DIR/${SET}.parse 44 | cp $MRS_WDIR/${SET}.${ORACLE}.point.lin $LIN_DIR/${SET}.att 45 | cp $MRS_WDIR/${SET}.${ORACLE}.endpoint.lin $LIN_DIR/${SET}.endatt 46 | done 47 | 48 | python $HOME/DeepDeepParser/scripts/find_bucket_sizes.py $LIN_DIR/train.en $LIN_DIR/train.parse > $LIN_DIR/buckets 49 | 50 | done 51 | 52 | -------------------------------------------------------------------------------- /scripts/run-ace-deepbank.sh: -------------------------------------------------------------------------------- 1 | 2 | OUT_DIR=deepbank-ace 3 | mkdir -p $OUT_DIR 4 | MEM=16000 5 | 6 | for name in dev test; do 7 | ace -g $ERG_DIR/erg-1214-x86-64-0.9.24.dat deepbank-dmrs/$name.raw -1Tf --maxent=$ERG_DIR/wsj.mem --max-unpack-megabytes=$MEM --max-chart-megabytes=$MEM > ${OUT_DIR}/${name}.mrs 2> ${OUT_DIR}/${name}.log 8 | 9 | # Extracts DMRS and evaluates. 10 | for TYPE in dmrs eds; do 11 | MRS_WDIR=deepbank-${TYPE}-working 12 | python $HOME/DeepDeepParser/mrs/extract_ace_mrs.py $name $OUT_DIR -${TYPE} 13 | python $HOME/DeepDeepParser/mrs/eval_edm.py ${MRS_WDIR}/${name}.${TYPE}.edm ${OUT_DIR}/${name}.${TYPE}.edm 14 | python $HOME/smatch/smatch.py -f ${MRS_WDIR}/${name}.${TYPE}.amr ${OUT_DIR}/${name}.${TYPE}.amr 15 | done 16 | done 17 | 18 | --------------------------------------------------------------------------------