├── .gitignore ├── LICENSE ├── README.md ├── papers ├── 1904.05255.pdf ├── 1906.03088.pdf ├── 1906.04341.pdf ├── C16-1120.pdf ├── S10-1057.pdf └── intbert_acl19paper-3.pdf └── research ├── __init__.py ├── document_processor ├── Encoder.py ├── PrepareInputForSentenceEncoder.py ├── SentenceEncoder.py └── __init__.py ├── evaluation ├── conllsrlwriter.py ├── semeval2010_task8_format_checker.pl ├── semeval2010_task8_scorer.pl ├── semeval2010_writer.py └── srl-eval.pl ├── iohandler ├── SRLToDoc.py ├── SemEvalToDoc.py ├── __init__.py ├── part_whole_reader.py └── sampler.py ├── libnlp ├── Document.py ├── SemanticRelation.py ├── SemanticRole.py ├── Token.py └── __init__.py ├── main.py ├── models ├── BERTModels.py └── __init__.py ├── tester ├── __init__.py └── test_model.py └── trainer ├── __init__.py ├── semantic_relation_train_model.py ├── semantic_role_train_model.py └── train_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | .idea/ 3 | data/ 4 | __pycache__/ 5 | *.pyc 6 | *.log 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for training simple BERT models for semantic relation classification or semantic role labelling. 2 | 3 | 1. Semantic Relation Classification 4 | 5 | As introduced in the SemEval 2010 Task 8, this code aims to classify semantic relations between nominals in a sentence. It assumes that the arguments between which relation holds is already provided. The data for this can be downloaded from https://github.com/sahitya0000/Relation-Classification 6 | 7 | 2. Semantic Role Labelling 8 | 9 | As introduced in CoNLL 2005 Shared Task, this code aims to classify semantic roles with respect to a given verb in a sentence. The data for this task is available at https://www.cs.upc.edu/~srlconll/. Please note that you will need access to the LDC for downloading this dataset. -------------------------------------------------------------------------------- /papers/1904.05255.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/papers/1904.05255.pdf -------------------------------------------------------------------------------- /papers/1906.03088.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/papers/1906.03088.pdf -------------------------------------------------------------------------------- /papers/1906.04341.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/papers/1906.04341.pdf -------------------------------------------------------------------------------- /papers/C16-1120.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/papers/C16-1120.pdf -------------------------------------------------------------------------------- /papers/S10-1057.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/papers/S10-1057.pdf -------------------------------------------------------------------------------- /papers/intbert_acl19paper-3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/papers/intbert_acl19paper-3.pdf -------------------------------------------------------------------------------- /research/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/__init__.py -------------------------------------------------------------------------------- /research/document_processor/Encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from pytorch_transformers import * 5 | 6 | from research.document_processor import PrepareInputForSentenceEncoder 7 | from research.libnlp.Document import Document 8 | 9 | 10 | class Encoder: 11 | MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'), 12 | (OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'), 13 | (GPT2Model, GPT2Tokenizer, 'gpt2'), 14 | (TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'), 15 | (XLNetModel, XLNetTokenizer, 'xlnet-base-cased'), 16 | (XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024')] 17 | 18 | def __init__(self, model_type, task_type, logger, max_len): 19 | self.model_type = model_type # indicates which sentence encoder you want to use 20 | self.task_type = task_type # indicates which type of classification task you are performing 21 | self.max_len = max_len 22 | model, tokenizer, pretrained_weights = self.validate_model(model_type, logger) 23 | self.model = model.from_pretrained(pretrained_weights) 24 | self.tokenizer = tokenizer.from_pretrained(pretrained_weights) 25 | self.logger = logger 26 | 27 | def validate_model(self, type, logger): 28 | if type > len(self.MODELS): 29 | logger.error("Incorrect model-tokenizer-pretrained_weights combination") 30 | sys.exit() 31 | else: 32 | return self.MODELS[type] 33 | 34 | def get_embedding(self, d): 35 | if self.task_type == 0: 36 | input_ids, position_vect1, position_vect2 = PrepareInputForSentenceEncoder.convert_to_input(d, self.model_type, self.task_type, self.tokenizer, 37 | self.max_len, add_positional_features=True) 38 | with torch.no_grad(): 39 | d.linkTokenIDs([input_ids, position_vect1, position_vect2]) 40 | if self.task_type == 1: 41 | input_ids, position_vect, labels = PrepareInputForSentenceEncoder.convert_to_input(d, self.model_type, self.task_type, self.tokenizer, self.max_len, add_positional_features=True) 42 | with torch.no_grad(): 43 | d.linkTokenIDs([input_ids, position_vect, labels]) 44 | self.logger.info("Document " + str(d.doc_id) + " encoded ") 45 | return d 46 | 47 | def encode_text(self, document): 48 | if isinstance(document, Document): 49 | document = self.get_embedding(document) 50 | if isinstance(document, dict): 51 | for idx in document.keys(): 52 | document[idx] = self.get_embedding(document[idx]) 53 | return document 54 | -------------------------------------------------------------------------------- /research/document_processor/PrepareInputForSentenceEncoder.py: -------------------------------------------------------------------------------- 1 | from research.iohandler.SRLToDoc import realign_data 2 | 3 | def extend_list(l, max_len): 4 | m = l 5 | m.extend([0] * (max_len - len(l))) 6 | return m 7 | 8 | def extend_labels(l, max_len): 9 | m = l 10 | m.extend(['O'] * (max_len - len(l))) 11 | return m 12 | 13 | def add_positional_features_to_text(ttext, token): 14 | position_vector = list() 15 | start_position = ttext.index(token[0]) 16 | end_position = ttext.index(token[-1]) 17 | for i, t in enumerate(ttext): 18 | if i < start_position: 19 | position_vector.append(start_position - i) 20 | if start_position <= i <= end_position: 21 | position_vector.append(0) 22 | if i > end_position: 23 | position_vector.append(i - end_position) 24 | return position_vector 25 | 26 | 27 | def convert_to_input(document, type, task_type, tokenizer, max_len, add_positional_features = False): 28 | text = document.text 29 | if type == 0: # for BERT 30 | if task_type == 0: # TODO: for semantic relations 31 | position_vect1 = None 32 | position_vect2 = None 33 | token1 = document.sr.token1 34 | token2 = document.sr.token2 35 | ttext = "[CLS] " + text + " [SEP] " + token1 + " [SEP] " + token2 36 | tokenized_text = tokenizer.tokenize(ttext) 37 | input_ids = tokenizer.convert_tokens_to_ids(tokenized_text) 38 | if add_positional_features: 39 | position_vect1 = add_positional_features_to_text(tokenized_text, tokenizer.tokenize(token1)) 40 | position_vect2 = add_positional_features_to_text(tokenized_text, tokenizer.tokenize(token2)) 41 | position_vect1 = extend_list(position_vect1, max_len) 42 | position_vect2 = extend_list(position_vect2, max_len) 43 | input_segments = [0] * (len(tokenizer.tokenize(text)) + 1) + [1] * (len(tokenizer.tokenize(token1)) + 1) + [ 44 | 1] * (len(tokenizer.tokenize(token2)) + 1) 45 | input_masks = [1] * (len(tokenized_text)) 46 | input_ids = extend_list(input_ids, max_len) 47 | input_segments = extend_list(input_segments, max_len) 48 | input_masks = extend_list(input_masks, max_len) 49 | return [input_ids, input_segments, input_masks], position_vect1, position_vect2 50 | 51 | if task_type == 1: # TODO: for semantic roles 52 | position_vect = None 53 | tokens = document.sr.tokens 54 | labels = document.sr.labels 55 | tokenized_tokens = tokenizer.tokenize(text) 56 | labels = realign_data(tokens, labels, tokenized_tokens) 57 | ttext = "[CLS] " + text + " [SEP] " + document.sr.get_verb() 58 | tokenized_text = tokenizer.tokenize(ttext) 59 | input_ids = tokenizer.convert_tokens_to_ids(tokenized_text) 60 | if add_positional_features: 61 | position_vect = add_positional_features_to_text(tokenized_text, tokenizer.tokenize(document.sr.get_verb())) 62 | position_vect = extend_list(position_vect, max_len) 63 | input_segments = [0] * (len(tokenizer.tokenize(text)) + 1) + [1] * (len(tokenizer.tokenize(document.sr.get_verb())) + 1) 64 | input_masks = [1] * (len(tokenized_text)) 65 | input_ids = extend_list(input_ids, max_len) 66 | input_segments = extend_list(input_segments, max_len) 67 | input_masks = extend_list(input_masks, max_len) 68 | labels = extend_labels(labels, max_len) 69 | return [input_ids, input_segments, input_masks], position_vect, labels 70 | 71 | return None -------------------------------------------------------------------------------- /research/document_processor/SentenceEncoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from pytorch_transformers import * 5 | 6 | from research.document_processor import PrepareInputForSentenceEncoder 7 | from research.libnlp.Document import Document 8 | 9 | 10 | class SentenceEncoder: 11 | MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'), 12 | (OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'), 13 | (GPT2Model, GPT2Tokenizer, 'gpt2'), 14 | (TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'), 15 | (XLNetModel, XLNetTokenizer, 'xlnet-base-cased'), 16 | (XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024')] 17 | 18 | def __init__(self, model_type, task_type, logger, max_len): 19 | self.model_type = model_type # indicates which sentence encoder you want to use 20 | self.task_type = task_type # indicates which type of classification task you are performing 21 | self.max_len = max_len 22 | model, tokenizer, pretrained_weights = self.validate_model(model_type, logger) 23 | self.model = model.from_pretrained(pretrained_weights) 24 | self.tokenizer = tokenizer.from_pretrained(pretrained_weights) 25 | self.logger = logger 26 | 27 | def validate_model(self, type, logger): 28 | if type > len(self.MODELS): 29 | logger.error("Incorrect model-tokenizer-pretrained_weights combination") 30 | sys.exit() 31 | else: 32 | return self.MODELS[type] 33 | 34 | def get_embedding(self, d, take_hypernyms): 35 | input_ids, position_vect1, position_vect2 = PrepareInputForSentenceEncoder.convert_to_input(d, self.model_type, self.task_type, self.tokenizer, 36 | self.max_len, take_hypernyms, add_positional_features=True) 37 | with torch.no_grad(): 38 | d.linkTokenIDs([input_ids, position_vect1, position_vect2]) 39 | self.logger.info("Document " + str(d.doc_id) + " encoded ") 40 | return d 41 | 42 | def encode_text(self, document, take_hypernyms = True): 43 | if isinstance(document, Document): 44 | document = self.get_embedding(document, take_hypernyms) 45 | if isinstance(document, dict): 46 | for idx in document.keys(): 47 | document[idx] = self.get_embedding(document[idx], take_hypernyms) 48 | return document 49 | -------------------------------------------------------------------------------- /research/document_processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/document_processor/__init__.py -------------------------------------------------------------------------------- /research/evaluation/conllsrlwriter.py: -------------------------------------------------------------------------------- 1 | def get_prediction(relations): 2 | return relations 3 | 4 | def file_writer(document, relations, doc_ids, file, inverse_sr_dict, class_type="true"): 5 | for doc_id in doc_ids: 6 | document_object = document[doc_id] 7 | tokens = document_object.sr.tokens 8 | if class_type == "true": 9 | labels = document_object.sr.labels 10 | else: 11 | labels = get_prediction(relations) 12 | 13 | def conll_writer(tokens, labels): 14 | pass -------------------------------------------------------------------------------- /research/evaluation/semeval2010_task8_format_checker.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # 3 | # 4 | # Author: Preslav Nakov 5 | # nakov@comp.nus.edu.sg 6 | # National University of Singapore 7 | # 8 | # WHAT: This is an official output file format checker for SemEval-2010 Task #8. 9 | # 10 | # Use: 11 | # semeval2010_task8_format_checker.pl 12 | # 13 | # Examples: 14 | # semeval2010_task8_format_checker.pl proposed_answer1.txt 15 | # semeval2010_task8_format_checker.pl proposed_answer2.txt 16 | # semeval2010_task8_format_checker.pl proposed_answer3.txt 17 | # semeval2010_task8_format_checker.pl proposed_answer4.txt 18 | # 19 | # In the examples above, the first three files are OK, while the last one contains four errors. 20 | # And answer_key2.txt contains the true labels for the *training* dataset. 21 | # 22 | # Description: 23 | # The scorer takes as input a proposed classification file, 24 | # which should contain one prediction per line in the format " " 25 | # with a TAB as a separator, e.g., 26 | # 1 Component-Whole(e2,e1) 27 | # 2 Other 28 | # 3 Instrument-Agency(e2,e1) 29 | # ... 30 | # The file does not have to be sorted in any way. 31 | # Repetitions of IDs are not allowed. 32 | # 33 | # In case of problems, the checker outputs the problemtic line and its number. 34 | # Finally, the total number of problems found is reported 35 | # or a message is output saying that the file format is OK. 36 | # 37 | # Participants are expected to check their output using this checker before submission. 38 | # 39 | # Last modified: March 10, 2010 40 | # 41 | # 42 | 43 | use strict; 44 | 45 | ############### 46 | ### I/O ### 47 | ############### 48 | 49 | if ($#ARGV != 0) { 50 | die "Usage:\nsemeval2010_task8_format_checker.pl \n"; 51 | } 52 | 53 | my $INPUT_FILE_NAME = $ARGV[0]; 54 | 55 | ################ 56 | ### MAIN ### 57 | ################ 58 | my %ids = (); 59 | 60 | my $errCnt = 0; 61 | open(INPUT, $INPUT_FILE_NAME) or die "Failed to open $INPUT_FILE_NAME for text reading.\n"; 62 | for (my $lineNo = 1; ; $lineNo++) { 63 | my ($id, $label) = &getIDandLabel($_); 64 | if ($id < 0) { 65 | s/[\n\r]*$//; 66 | print "Bad file format on line $lineNo: '$_'\n"; 67 | $errCnt++; 68 | } 69 | elsif (defined $ids{$id}) { 70 | s/[\n\r]*$//; 71 | print "Bad file format on line $lineNo (ID $id is already defined): '$_'\n"; 72 | $errCnt++; 73 | } 74 | $ids{$id}++; 75 | } 76 | close(INPUT) or die "Failed to close $INPUT_FILE_NAME.\n"; 77 | 78 | if (0 == $errCnt) { 79 | print "\n<<< The file format is OK.\n"; 80 | } 81 | else { 82 | print "\n<<< The format is INCORRECT: $errCnt problematic line(s) found!\n"; 83 | } 84 | 85 | 86 | ################ 87 | ### SUBS ### 88 | ################ 89 | 90 | sub getIDandLabel() { 91 | my $line = shift; 92 | 93 | return (-1,()) if ($line !~ /^([0-9]+)\t([^\r]+)\r?\n$/); 94 | my ($id, $label) = ($1, $2); 95 | 96 | return ($id, '_Other') if ($label eq 'Other'); 97 | 98 | return ($id, $label) 99 | if (($label eq 'Cause-Effect(e1,e2)') || ($label eq 'Cause-Effect(e2,e1)') || 100 | ($label eq 'Component-Whole(e1,e2)') || ($label eq 'Component-Whole(e2,e1)') || 101 | ($label eq 'Content-Container(e1,e2)') || ($label eq 'Content-Container(e2,e1)') || 102 | ($label eq 'Entity-Destination(e1,e2)') || ($label eq 'Entity-Destination(e2,e1)') || 103 | ($label eq 'Entity-Origin(e1,e2)') || ($label eq 'Entity-Origin(e2,e1)') || 104 | ($label eq 'Instrument-Agency(e1,e2)') || ($label eq 'Instrument-Agency(e2,e1)') || 105 | ($label eq 'Member-Collection(e1,e2)') || ($label eq 'Member-Collection(e2,e1)') || 106 | ($label eq 'Message-Topic(e1,e2)') || ($label eq 'Message-Topic(e2,e1)') || 107 | ($label eq 'Product-Producer(e1,e2)') || ($label eq 'Product-Producer(e2,e1)')); 108 | 109 | return (-1, ()); 110 | } 111 | -------------------------------------------------------------------------------- /research/evaluation/semeval2010_task8_scorer.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # 3 | # 4 | # Author: Preslav Nakov 5 | # nakov@comp.nus.edu.sg 6 | # National University of Singapore 7 | # 8 | # WHAT: This is the official scorer for SemEval-2010 Task #8. 9 | # 10 | # 11 | # Last modified: March 22, 2010 12 | # 13 | # Current version: 1.2 14 | # 15 | # Revision history: 16 | # - Version 1.2 (fixed a bug in the precision for the scoring of (iii)) 17 | # - Version 1.1 (fixed a bug in the calculation of accuracy) 18 | # 19 | # 20 | # Use: 21 | # semeval2010_task8_scorer-v1.1.pl 22 | # 23 | # Example2: 24 | # semeval2010_task8_scorer-v1.1.pl proposed_answer1.txt answer_key1.txt > result_scores1.txt 25 | # semeval2010_task8_scorer-v1.1.pl proposed_answer2.txt answer_key2.txt > result_scores2.txt 26 | # semeval2010_task8_scorer-v1.1.pl proposed_answer3.txt answer_key3.txt > result_scores3.txt 27 | # 28 | # Description: 29 | # The scorer takes as input a proposed classification file and an answer key file. 30 | # Both files should contain one prediction per line in the format " " 31 | # with a TAB as a separator, e.g., 32 | # 1 Component-Whole(e2,e1) 33 | # 2 Other 34 | # 3 Instrument-Agency(e2,e1) 35 | # ... 36 | # The files do not have to be sorted in any way and the first file can have predictions 37 | # for a subset of the IDs in the second file only, e.g., because hard examples have been skipped. 38 | # Repetitions of IDs are not allowed in either of the files. 39 | # 40 | # The scorer calculates and outputs the following statistics: 41 | # (1) confusion matrix, which shows 42 | # - the sums for each row/column: -SUM- 43 | # - the number of skipped examples: skip 44 | # - the number of examples with correct relation, but wrong directionality: xDIRx 45 | # - the number of examples in the answer key file: ACTUAL ( = -SUM- + skip + xDIRx ) 46 | # (2) accuracy and coverage 47 | # (3) precision (P), recall (R), and F1-score for each relation 48 | # (4) micro-averaged P, R, F1, where the calculations ignore the Other category. 49 | # (5) macro-averaged P, R, F1, where the calculations ignore the Other category. 50 | # 51 | # Note that in scores (4) and (5), skipped examples are equivalent to those classified as Other. 52 | # So are examples classified as relations that do not exist in the key file (which is probably not optimal). 53 | # 54 | # The scoring is done three times: 55 | # (i) as a (2*9+1)-way classification 56 | # (ii) as a (9+1)-way classification, with directionality ignored 57 | # (iii) as a (9+1)-way classification, with directionality taken into account. 58 | # 59 | # The official score is the macro-averaged F1-score for (iii). 60 | # 61 | 62 | use strict; 63 | 64 | 65 | ############### 66 | ### I/O ### 67 | ############### 68 | 69 | if ($#ARGV != 1) { 70 | die "Usage:\nsemeval2010_task8_scorer.pl \n"; 71 | } 72 | 73 | my $PROPOSED_ANSWERS_FILE_NAME = $ARGV[0]; 74 | my $ANSWER_KEYS_FILE_NAME = $ARGV[1]; 75 | 76 | 77 | ################ 78 | ### MAIN ### 79 | ################ 80 | 81 | my (%confMatrix19way, %confMatrix10wayNoDir, %confMatrix10wayWithDir) = (); 82 | my (%idsProposed, %idsAnswer) = (); 83 | my (%allLabels19waylAnswer, %allLabels10wayAnswer) = (); 84 | my (%allLabels19wayProposed, %allLabels10wayNoDirProposed, %allLabels10wayWithDirProposed) = (); 85 | 86 | ### 1. Read the file contents 87 | my $totalProposed = &readFileIntoHash($PROPOSED_ANSWERS_FILE_NAME, \%idsProposed); 88 | my $totalAnswer = &readFileIntoHash($ANSWER_KEYS_FILE_NAME, \%idsAnswer); 89 | 90 | ### 2. Calculate the confusion matrices 91 | foreach my $id (keys %idsProposed) { 92 | 93 | ### 2.1. Unexpected IDs are not allowed 94 | die "File $PROPOSED_ANSWERS_FILE_NAME contains a bad ID: '$id'" 95 | if (!defined($idsAnswer{$id})); 96 | 97 | ### 2.2. Update the 19-way confusion matrix 98 | my $labelProposed = $idsProposed{$id}; 99 | my $labelAnswer = $idsAnswer{$id}; 100 | $confMatrix19way{$labelProposed}{$labelAnswer}++; 101 | $allLabels19wayProposed{$labelProposed}++; 102 | 103 | ### 2.3. Update the 10-way confusion matrix *without* direction 104 | my $labelProposedNoDir = $labelProposed; 105 | my $labelAnswerNoDir = $labelAnswer; 106 | $labelProposedNoDir =~ s/\(e[12],e[12]\)[\n\r]*$//; 107 | $labelAnswerNoDir =~ s/\(e[12],e[12]\)[\n\r]*$//; 108 | $confMatrix10wayNoDir{$labelProposedNoDir}{$labelAnswerNoDir}++; 109 | $allLabels10wayNoDirProposed{$labelProposedNoDir}++; 110 | 111 | ### 2.4. Update the 10-way confusion matrix *with* direction 112 | if ($labelProposed eq $labelAnswer) { ## both relation and direction match 113 | $confMatrix10wayWithDir{$labelProposedNoDir}{$labelAnswerNoDir}++; 114 | $allLabels10wayWithDirProposed{$labelProposedNoDir}++; 115 | } 116 | elsif ($labelProposedNoDir eq $labelAnswerNoDir) { ## the relations match, but the direction is wrong 117 | $confMatrix10wayWithDir{'WRONG_DIR'}{$labelAnswerNoDir}++; 118 | $allLabels10wayWithDirProposed{'WRONG_DIR'}++; 119 | } 120 | else { ### Wrong relation 121 | $confMatrix10wayWithDir{$labelProposedNoDir}{$labelAnswerNoDir}++; 122 | $allLabels10wayWithDirProposed{$labelProposedNoDir}++; 123 | } 124 | } 125 | 126 | ### 3. Calculate the ground truth distributions 127 | foreach my $id (keys %idsAnswer) { 128 | 129 | ### 3.1. Update the 19-way answer distribution 130 | my $labelAnswer = $idsAnswer{$id}; 131 | $allLabels19waylAnswer{$labelAnswer}++; 132 | 133 | ### 3.2. Update the 10-way answer distribution 134 | my $labelAnswerNoDir = $labelAnswer; 135 | $labelAnswerNoDir =~ s/\(e[12],e[12]\)[\n\r]*$//; 136 | $allLabels10wayAnswer{$labelAnswerNoDir}++; 137 | } 138 | 139 | ### 4. Check for proposed classes that are not contained in the answer key file: this may happen in cross-validation 140 | foreach my $labelProposed (sort keys %allLabels19wayProposed) { 141 | if (!defined($allLabels19waylAnswer{$labelProposed})) { 142 | print "!!!WARNING!!! The proposed file contains $allLabels19wayProposed{$labelProposed} label(s) of type '$labelProposed', which is NOT present in the key file.\n\n"; 143 | } 144 | } 145 | 146 | ### 4. 19-way evaluation with directionality 147 | print "<<< (2*9+1)-WAY EVALUATION (USING DIRECTIONALITY)>>>:\n\n"; 148 | &evaluate(\%confMatrix19way, \%allLabels19wayProposed, \%allLabels19waylAnswer, $totalProposed, $totalAnswer, 0); 149 | 150 | ### 5. Evaluate without directionality 151 | print "<<< (9+1)-WAY EVALUATION IGNORING DIRECTIONALITY >>>:\n\n"; 152 | &evaluate(\%confMatrix10wayNoDir, \%allLabels10wayNoDirProposed, \%allLabels10wayAnswer, $totalProposed, $totalAnswer, 0); 153 | 154 | ### 6. Evaluate without directionality 155 | print "<<< (9+1)-WAY EVALUATION TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL >>>:\n\n"; 156 | my $officialScore = &evaluate(\%confMatrix10wayWithDir, \%allLabels10wayWithDirProposed, \%allLabels10wayAnswer, $totalProposed, $totalAnswer, 1); 157 | 158 | ### 7. Output the official score 159 | printf "<<< The official score is (9+1)-way evaluation with directionality taken into account: macro-averaged F1 = %0.2f%s >>>\n", $officialScore, '%'; 160 | 161 | 162 | ################ 163 | ### SUBS ### 164 | ################ 165 | 166 | sub getIDandLabel() { 167 | my $line = shift; 168 | return (-1,()) if ($line !~ /^([0-9]+)\t([^\r]+)\r?\n$/); 169 | 170 | my ($id, $label) = ($1, $2); 171 | 172 | return ($id, '_Other') if ($label eq 'Other'); 173 | 174 | return ($id, $label) 175 | if (($label eq 'Cause-Effect(e1,e2)') || ($label eq 'Cause-Effect(e2,e1)') || 176 | ($label eq 'Component-Whole(e1,e2)') || ($label eq 'Component-Whole(e2,e1)') || 177 | ($label eq 'Content-Container(e1,e2)') || ($label eq 'Content-Container(e2,e1)') || 178 | ($label eq 'Entity-Destination(e1,e2)') || ($label eq 'Entity-Destination(e2,e1)') || 179 | ($label eq 'Entity-Origin(e1,e2)') || ($label eq 'Entity-Origin(e2,e1)') || 180 | ($label eq 'Instrument-Agency(e1,e2)') || ($label eq 'Instrument-Agency(e2,e1)') || 181 | ($label eq 'Member-Collection(e1,e2)') || ($label eq 'Member-Collection(e2,e1)') || 182 | ($label eq 'Message-Topic(e1,e2)') || ($label eq 'Message-Topic(e2,e1)') || 183 | ($label eq 'Product-Producer(e1,e2)') || ($label eq 'Product-Producer(e2,e1)')); 184 | 185 | return (-1, ()); 186 | } 187 | 188 | 189 | sub readFileIntoHash() { 190 | my ($fname, $ids) = @_; 191 | open(INPUT, $fname) or die "Failed to open $fname for text reading.\n"; 192 | my $lineNo = 0; 193 | while () { 194 | $lineNo++; 195 | my ($id, $label) = &getIDandLabel($_); 196 | die "Bad file format on line $lineNo: '$_'\n" if ($id < 0); 197 | if (defined $$ids{$id}) { 198 | s/[\n\r]*$//; 199 | die "Bad file format on line $lineNo (ID $id is already defined): '$_'\n"; 200 | } 201 | $$ids{$id} = $label; 202 | } 203 | close(INPUT) or die "Failed to close $fname.\n"; 204 | return $lineNo; 205 | } 206 | 207 | 208 | sub evaluate() { 209 | my ($confMatrix, $allLabelsProposed, $allLabelsAnswer, $totalProposed, $totalAnswer, $useWrongDir) = @_; 210 | 211 | ### 0. Create a merged list for the confusion matrix 212 | my @allLabels = (); 213 | &mergeLabelLists($allLabelsAnswer, $allLabelsProposed, \@allLabels); 214 | 215 | ### 1. Print the confusion matrix heading 216 | print "Confusion matrix:\n"; 217 | print " "; 218 | foreach my $label (@allLabels) { 219 | printf " %4s", &getShortRelName($label, $allLabelsAnswer); 220 | } 221 | print " <-- classified as\n"; 222 | print " +"; 223 | foreach my $label (@allLabels) { 224 | print "-----"; 225 | } 226 | if ($useWrongDir) { 227 | print "+ -SUM- xDIRx skip ACTUAL\n"; 228 | } 229 | else { 230 | print "+ -SUM- skip ACTUAL\n"; 231 | } 232 | 233 | ### 2. Print the rest of the confusion matrix 234 | my $freqCorrect = 0; 235 | my $ind = 1; 236 | my $otherSkipped = 0; 237 | foreach my $labelAnswer (sort keys %{$allLabelsAnswer}) { 238 | 239 | ### 2.1. Output the short relation label 240 | printf " %4s |", &getShortRelName($labelAnswer, $allLabelsAnswer); 241 | 242 | ### 2.2. Output a row of the confusion matrix 243 | my $sumProposed = 0; 244 | foreach my $labelProposed (@allLabels) { 245 | $$confMatrix{$labelProposed}{$labelAnswer} = 0 246 | if (!defined($$confMatrix{$labelProposed}{$labelAnswer})); 247 | printf "%4d ", $$confMatrix{$labelProposed}{$labelAnswer}; 248 | $sumProposed += $$confMatrix{$labelProposed}{$labelAnswer}; 249 | } 250 | 251 | ### 2.3. Output the horizontal sums 252 | if ($useWrongDir) { 253 | my $ans = defined($$allLabelsAnswer{$labelAnswer}) ? $$allLabelsAnswer{$labelAnswer} : 0; 254 | $$confMatrix{'WRONG_DIR'}{$labelAnswer} = 0 if (!defined $$confMatrix{'WRONG_DIR'}{$labelAnswer}); 255 | printf "| %4d %4d %4d %6d\n", $sumProposed, $$confMatrix{'WRONG_DIR'}{$labelAnswer}, $ans - $sumProposed - $$confMatrix{'WRONG_DIR'}{$labelAnswer}, $ans; 256 | if ($labelAnswer eq '_Other') { 257 | $otherSkipped = $ans - $sumProposed - $$confMatrix{'WRONG_DIR'}{$labelAnswer}; 258 | } 259 | } 260 | else { 261 | my $ans = defined($$allLabelsAnswer{$labelAnswer}) ? $$allLabelsAnswer{$labelAnswer} : 0; 262 | printf "| %4d %4d %4d\n", $sumProposed, $ans - $sumProposed, $ans; 263 | if ($labelAnswer eq '_Other') { 264 | $otherSkipped = $ans - $sumProposed; 265 | } 266 | } 267 | 268 | $ind++; 269 | 270 | $$confMatrix{$labelAnswer}{$labelAnswer} = 0 271 | if (!defined($$confMatrix{$labelAnswer}{$labelAnswer})); 272 | $freqCorrect += $$confMatrix{$labelAnswer}{$labelAnswer}; 273 | } 274 | print " +"; 275 | foreach (@allLabels) { 276 | print "-----"; 277 | } 278 | print "+\n"; 279 | 280 | ### 3. Print the vertical sums 281 | print " -SUM- "; 282 | foreach my $labelProposed (@allLabels) { 283 | $$allLabelsProposed{$labelProposed} = 0 284 | if (!defined $$allLabelsProposed{$labelProposed}); 285 | printf "%4d ", $$allLabelsProposed{$labelProposed}; 286 | } 287 | if ($useWrongDir) { 288 | printf " %4d %4d %4d %6d\n\n", $totalProposed - $$allLabelsProposed{'WRONG_DIR'}, $$allLabelsProposed{'WRONG_DIR'}, $totalAnswer - $totalProposed, $totalAnswer; 289 | } 290 | else { 291 | printf " %4d %4d %4d\n\n", $totalProposed, $totalAnswer - $totalProposed, $totalAnswer; 292 | } 293 | 294 | ### 4. Output the coverage 295 | my $coverage = 100.0 * $totalProposed / $totalAnswer; 296 | printf "%s%d%s%d%s%5.2f%s", 'Coverage = ', $totalProposed, '/', $totalAnswer, ' = ', $coverage, "\%\n"; 297 | 298 | ### 5. Output the accuracy 299 | my $accuracy = 100.0 * $freqCorrect / $totalProposed; 300 | printf "%s%d%s%d%s%5.2f%s", 'Accuracy (calculated for the above confusion matrix) = ', $freqCorrect, '/', $totalProposed, ' = ', $accuracy, "\%\n"; 301 | 302 | ### 6. Output the accuracy considering all skipped to be wrong 303 | $accuracy = 100.0 * $freqCorrect / $totalAnswer; 304 | printf "%s%d%s%d%s%5.2f%s", 'Accuracy (considering all skipped examples as Wrong) = ', $freqCorrect, '/', $totalAnswer, ' = ', $accuracy, "\%\n"; 305 | 306 | ### 7. Calculate accuracy with all skipped examples considered Other 307 | my $accuracyWithOther = 100.0 * ($freqCorrect + $otherSkipped) / $totalAnswer; 308 | printf "%s%d%s%d%s%5.2f%s", 'Accuracy (considering all skipped examples as Other) = ', ($freqCorrect + $otherSkipped), '/', $totalAnswer, ' = ', $accuracyWithOther, "\%\n"; 309 | 310 | ### 8. Output P, R, F1 for each relation 311 | my ($macroP, $macroR, $macroF1) = (0, 0, 0); 312 | my ($microCorrect, $microProposed, $microAnswer) = (0, 0, 0); 313 | print "\nResults for the individual relations:\n"; 314 | foreach my $labelAnswer (sort keys %{$allLabelsAnswer}) { 315 | 316 | ### 8.1. Consider all wrong directionalities as wrong classification decisions 317 | my $wrongDirectionCnt = 0; 318 | if ($useWrongDir && defined $$confMatrix{'WRONG_DIR'}{$labelAnswer}) { 319 | $wrongDirectionCnt = $$confMatrix{'WRONG_DIR'}{$labelAnswer}; 320 | } 321 | 322 | ### 8.2. Prevent Perl complains about unintialized values 323 | if (!defined($$allLabelsProposed{$labelAnswer})) { 324 | $$allLabelsProposed{$labelAnswer} = 0; 325 | } 326 | 327 | ### 8.3. Calculate P/R/F1 328 | my $P = (0 == $$allLabelsProposed{$labelAnswer}) ? 0 329 | : 100.0 * $$confMatrix{$labelAnswer}{$labelAnswer} / ($$allLabelsProposed{$labelAnswer} + $wrongDirectionCnt); 330 | my $R = (0 == $$allLabelsAnswer{$labelAnswer}) ? 0 331 | : 100.0 * $$confMatrix{$labelAnswer}{$labelAnswer} / $$allLabelsAnswer{$labelAnswer}; 332 | my $F1 = (0 == $P + $R) ? 0 : 2 * $P * $R / ($P + $R); 333 | 334 | ### 8.4. Output P/R/F1 335 | if ($useWrongDir) { 336 | printf "%25s%s%4d%s(%4d +%4d)%s%6.2f", $labelAnswer, 337 | " : P = ", $$confMatrix{$labelAnswer}{$labelAnswer}, '/', $$allLabelsProposed{$labelAnswer}, $wrongDirectionCnt, ' = ', $P; 338 | } 339 | else { 340 | printf "%25s%s%4d%s%4d%s%6.2f", $labelAnswer, 341 | " : P = ", $$confMatrix{$labelAnswer}{$labelAnswer}, '/', ($$allLabelsProposed{$labelAnswer} + $wrongDirectionCnt), ' = ', $P; 342 | } 343 | printf"%s%4d%s%4d%s%6.2f%s%6.2f%s\n", 344 | "% R = ", $$confMatrix{$labelAnswer}{$labelAnswer}, '/', $$allLabelsAnswer{$labelAnswer}, ' = ', $R, 345 | "% F1 = ", $F1, '%'; 346 | 347 | ### 8.5. Accumulate statistics for micro/macro-averaging 348 | if ($labelAnswer ne '_Other') { 349 | $macroP += $P; 350 | $macroR += $R; 351 | $macroF1 += $F1; 352 | $microCorrect += $$confMatrix{$labelAnswer}{$labelAnswer}; 353 | $microProposed += $$allLabelsProposed{$labelAnswer} + $wrongDirectionCnt; 354 | $microAnswer += $$allLabelsAnswer{$labelAnswer}; 355 | } 356 | } 357 | 358 | ### 9. Output the micro-averaged P, R, F1 359 | my $microP = (0 == $microProposed) ? 0 : 100.0 * $microCorrect / $microProposed; 360 | my $microR = (0 == $microAnswer) ? 0 : 100.0 * $microCorrect / $microAnswer; 361 | my $microF1 = (0 == $microP + $microR) ? 0 : 2.0 * $microP * $microR / ($microP + $microR); 362 | print "\nMicro-averaged result (excluding Other):\n"; 363 | printf "%s%4d%s%4d%s%6.2f%s%4d%s%4d%s%6.2f%s%6.2f%s\n", 364 | "P = ", $microCorrect, '/', $microProposed, ' = ', $microP, 365 | "% R = ", $microCorrect, '/', $microAnswer, ' = ', $microR, 366 | "% F1 = ", $microF1, '%'; 367 | 368 | ### 10. Output the macro-averaged P, R, F1 369 | my $distinctLabelsCnt = keys %{$allLabelsAnswer}; 370 | ## -1, if '_Other' exists 371 | $distinctLabelsCnt-- if (defined $$allLabelsAnswer{'_Other'}); 372 | 373 | $macroP /= $distinctLabelsCnt; # first divide by the number of non-Other categories 374 | $macroR /= $distinctLabelsCnt; 375 | $macroF1 /= $distinctLabelsCnt; 376 | print "\nMACRO-averaged result (excluding Other):\n"; 377 | printf "%s%6.2f%s%6.2f%s%6.2f%s\n\n\n\n", "P = ", $macroP, "%\tR = ", $macroR, "%\tF1 = ", $macroF1, '%'; 378 | 379 | ### 11. Return the official score 380 | return $macroF1; 381 | } 382 | 383 | 384 | sub getShortRelName() { 385 | my ($relName, $hashToCheck) = @_; 386 | return '_O_' if ($relName eq '_Other'); 387 | die "relName='$relName'" if ($relName !~ /^(.)[^\-]+\-(.)/); 388 | my $result = (defined $$hashToCheck{$relName}) ? "$1\-$2" : "*$1$2"; 389 | if ($relName =~ /\(e([12])/) { 390 | $result .= $1; 391 | } 392 | return $result; 393 | } 394 | 395 | sub mergeLabelLists() { 396 | my ($hash1, $hash2, $mergedList) = @_; 397 | foreach my $key (sort keys %{$hash1}) { 398 | push @{$mergedList}, $key if ($key ne 'WRONG_DIR'); 399 | } 400 | foreach my $key (sort keys %{$hash2}) { 401 | push @{$mergedList}, $key if (($key ne 'WRONG_DIR') && !defined($$hash1{$key})); 402 | } 403 | } 404 | -------------------------------------------------------------------------------- /research/evaluation/semeval2010_writer.py: -------------------------------------------------------------------------------- 1 | def file_writer(classes, dirs, ids, file_, inv_map, class_type="true"): 2 | length = len(dirs) 3 | for count in range(length): 4 | if class_type == "true": 5 | value = classes[count].item() 6 | dir_ = dirs[count].item() 7 | else: 8 | value = classes[count].max(0)[1].item() 9 | dir_ = dirs[count].max(0)[1].item() 10 | txt = str(ids[count].item()) + "\t" + inv_map[value] 11 | if inv_map[value] != "Other": 12 | if dir_ == 0: 13 | txt += "(e1,e2)" 14 | else: 15 | txt += "(e2,e1)" 16 | file_.write(txt + "\n") 17 | return file_ 18 | -------------------------------------------------------------------------------- /research/evaluation/srl-eval.pl: -------------------------------------------------------------------------------- 1 | #! /usr/bin/perl 2 | 3 | ################################################################## 4 | # 5 | # srl-eval.pl : evaluation program for the CoNLL-2005 Shared Task 6 | # 7 | # Authors : Xavier Carreras and Lluis Marquez 8 | # Contact : carreras@lsi.upc.edu 9 | # 10 | # Created : January 2004 11 | # Modified: 12 | # 2005/04/21 minor update; for perl-5.8 the table in LateX 13 | # did not print correctly 14 | # 2005/02/05 minor updates for CoNLL-2005 15 | # 16 | ################################################################## 17 | 18 | 19 | use strict; 20 | 21 | 22 | 23 | ############################################################ 24 | # A r g u m e n t s a n d H e l p 25 | 26 | use Getopt::Long; 27 | my %options; 28 | GetOptions(\%options, 29 | "latex", # latex output 30 | "C", # confusion matrix 31 | "noW" 32 | ); 33 | 34 | 35 | my $script = "srl-eval.pl"; 36 | my $help = << "end_of_help;"; 37 | Usage: srl-eval.pl 38 | Options: 39 | -latex Produce a results table in LaTeX 40 | -C Produce a confusion matrix of gold vs. predicted argments, wrt. their role 41 | 42 | end_of_help; 43 | 44 | 45 | ############################################################ 46 | # M A I N P R O G R A M 47 | 48 | 49 | my $ns = 0; # number of sentence 50 | my $ntargets = 0; # number of target verbs 51 | my %E; # evaluation results 52 | my %C; # confusion matrix 53 | 54 | my %excluded = ( V => 1); 55 | 56 | ## 57 | 58 | # open files 59 | 60 | if (@ARGV != 2) { 61 | print $help; 62 | exit; 63 | } 64 | 65 | my $goldfile = shift @ARGV; 66 | my $predfile = shift @ARGV; 67 | 68 | if ($goldfile =~ /\.gz/) { 69 | open GOLD, "gunzip -c $goldfile |" or die "$script: could not open gzipped file of gold props ($goldfile)! $!\n"; 70 | } 71 | else { 72 | open GOLD, $goldfile or die "$script: could not open file of gold props ($goldfile)! $!\n"; 73 | } 74 | if ($predfile =~ /\.gz/) { 75 | open PRED, "gunzip -c $predfile |" or die "$script: could not open gzipped file of predicted props ($predfile)! $!\n"; 76 | } 77 | else { 78 | open PRED, $predfile or die "$script: could not open file of predicted props ($predfile)! $!\n"; 79 | } 80 | 81 | 82 | ## 83 | # read and evaluate propositions, sentence by sentence 84 | 85 | my $s = SRL::sentence->read_props($ns, GOLD => \*GOLD, PRED => \*PRED); 86 | 87 | while ($s) { 88 | 89 | my $prop; 90 | 91 | my (@G, @P, $i); 92 | 93 | map { $G[$_->position] = $_ } $s->gold_props; 94 | map { $P[$_->position] = $_ } $s->pred_props; 95 | 96 | for($i=0; $i<@G; $i++) { 97 | my $gprop = $G[$i]; 98 | my $pprop = $P[$i]; 99 | 100 | if ($pprop and !$gprop) { 101 | !$options{noW} and print STDERR "WARNING : sentence $ns : verb ", $pprop->verb, 102 | " at position ", $pprop->position, " : found predicted prop without its gold reference! Skipping prop!\n"; 103 | } 104 | elsif ($gprop) { 105 | if (!$pprop) { 106 | !$options{noW} and print STDERR "WARNING : sentence $ns : verb ", $gprop->verb, 107 | " at position ", $gprop->position, " : missing predicted prop! Counting all arguments as missed!\n"; 108 | $pprop = SRL::prop->new($gprop->verb, $gprop->position); 109 | } 110 | elsif ($gprop->verb ne $pprop->verb) { 111 | !$options{noW} and print STDERR "WARNING : sentence $ns : props do not match : expecting ", 112 | $gprop->verb, " at position ", $gprop->position, 113 | ", found ", $pprop->verb, " at position ", $pprop->position, "! Counting all gold arguments as missed!\n"; 114 | $pprop = SRL::prop->new($gprop->verb, $gprop->position); 115 | } 116 | 117 | $ntargets++; 118 | my %e = evaluate_proposition($gprop, $pprop); 119 | 120 | 121 | # Update global evaluation results 122 | 123 | $E{ok} += $e{ok}; 124 | $E{op} += $e{op}; 125 | $E{ms} += $e{ms}; 126 | $E{ptv} += $e{ptv}; 127 | 128 | my $t; 129 | foreach $t ( keys %{$e{T}} ) { 130 | $E{T}{$t}{ok} += $e{T}{$t}{ok}; 131 | $E{T}{$t}{op} += $e{T}{$t}{op}; 132 | $E{T}{$t}{ms} += $e{T}{$t}{ms}; 133 | } 134 | foreach $t ( keys %{$e{E}} ) { 135 | $E{E}{$t}{ok} += $e{E}{$t}{ok}; 136 | $E{E}{$t}{op} += $e{E}{$t}{op}; 137 | $E{E}{$t}{ms} += $e{E}{$t}{ms}; 138 | } 139 | 140 | if ($options{C}) { 141 | update_confusion_matrix(\%C, $gprop, $pprop); 142 | } 143 | } 144 | } 145 | 146 | $ns++; 147 | $s = SRL::sentence->read_props($ns, GOLD => \*GOLD, PRED => \*PRED); 148 | 149 | } 150 | 151 | 152 | # Print Evaluation results 153 | my $t; 154 | 155 | if ($options{latex}) { 156 | print '\begin{table}[t]', "\n"; 157 | print '\centering', "\n"; 158 | print '\begin{tabular}{|l|r|r|r|}\cline{2-4}', "\n"; 159 | print '\multicolumn{1}{l|}{}', "\n"; 160 | print ' & Precision & Recall & F$_{\beta=1}$', '\\\\', "\n", '\hline', "\n"; #' 161 | 162 | printf("%-10s & %6.2f\\%% & %6.2f\\%% & %6.2f\\\\\n", "Overall", precrecf1($E{ok}, $E{op}, $E{ms})); 163 | print '\hline', "\n"; 164 | 165 | foreach $t ( sort keys %{$E{T}} ) { 166 | printf("%-10s & %6.2f\\%% & %6.2f\\%% & %6.2f\\\\\n", $t, precrecf1($E{T}{$t}{ok}, $E{T}{$t}{op}, $E{T}{$t}{ms})); 167 | } 168 | print '\hline', "\n"; 169 | 170 | if (%excluded) { 171 | print '\hline', "\n"; 172 | foreach $t ( sort keys %{$E{E}} ) { 173 | printf("%-10s & %6.2f\\%% & %6.2f\\%% & %6.2f\\\\\n", $t, precrecf1($E{E}{$t}{ok}, $E{E}{$t}{op}, $E{E}{$t}{ms})); 174 | } 175 | print '\hline', "\n"; 176 | } 177 | 178 | print '\end{tabular}', "\n"; 179 | print '\end{table}', "\n"; 180 | } 181 | else { 182 | printf("Number of Sentences : %6d\n", $ns); 183 | printf("Number of Propositions : %6d\n", $ntargets); 184 | printf("Percentage of perfect props : %6.2f\n",($ntargets>0 ? 100*$E{ptv}/$ntargets : 0)); 185 | print "\n"; 186 | 187 | printf("%10s %6s %6s %6s %6s %6s %6s\n", "", "corr.", "excess", "missed", "prec.", "rec.", "F1"); 188 | print "------------------------------------------------------------\n"; 189 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 190 | "Overall", $E{ok}, $E{op}, $E{ms}, precrecf1($E{ok}, $E{op}, $E{ms})); 191 | # print "------------------------------------------------------------\n"; 192 | print "----------\n"; 193 | 194 | # printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 195 | # "all - {V}", $O2{ok}, $O2{op}, $O2{ms}, precrecf1($O2{ok}, $O2{op}, $O2{ms})); 196 | # print "------------------------------------------------------------\n"; 197 | 198 | foreach $t ( sort keys %{$E{T}} ) { 199 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 200 | $t, $E{T}{$t}{ok}, $E{T}{$t}{op}, $E{T}{$t}{ms}, precrecf1($E{T}{$t}{ok}, $E{T}{$t}{op}, $E{T}{$t}{ms})); 201 | } 202 | print "------------------------------------------------------------\n"; 203 | 204 | foreach $t ( sort keys %{$E{E}} ) { 205 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 206 | $t, $E{E}{$t}{ok}, $E{E}{$t}{op}, $E{E}{$t}{ms}, precrecf1($E{E}{$t}{ok}, $E{E}{$t}{op}, $E{E}{$t}{ms})); 207 | } 208 | print "------------------------------------------------------------\n"; 209 | } 210 | 211 | 212 | # print confusion matrix 213 | if ($options{C}) { 214 | 215 | my $k; 216 | 217 | # Evaluation of Unlabelled arguments 218 | my ($uok, $uop, $ums, $uacc) = (0,0,0,0); 219 | foreach $k ( grep { $_ ne "-NONE-" && $_ ne "V" } keys %C ) { 220 | map { $uok += $C{$k}{$_} } grep { $_ ne "-NONE-" && $_ ne "V" } keys %{$C{$k}}; 221 | $uacc += $C{$k}{$k}; 222 | $ums += $C{$k}{"-NONE-"}; 223 | } 224 | map { $uop += $C{"-NONE-"}{$_} } grep { $_ ne "-NONE-" && $_ ne "V" } keys %{$C{"-NONE-"}}; 225 | 226 | print "--------------------------------------------------------------------\n"; 227 | printf("%10s %6s %6s %6s %6s %6s %6s %6s\n", "", "corr.", "excess", "missed", "prec.", "rec.", "F1", "lAcc"); 228 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f %6.2f\n", 229 | "Unlabeled", $uok, $uop, $ums, precrecf1($uok, $uop, $ums), 100*$uacc/$uok); 230 | print "--------------------------------------------------------------------\n"; 231 | 232 | 233 | 234 | print "\n---- Confusion Matrix: (one row for each correct role, with the distribution of predictions)\n"; 235 | 236 | my %AllKeys; 237 | map { $AllKeys{$_} = 1 } map { $_, keys %{$C{$_}} } keys %C; 238 | my @AllKeys = sort keys %AllKeys; 239 | 240 | 241 | 242 | my $i = -1; 243 | print " "; 244 | map { printf("%4d ", $i); $i++} @AllKeys; 245 | print "\n"; 246 | $i = -1; 247 | foreach $k ( @AllKeys ) { 248 | printf("%2d: %-8s ", $i++, $k); 249 | map { printf("%4d ", $C{$k}{$_}) } @AllKeys; 250 | print "\n"; 251 | } 252 | 253 | 254 | my ($t1,$t2); 255 | foreach $t1 ( sort keys %C ) { 256 | foreach $t2 ( sort keys %{$C{$t1}} ) { 257 | # printf(" %-6s vs %-6s : %-5d\n", $t1, $t2, $C{$t1}{$t2}); 258 | } 259 | } 260 | } 261 | 262 | # end of main program 263 | ##################### 264 | 265 | ############################################################ 266 | # S U B R O U T I N E S 267 | 268 | 269 | # evaluates a predicted proposition wrt the gold correct proposition 270 | # returns a hash with the following keys 271 | # ok : number of correctly predicted args 272 | # ms : number of missed args 273 | # op : number of over-predicted args 274 | # T : a hash indexed by argument types, where 275 | # each value is in turn a hash of {ok,ms,op} numbers 276 | # E : a hash indexed by excluded argument types, where 277 | # each value is in turn a hash of {ok,ms,op} numbers 278 | sub evaluate_proposition { 279 | my ($gprop, $pprop) = @_; 280 | 281 | my $o = $gprop->discriminate_args($pprop); 282 | 283 | my %e; 284 | 285 | my $a; 286 | foreach $a (@{$o->{ok}}) { 287 | if (!$excluded{$a->type}) { 288 | $e{ok}++; 289 | $e{T}{$a->type}{ok}++; 290 | } 291 | else { 292 | $e{E}{$a->type}{ok}++; 293 | } 294 | } 295 | foreach $a (@{$o->{op}}) { 296 | if (!$excluded{$a->type}) { 297 | $e{op}++; 298 | $e{T}{$a->type}{op}++; 299 | } 300 | else { 301 | $e{E}{$a->type}{op}++; 302 | } 303 | } 304 | foreach $a (@{$o->{ms}}) { 305 | if (!$excluded{$a->type}) { 306 | $e{ms}++; 307 | $e{T}{$a->type}{ms}++; 308 | } 309 | else { 310 | $e{E}{$a->type}{ms}++; 311 | } 312 | } 313 | 314 | $e{ptv} = (!$e{op} and !$e{ms}) ? 1 : 0; 315 | 316 | return %e; 317 | } 318 | 319 | 320 | # computes precision, recall and F1 measures 321 | sub precrecf1 { 322 | my ($ok, $op, $ms) = @_; 323 | 324 | my $p = ($ok + $op > 0) ? 100*$ok/($ok+$op) : 0; 325 | my $r = ($ok + $ms > 0) ? 100*$ok/($ok+$ms) : 0; 326 | 327 | my $f1 = ($p+$r>0) ? (2*$p*$r)/($p+$r) : 0; 328 | 329 | return ($p,$r,$f1); 330 | } 331 | 332 | 333 | 334 | 335 | sub update_confusion_matrix { 336 | my ($C, $gprop, $pprop) = @_; 337 | 338 | my $o = $gprop->discriminate_args($pprop, 0); 339 | 340 | my $a; 341 | foreach $a ( @{$o->{ok}} ) { 342 | my $g = shift @{$o->{eq}}; 343 | $C->{$g->type}{$a->type}++; 344 | } 345 | foreach $a ( @{$o->{ms}} ) { 346 | $C->{$a->type}{"-NONE-"}++; 347 | } 348 | foreach $a ( @{$o->{op}} ) { 349 | $C->{"-NONE-"}{$a->type}++; 350 | } 351 | } 352 | 353 | 354 | # end of script 355 | ############### 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | ################################################################################ 378 | # 379 | # Package s e n t e n c e 380 | # 381 | # February 2004 382 | # 383 | # Stores information of a sentence, namely words, chunks, clauses, 384 | # named entities and propositions (gold and predicted). 385 | # 386 | # Provides access methods. 387 | # Provides methods for reading/writing sentences from/to files in 388 | # CoNLL-2004/CoNLL-2005 formats. 389 | # 390 | # 391 | ################################################################################ 392 | 393 | 394 | package SRL::sentence; 395 | use strict; 396 | 397 | 398 | 399 | sub new { 400 | my ($pkg, $id) = @_; 401 | 402 | my $s = []; 403 | 404 | $s->[0] = $id; # sentence number 405 | $s->[1] = undef; # words (the list or the number of words) 406 | $s->[2] = []; # gold props 407 | $s->[3] = []; # predicted props 408 | $s->[4] = undef; # chunks 409 | $s->[5] = undef; # clauses 410 | $s->[6] = undef; # full syntactic tree 411 | $s->[7] = undef; # named entities 412 | 413 | return bless $s, $pkg; 414 | } 415 | 416 | #----- 417 | 418 | sub id { 419 | my $s = shift; 420 | return $s->[0]; 421 | } 422 | 423 | #----- 424 | 425 | sub length { 426 | my $s = shift; 427 | if (ref($s->[1])) { 428 | return scalar(@{$s->[1]}); 429 | } 430 | else { 431 | return $s->[1]; 432 | } 433 | } 434 | 435 | sub set_length { 436 | my $s = shift; 437 | $s->[1] = shift; 438 | } 439 | 440 | #----- 441 | 442 | # returns the i-th word of the sentence 443 | sub word { 444 | my ($s, $i) = @_; 445 | return $s->[1][$i]; 446 | } 447 | 448 | 449 | # returns the list of words of the sentence 450 | sub words { 451 | my $s = shift; 452 | if (@_) { 453 | return map { $s->[1][$_] } @_; 454 | } 455 | else { 456 | return @{$s->[1]}; 457 | } 458 | } 459 | 460 | sub ref_words { 461 | my $s = shift; 462 | return $s->[1]; 463 | } 464 | 465 | 466 | sub chunking { 467 | my $s = shift; 468 | return $s->[4]; 469 | } 470 | 471 | sub clausing { 472 | my $s = shift; 473 | return $s->[5]; 474 | } 475 | 476 | sub syntree { 477 | my $s = shift; 478 | return $s->[6]; 479 | } 480 | 481 | sub named_entities { 482 | my $s = shift; 483 | return $s->[7]; 484 | } 485 | 486 | #----- 487 | 488 | sub add_gold_props { 489 | my $s = shift; 490 | push @{$s->[2]}, @_; 491 | } 492 | 493 | sub gold_props { 494 | my $s = shift; 495 | return @{$s->[2]}; 496 | } 497 | 498 | sub add_pred_props { 499 | my $s = shift; 500 | push @{$s->[3]}, @_; 501 | } 502 | 503 | sub pred_props { 504 | my $s = shift; 505 | return @{$s->[3]}; 506 | } 507 | 508 | 509 | #------------------------------------------------------------ 510 | # I/O F U N C T I O N S 511 | #------------------------------------------------------------ 512 | 513 | # Reads a complete (words, synt, props) sentence from a stream 514 | # Returns: the reference to the sentence object or 515 | # undef if no sentence found 516 | # The propositions in the file are stored as gold props 517 | # For each gold prop, an empty predicted prop is created 518 | # 519 | # The %C hash contains the column number for each annotation of 520 | # the datafile. 521 | # 522 | sub read_from_stream { 523 | my ($pkg, $id, $fh, %C) = @_; 524 | 525 | if (!%C) { 526 | %C = ( words => 0, 527 | pos => 1, 528 | chunks => 2, 529 | clauses => 3, 530 | syntree => 4, 531 | ne => 5, 532 | props => 6 533 | ) 534 | } 535 | 536 | # my $k; 537 | # foreach $k ( "words", "pos", "props" ) { 538 | # if (!exists($C{$k}) { 539 | # die "sentence->read_from_stream :: undefined column number for $k.\n"; 540 | # } 541 | # } 542 | 543 | my $cols = read_columns($fh); 544 | 545 | if (!@$cols) { 546 | return undef; 547 | } 548 | 549 | my $s = $pkg->new($id); 550 | 551 | # words and PoS 552 | my $words = $cols->[$C{words}]; 553 | my $pos = $cols->[$C{pos}]; 554 | 555 | # initialize list of words 556 | $s->[1] = []; 557 | my $i; 558 | for ($i=0;$i<@$words;$i++) { 559 | push @{$s->[1]}, SRL::word->new($i, $words->[$i], $pos->[$i]); 560 | } 561 | 562 | my $c; 563 | 564 | # chunks 565 | if (exists($C{chunks})) { 566 | $c = $cols->[$C{chunks}]; 567 | # initialize chunking 568 | $s->[4] = SRL::phrase_set->new(); 569 | $s->[4]->load_SE_tagging(@$c); 570 | } 571 | 572 | # clauses 573 | if (exists($C{clauses})) { 574 | $c = $cols->[$C{clauses}]; 575 | # initialize clauses 576 | $s->[5] = SRL::phrase_set->new(); 577 | $s->[5]->load_SE_tagging(@$c); 578 | } 579 | 580 | # syntree 581 | if (exists($C{syntree})) { 582 | $c = $cols->[$C{syntree}]; 583 | # initialize syntree 584 | $s->[6] = SRL::syntree->new(); 585 | $s->[6]->load_SE_tagging($s->[1], @$c); 586 | } 587 | 588 | # named entities 589 | if (exists($C{ne})) { 590 | $c = $cols->[$C{ne}]; 591 | $s->[7] = SRL::phrase_set->new(); 592 | $s->[7]->load_SE_tagging(@$c); 593 | } 594 | 595 | 596 | my $i = 0; 597 | while ($i<$C{props}) { 598 | shift @$cols; 599 | $i++; 600 | } 601 | 602 | # gold props 603 | my $targets = shift @$cols or die "error :: reading sentence $id :: no targets found!\n"; 604 | if (@$cols) { 605 | $s->load_props($s->[2], $targets, $cols); 606 | } 607 | 608 | # initialize predicted props 609 | foreach $i ( grep { $targets->[$_] ne "-" } ( 0 .. scalar(@$targets)-1 ) ) { 610 | push @{$s->[3]}, SRL::prop->new($targets->[$i], $i); 611 | } 612 | 613 | return $s; 614 | } 615 | 616 | 617 | 618 | #------------------------------------------------------------ 619 | 620 | 621 | # reads the propositions of a sentence from files 622 | # allows to store propositions as gold and/or predicted, 623 | # by specifying filehandles as values in the %FILES hash 624 | # indexed by {GOLD,PRED} keys 625 | # expects: each prop file: first column specifying target verbs, 626 | # and remaining columns specifying arguments 627 | # returns a new sentence, containing the list of prop 628 | # objects, one for each column, in gold/pred contexts 629 | # returns undef when EOF 630 | sub read_props { 631 | my ($pkg, $id, %FILES) = @_; 632 | 633 | my $s = undef; 634 | my $length = undef; 635 | 636 | if (exists($FILES{GOLD})) { 637 | my $cols = read_columns($FILES{GOLD}); 638 | 639 | # end of file 640 | if (!@$cols) { 641 | return undef; 642 | } 643 | 644 | $s = $pkg->new($id); 645 | my $targets = shift @$cols; 646 | $length = scalar(@$targets); 647 | $s->set_length($length); 648 | $s->load_props($s->[2], $targets, $cols); 649 | } 650 | if (exists($FILES{PRED})) { 651 | my $cols = read_columns($FILES{PRED}); 652 | 653 | if (!defined($s)) { 654 | # end of file 655 | if (!@$cols) { 656 | return undef; 657 | } 658 | $s = $pkg->new($id); 659 | } 660 | my $targets = shift @$cols; 661 | 662 | if (defined($length)) { 663 | ($length != scalar(@$targets)) and 664 | die "ERROR : sentence $id : gold and pred sentences do not align correctly!\n"; 665 | } 666 | else { 667 | $length = scalar(@$targets); 668 | $s->set_length($length); 669 | } 670 | $s->load_props($s->[3], $targets, $cols); 671 | } 672 | 673 | return $s; 674 | } 675 | 676 | 677 | sub load_props { 678 | my ($s, $where, $targets, $cols) = @_; 679 | 680 | my $i; 681 | for ($i=0; $i<@$targets; $i++) { 682 | if ($targets->[$i] ne "-") { 683 | my $prop = SRL::prop->new($targets->[$i], $i); 684 | 685 | my $col = shift @$cols; 686 | if (defined($col)) { 687 | # print "SE Tagging: ", join(" ", @$col), "\n"; 688 | $prop->load_SE_tagging(@$col); 689 | } 690 | else { 691 | print STDERR "WARNING : sentence ", $s->id, " : can't find column of args for prop ", $prop->verb, "!\n"; 692 | } 693 | push @$where, $prop; 694 | } 695 | } 696 | } 697 | 698 | 699 | # writes a sentence to an output stream 700 | # allows to specify which parts of the sentence are written 701 | # by giving true values to the %WHAT hash, indexed by 702 | # {WORDS,SYNT,GOLD,PRED} keys 703 | sub write_to_stream { 704 | my ($s, $fh, %WHAT) = @_; 705 | 706 | if (!%WHAT) { 707 | %WHAT = ( WORDS => 1, 708 | PSYNT => 1, 709 | FSYNT => 1, 710 | GOLD => 0, 711 | PRED => 1 712 | ); 713 | } 714 | 715 | my @columns; 716 | 717 | if ($WHAT{WORDS}) { 718 | my @words = map { $_->form } $s->words; 719 | push @columns, \@words; 720 | } 721 | if ($WHAT{PSYNT}) { 722 | my @pos = map { $_->pos } $s->words; 723 | push @columns, \@pos; 724 | my @chunks = $s->chunking->to_SE_tagging($s->length); 725 | push @columns, \@chunks; 726 | my @clauses = $s->clausing->to_SE_tagging($s->length); 727 | push @columns, \@clauses; 728 | } 729 | if ($WHAT{FSYNT}) { 730 | my @pos = map { $_->pos } $s->words; 731 | push @columns, \@pos; 732 | my @sttags = $s->syntree->to_SE_tagging(); 733 | push @columns, \@sttags; 734 | } 735 | if ($WHAT{GOLD}) { 736 | push @columns, $s->props_to_columns($s->[2]); 737 | } 738 | if ($WHAT{PRED}) { 739 | push @columns, $s->props_to_columns($s->[3]); 740 | } 741 | if ($WHAT{PROPS}) { 742 | push @columns, $s->props_to_columns($WHAT{PROPS}); 743 | } 744 | 745 | 746 | reformat_columns(\@columns); 747 | 748 | # finally, print columns word by word 749 | my $i; 750 | for ($i=0;$i<$s->length;$i++) { 751 | print $fh join(" ", map { $_->[$i] } @columns), "\n"; 752 | } 753 | print $fh "\n"; 754 | 755 | 756 | } 757 | 758 | # turns a set of propositions (target verbs + args for each one) into a set of 759 | # columns in the CoNLL Start-End format 760 | sub props_to_columns { 761 | my ($s, $Pref) = @_; 762 | 763 | my @props = sort { $a->position <=> $b->position } @{$Pref}; 764 | 765 | my $l = $s->length; 766 | my $verbs = []; 767 | my @cols = ( $verbs ); 768 | my $p; 769 | 770 | foreach $p ( @props ) { 771 | defined($verbs->[$p->position]) and die "sentence->preds_to_columns: already defined verb at sentence ", $s->id, " position ", $p->position, "!\n"; 772 | $verbs->[$p->position] = sprintf("%-15s", $p->verb); 773 | 774 | my @tags = $p->to_SE_tagging($l); 775 | push @cols, \@tags; 776 | } 777 | 778 | # finally, define empty verb positions 779 | my $i; 780 | for ($i=0;$i<$l;$i++) { 781 | if (!defined($verbs->[$i])) { 782 | $verbs->[$i] = sprintf("%-15s", "-"); 783 | } 784 | } 785 | 786 | return @cols; 787 | } 788 | 789 | 790 | 791 | # Writes the predicted propositions of the sentence to an output file handler ($fh) 792 | # Specifically, writes a column of target verbs, and a column of arguments 793 | # for each target verb 794 | # OBSOLETE : the same can be done with write_to_stream($s, PRED => 1) 795 | sub write_pred_props { 796 | my ($s, $fh) = @_; 797 | 798 | my @props = sort { $a->position <=> $b->position } $s->pred_props; 799 | 800 | my $l = $s->length; 801 | my @verbs = (); 802 | my @cols = (); 803 | my $p; 804 | 805 | foreach $p ( @props ) { 806 | defined($verbs[$p->position]) and die "prop->write_pred_props: already defined verb at sentence ", $s->id, " position ", $p->position, "!\n"; 807 | $verbs[$p->position] = $p->verb; 808 | 809 | my @tags = $p->to_SE_tagging($l); 810 | push @cols, \@tags; 811 | } 812 | 813 | # finally, print columns word by word 814 | my $i; 815 | for ($i=0;$i<$l;$i++) { 816 | printf $fh "%-15s %s\n", (defined($verbs[$i])? $verbs[$i] : "-"), 817 | join(" ", map { $_->[$i] } @cols); 818 | } 819 | print "\n"; 820 | } 821 | 822 | 823 | 824 | # reads columns until blank line or EOF 825 | # returns an array of columns (each column is a reference to an array containing the column) 826 | # each column in the returned array should be the same size 827 | sub read_columns { 828 | my $fh = shift; 829 | 830 | # read columns until blank line or eof 831 | my @cols; 832 | my $i; 833 | my @line = split(" ", <$fh>); 834 | while (@line) { 835 | for ($i=0; $i<@line; $i++) { 836 | push @{$cols[$i]}, $line[$i]; 837 | } 838 | @line = split(" ", <$fh>); 839 | } 840 | 841 | return \@cols; 842 | } 843 | 844 | 845 | 846 | # reformats the tags of a list of columns, so that each 847 | # column has a fixed width along all tags 848 | # 849 | # 850 | sub reformat_columns { 851 | my $cols = shift; # a reference to the list of columns of a sentence 852 | 853 | my $i; 854 | for ($i=0;$i[$i]); 856 | } 857 | } 858 | 859 | 860 | 861 | # reformats the tags of a column, so that each 862 | # tag has the same width 863 | # 864 | # tag sequences are left justified 865 | # start-end annotations are centered at the asterisk 866 | # 867 | sub column_pretty_format { 868 | my $col = shift; # a reference to the column (array) of tags 869 | 870 | (!@$col) and return undef; 871 | 872 | my ($i); 873 | if ($col->[0] =~ /\*/) { 874 | 875 | # Start-End 876 | my $ok = 1; 877 | 878 | my (@s,@e,$t,$ms,$me); 879 | $ms = 2; $me = 2; 880 | $i = 0; 881 | while ($ok and $i<@$col) { 882 | if ($col->[$i] =~ /^(.*\*)(.*)$/) { 883 | $s[$i] = $1; 884 | $e[$i] = $2; 885 | if (length($s[$i]) > $ms) { 886 | $ms = length($s[$i]); 887 | } 888 | if (length($e[$i]) > $me) { 889 | $me = length($e[$i]); 890 | } 891 | } 892 | else { 893 | # In this case, the current token is not compliant with SE format 894 | # So, we treat format the column as a sequence of tags 895 | $ok = 0; 896 | } 897 | $i++; 898 | } 899 | # print "M $ms $me\n"; 900 | 901 | if ($ok) { 902 | my $f = "%".($ms+1)."s%-".($me+1)."s"; 903 | for ($i=0; $i<@$col; $i++) { 904 | $col->[$i] = sprintf($f, $s[$i], $e[$i]); 905 | } 906 | return; 907 | } 908 | } 909 | 910 | # Tokens 911 | my $l=0; 912 | map { (length($_)>$l) and ($l=length($_)) } @$col; 913 | my $f = "%-".($l+1)."s"; 914 | for ($i=0; $i<@$col; $i++) { 915 | $col->[$i] = sprintf($f,$col->[$i]); 916 | } 917 | 918 | } 919 | 920 | 921 | 922 | 1; 923 | 924 | 925 | 926 | 927 | 928 | 929 | 930 | 931 | ################################################################## 932 | # 933 | # Package p r o p : A proposition (verb + args) 934 | # 935 | # January 2004 936 | # 937 | ################################################################## 938 | 939 | 940 | package SRL::prop; 941 | 942 | use strict; 943 | 944 | 945 | # Constructor: creates a new prop, with empty arguments 946 | # Parameters: verb form, position of verb 947 | sub new { 948 | my ($pkg, $v, $position) = @_; 949 | 950 | my $p = []; 951 | 952 | $p->[0] = $v; # the verb 953 | $p->[1] = $position; # verb position 954 | $p->[2] = undef; # verb sense 955 | $p->[3] = []; # args, empty by default 956 | 957 | return bless $p, $pkg; 958 | } 959 | 960 | ## Accessor/Initialization methods 961 | 962 | # returns the verb form of the prop 963 | sub verb { 964 | my $p = shift; 965 | return $p->[0]; 966 | } 967 | 968 | # returns the verb position of the verb in the prop 969 | sub position { 970 | my $p = shift; 971 | return $p->[1]; 972 | } 973 | 974 | # returns the verb sense of the verb in the prop 975 | sub sense { 976 | my $p = shift; 977 | return $p->[2]; 978 | } 979 | 980 | # initializes the verb sense of the verb in the prop 981 | sub set_sense { 982 | my $p = shift; 983 | $p->[2] = shift; 984 | } 985 | 986 | 987 | # returns the list of arguments of the prop 988 | sub args { 989 | my $p = shift; 990 | return @{$p->[3]}; 991 | } 992 | 993 | # initializes the list of arguments of the prop 994 | sub set_args { 995 | my $p = shift; 996 | @{$p->[3]} = @_; 997 | } 998 | 999 | # adds arguments to the prop 1000 | sub add_args { 1001 | my $p = shift; 1002 | push @{$p->[3]}, @_; 1003 | } 1004 | 1005 | # Returns the list of phrases of the prop 1006 | # Each argument corresponds to one phrase, except for 1007 | # discontinuous arguments, where each piece forms a phrase 1008 | sub phrases { 1009 | my $p = shift; 1010 | return map { $_->single ? $_ : $_->phrases} @{$p->[3]}; 1011 | } 1012 | 1013 | 1014 | ###### Methods 1015 | 1016 | # Adds arguments represented in Start-End tagging 1017 | # Receives a list of Start-End tags (one per word in the sentence) 1018 | # Creates an arg object for each argument in the taggging 1019 | # and modifies the prop so that the arguments are part of it 1020 | # Takes into account special treatment for discontinuous arguments 1021 | sub load_SE_tagging { 1022 | my ($prop, @tags) = @_; 1023 | 1024 | # auxiliar phrase set 1025 | my $set = SRL::phrase_set->new(); 1026 | $set->load_SE_tagging(@tags); 1027 | 1028 | # store args per type, to be able to continue them 1029 | my %ARGS; 1030 | my $a; 1031 | 1032 | # add each phrase as an argument, with special treatment for multi-phrase arguments (A C-A C-A) 1033 | foreach $a ( $set->phrases ) { 1034 | 1035 | # the phrase continues a started arg 1036 | if ($a->type =~ /^C\-/) { 1037 | my $type = $'; # ' 1038 | if (exists($ARGS{$type})) { 1039 | my $pc = $a; 1040 | $a = $ARGS{$type}; 1041 | if ($a->single) { 1042 | # create the head phrase, considered arg until now 1043 | my $ph = SRL::phrase->new($a->start, $a->end, $type); 1044 | $a->add_phrases($ph); 1045 | } 1046 | $a->add_phrases($pc); 1047 | $a->set_end($pc->end); 1048 | } 1049 | else { 1050 | # print STDERR "WARNING : found continuation phrase \"C-$type\" without heading phrase: turned into regular $type argument.\n"; 1051 | # turn the phrase into arg 1052 | bless $a, "SRL::arg"; 1053 | $a->set_type($type); 1054 | push @{$prop->[3]}, $a; 1055 | $ARGS{$a->type} = $a; 1056 | } 1057 | } 1058 | else { 1059 | # turn the phrase into arg 1060 | bless $a, "SRL::arg"; 1061 | push @{$prop->[3]}, $a; 1062 | $ARGS{$a->type} = $a; 1063 | } 1064 | } 1065 | 1066 | } 1067 | 1068 | 1069 | ## discriminates the args of prop $pb wrt the args of prop $pa, returning intersection(a^b), a-b and b-a 1070 | # returns a hash reference containing three lists: 1071 | # $out->{ok} : args in $pa and $pb 1072 | # $out->{ms} : args in $pa and not in $pb 1073 | # $out->{op} : args in $pb and not in $pa 1074 | sub discriminate_args { 1075 | my $pa = shift; 1076 | my $pb = shift; 1077 | my $check_type = @_ ? shift : 1; 1078 | 1079 | my $out = {}; 1080 | !$check_type and @{$out->{eq}} = (); 1081 | @{$out->{ok}} = (); 1082 | @{$out->{ms}} = (); 1083 | @{$out->{op}} = (); 1084 | 1085 | my $a; 1086 | my %ok; 1087 | 1088 | my %ARGS; 1089 | 1090 | foreach $a ($pa->args) { 1091 | $ARGS{$a->start}{$a->end} = $a; 1092 | } 1093 | 1094 | foreach $a ($pb->args) { 1095 | my $s = $a->start; 1096 | my $e = $a->end; 1097 | 1098 | my $gold = $ARGS{$s}{$e}; 1099 | if (!defined($gold)) { 1100 | push @{$out->{op}}, $a; 1101 | } 1102 | elsif ($gold->single and $a->single) { 1103 | if (!$check_type or ($gold->type eq $a->type)) { 1104 | !$check_type and push @{$out->{eq}}, $gold; 1105 | push @{$out->{ok}}, $a; 1106 | delete($ARGS{$s}{$e}); 1107 | } 1108 | else { 1109 | push @{$out->{op}}, $a; 1110 | } 1111 | } 1112 | elsif (!$gold->single and $a->single) { 1113 | push @{$out->{op}}, $a; 1114 | } 1115 | elsif ($gold->single and !$a->single) { 1116 | push @{$out->{op}}, $a; 1117 | } 1118 | else { 1119 | # Check phrases of arg 1120 | my %P; 1121 | my $ok = (!$check_type or ($gold->type eq $a->type)); 1122 | $ok and map { $P{ $_->start.".".$_->end } = 1 } $gold->phrases; 1123 | my @P = $a->phrases; 1124 | while ($ok and @P) { 1125 | my $p = shift @P; 1126 | if ($P{ $p->start.".".$p->end }) { 1127 | delete $P{ $p->start.".".$p->end } 1128 | } 1129 | else { 1130 | $ok = 0; 1131 | } 1132 | } 1133 | if ($ok and !(values %P)) { 1134 | !$check_type and push @{$out->{eq}}, $gold; 1135 | push @{$out->{ok}}, $a; 1136 | delete $ARGS{$s}{$e} 1137 | } 1138 | else { 1139 | push @{$out->{op}}, $a; 1140 | } 1141 | } 1142 | } 1143 | 1144 | my ($s); 1145 | foreach $s ( keys %ARGS ) { 1146 | foreach $a ( values %{$ARGS{$s}} ) { 1147 | push @{$out->{ms}}, $a; 1148 | } 1149 | } 1150 | 1151 | return $out; 1152 | } 1153 | 1154 | 1155 | # Generates a Start-End tagging for the prop arguments 1156 | # Expects the prop object, and l=length of the sentence 1157 | # Returns a list of l tags 1158 | sub to_SE_tagging { 1159 | my $prop = shift; 1160 | my $l = shift; 1161 | my @tags = (); 1162 | 1163 | my ($a, $p); 1164 | foreach $a ( $prop->args ) { 1165 | my $t = $a->type; 1166 | my $cont = 0; 1167 | foreach $p ( $a->single ? $a : $a->phrases ) { 1168 | if (defined($tags[$p->start])) { 1169 | die "prop->to_SE_tagging: Already defined tag in position ", $p->start, "! Prop phrases overlap or embed!\n"; 1170 | } 1171 | if ($p->start != $p->end) { 1172 | $tags[$p->start] = sprintf("%7s", "(".$t)."* "; 1173 | if (defined($tags[$p->end])) { 1174 | die "prop->to_SE_tagging: Already defined tag in position ", $p->end, "! Prop phrases overlap or embed!\n"; 1175 | } 1176 | # $tags[$p->end] = " *".sprintf("%-7s", $t.")"); 1177 | $tags[$p->end] = " *".sprintf("%-3s", ")"); 1178 | } 1179 | else { 1180 | # $tags[$p->start] = sprintf("%7s", "(".$t)."*".sprintf("%-7s", $t.")"); 1181 | $tags[$p->start] = sprintf("%7s", "(".$t)."*".sprintf("%-3s",")"); 1182 | } 1183 | 1184 | if (!$cont) { 1185 | $cont = 1; 1186 | $t = "C-".$t; 1187 | } 1188 | } 1189 | } 1190 | 1191 | my $i; 1192 | for ($i=0; $i<$l; $i++) { 1193 | if (!defined($tags[$i])) { 1194 | $tags[$i] = " * "; 1195 | } 1196 | } 1197 | 1198 | return @tags; 1199 | } 1200 | 1201 | 1202 | # generates a string representing the proposition 1203 | sub to_string { 1204 | my $p = shift; 1205 | 1206 | my $s = "[". $p->verb . "@" . $p->position . ": "; 1207 | $s .= join(" ", map { $_->to_string } $p->args); 1208 | $s .= " ]"; 1209 | 1210 | return $s; 1211 | } 1212 | 1213 | 1214 | 1; 1215 | 1216 | 1217 | ################################################################################ 1218 | # 1219 | # Package p h r a s e _ s e t 1220 | # 1221 | # A set of phrases 1222 | # Each phrase is indexed by (start,end) positions 1223 | # 1224 | # Holds non-overlapping phrase sets. 1225 | # Embedding of phrases allowed and exploited in class methods 1226 | # 1227 | # Brings useful functions on phrase sets, such as: 1228 | # - Load phrases from tag sequences in IOB1, IOB2, Start-End formats 1229 | # - Retrieve a phrase given its (start,end) positions 1230 | # - List phrases found within a given (s,e) segment 1231 | # - Discriminate a predicted set of phrases with respect to the gold set 1232 | # 1233 | ################################################################################ 1234 | 1235 | use strict; 1236 | 1237 | 1238 | package SRL::phrase_set; 1239 | 1240 | ## $phrase_types global variable 1241 | # If defined, contains a hash table specifying the phrase types to be considered 1242 | # If undefined, any phrase type is considered 1243 | my $phrase_types = undef; 1244 | sub set_phrase_types { 1245 | $phrase_types = {}; 1246 | my $t; 1247 | foreach $t ( @_ ) { 1248 | $phrase_types->{$t} = 1; 1249 | } 1250 | } 1251 | 1252 | # Constructor: creates a new phrase set 1253 | # Arguments: an initial set of phrases, which are added to the set 1254 | sub new { 1255 | my ($pkg, @P) = @_; 1256 | my $s = []; 1257 | @{$s->[0]} = (); # NxN half-matrix, storing phrases 1258 | $s->[1] = 0; # N (length of the sentence) 1259 | bless $s, $pkg; 1260 | 1261 | $s->add_phrases(@P); 1262 | 1263 | return $s; 1264 | } 1265 | 1266 | 1267 | # Adds phrases represented in IOB2 tagging 1268 | # Receives a list of IOB2 tags (one per word in the sentence) 1269 | # Creates a phrase object for each phrase in the taggging 1270 | # and modifies the set so that the phrases are part of it 1271 | sub load_IOB2_tagging { 1272 | my ($set, @tags) = @_; 1273 | 1274 | my $wid = 0; # word id 1275 | my $phrase = undef; # current phrase 1276 | my $t; 1277 | foreach $t (@tags) { 1278 | if ($phrase and $t !~ /^I/) { 1279 | $phrase->set_end($wid-1); 1280 | $set->add_phrases($phrase); 1281 | $phrase = undef; 1282 | } 1283 | if ($t =~ /^B-/) { 1284 | my $type = $'; 1285 | if (!defined($phrase_types) or $phrase_types->{$type}) { 1286 | $phrase = SRL::phrase->new($wid); 1287 | $phrase->set_type($type); 1288 | } 1289 | } 1290 | $wid++; 1291 | } 1292 | if ($phrase) { 1293 | $phrase->set_end($wid-1); 1294 | $set->add_phrases($phrase); 1295 | } 1296 | } 1297 | 1298 | 1299 | # Adds phrases represented in IOB1 tagging 1300 | # Receives a list of IOB1 tags (one per word in the sentence) 1301 | # Creates a phrase object for each phrase in the taggging 1302 | # and modifies the set so that the phrases are part of it 1303 | sub load_IOB1_tagging { 1304 | my ($set, @tags) = @_; 1305 | 1306 | my $wid = 0; # word id 1307 | my $phrase = undef; # current phrase 1308 | my $t = shift @tags; 1309 | while (defined($t)) { 1310 | if ($t =~ /^[BI]-/) { 1311 | my $type = $'; 1312 | if (!defined($phrase_types) or $phrase_types->{$type}) { 1313 | $phrase = SRL::phrase->new($wid); 1314 | $phrase->set_type($type); 1315 | my $tag = "I-".$type; 1316 | $t = shift @tags; 1317 | $wid++; 1318 | while ($t eq $tag) { 1319 | $t = shift @tags; 1320 | $wid++; 1321 | } 1322 | $phrase->set_end($wid-1); 1323 | $set->add_phrases($phrase); 1324 | } 1325 | else { 1326 | $t = shift @tags; 1327 | $wid++; 1328 | } 1329 | } 1330 | else { 1331 | $t = shift @tags; 1332 | $wid++; 1333 | } 1334 | } 1335 | } 1336 | 1337 | # Adds phrases represented in Start-End tagging 1338 | # Receives a list of Start-End tags (one per word in the sentence) 1339 | # Creates a phrase object for each phrase in the taggging 1340 | # and modifies the set so that the phrases are part of it 1341 | sub load_SE_tagging { 1342 | my ($set, @tags) = @_; 1343 | 1344 | my (@SP); # started phrases 1345 | my $wid = 0; 1346 | my ($tag, $p); 1347 | foreach $tag ( @tags ) { 1348 | while ($tag !~ /^\*/) { 1349 | $tag =~ /^\(((\\\*|[^*(])+)/ or die "phrase_set->load_SE_tagging: opening nodes -- bad format in $tag at $wid-th position!\n"; 1350 | my $type = $1; 1351 | $tag = $'; 1352 | if (!defined($phrase_types) or $phrase_types->{$type}) { 1353 | $p = SRL::phrase->new($wid); 1354 | $p->set_type($type); 1355 | push @SP, $p; 1356 | } 1357 | } 1358 | $tag =~ s/^\*//; 1359 | while ($tag ne "") { 1360 | $tag =~ /^([^\)]*)\)/ or die "phrase_set->load_SE_tagging: closing phrases -- bad format in $tag!\n"; 1361 | my $type = $1; 1362 | $tag = $'; 1363 | if (!$type or !defined($phrase_types) or $phrase_types->{$type}) { 1364 | $p = pop @SP; 1365 | (!$type) or ($type eq $p->type) or die "phrase_set->load_SE_tagging: types do not match!\n"; 1366 | $p->set_end($wid); 1367 | 1368 | if (@SP) { 1369 | $SP[$#SP]->add_phrases($p); 1370 | } 1371 | else { 1372 | $set->add_phrases($p); 1373 | } 1374 | } 1375 | } 1376 | $wid++; 1377 | } 1378 | (!@SP) or die "phrase_set->load_SE_tagging: some phrases are unclosed!\n"; 1379 | } 1380 | 1381 | 1382 | sub refs_start_end_tags { 1383 | my ($s, $l) = @_; 1384 | 1385 | my (@S,@E,$i); 1386 | for ($i=0; $i<$l; $i++) { 1387 | $S[$i] = ""; 1388 | $E[$i] = ""; 1389 | } 1390 | 1391 | my $p; 1392 | foreach $p ( $s->phrases ) { 1393 | $S[$p->start] .= "(".$p->type; 1394 | # $E[$p->end] = $E[$p->end].$p->type.")"; 1395 | $E[$p->end] .= ")"; 1396 | } 1397 | 1398 | return (\@S,\@E); 1399 | } 1400 | 1401 | 1402 | sub to_SE_tagging { 1403 | my ($s, $l) = @_; 1404 | 1405 | # my (@S,@E,$i); 1406 | # for ($i=0; $i<$l; $i++) { 1407 | # $S[$i] = ""; 1408 | # $E[$i] = ""; 1409 | # } 1410 | 1411 | # my $p; 1412 | # foreach $p ( $s->phrases ) { 1413 | # $S[$p->start] .= "(".$p->type; 1414 | # # $E[$p->end] = $E[$p->end].$p->type.")"; 1415 | # $E[$p->end] .= ")"; 1416 | # } 1417 | 1418 | my ($S,$E) = refs_start_end_tags($s,$l); 1419 | 1420 | my $i; 1421 | my @tags; 1422 | for ($i=0; $i<$l; $i++) { 1423 | # $tags[$i] = sprintf("%8s*%-12s", $S->[$i], $E->[$i]); 1424 | $tags[$i] = sprintf("%8s*%-5s", $S->[$i], $E->[$i]); 1425 | } 1426 | return @tags; 1427 | } 1428 | 1429 | 1430 | sub to_IOB2_tagging { 1431 | my ($s, $l) = @_; 1432 | 1433 | my (@tags,$p,$i); 1434 | 1435 | foreach $p ( $s->phrases ) { 1436 | my $tag = $p->type; 1437 | $i = $p->start; 1438 | $tags[$i] and $tags[$i] .= "/"; 1439 | $tags[$i] .= "B-".$tag; 1440 | $i++; 1441 | while ($i<=$p->end) { 1442 | $tags[$i] and $tags[$i] .= "/"; 1443 | $tags[$i] .= "I-".$tag; 1444 | $i++; 1445 | } 1446 | } 1447 | for ($i=0; $i<$l; $i++) { 1448 | if (!defined($tags[$i])) { 1449 | $tags[$i] = "O "; 1450 | } 1451 | else { 1452 | $tags[$i] = sprintf("%-6s", $tags[$i]); 1453 | } 1454 | } 1455 | return @tags; 1456 | } 1457 | 1458 | 1459 | # ------------------------------------------------------------ 1460 | 1461 | # Adds phrases in the set, recursively (ie. internal phrases are also added) 1462 | sub add_phrases { 1463 | my ($s, @P) = @_; 1464 | my $ph; 1465 | foreach $ph ( map { $_->dfs } @P ) { 1466 | $s->[0][$ph->start][$ph->end] = $ph; 1467 | if ($ph->end >= $s->[1]) { 1468 | $s->[1] = $ph->end +1; 1469 | } 1470 | } 1471 | } 1472 | 1473 | # returns the number of phrases in the set 1474 | sub size { 1475 | my $set = shift; 1476 | 1477 | my ($i,$j); 1478 | my $n; 1479 | for ($i=0; $i<@{$set->[0]}; $i++) { 1480 | if (defined($set->[0][$i])) { 1481 | for ($j=$i; $j<@{$set->[0][$i]}; $j++) { 1482 | if (defined($set->[0][$i][$j])) { 1483 | $n++; 1484 | } 1485 | } 1486 | } 1487 | } 1488 | return $n; 1489 | } 1490 | 1491 | # returns the phrase starting at word position $s and ending at $e 1492 | # or undef if it doesn't exist 1493 | sub phrase { 1494 | my ($set, $s, $e) = @_; 1495 | return $set->[0][$s][$e]; 1496 | } 1497 | 1498 | 1499 | # Returns phrases in the set, recursively in depth first search order 1500 | # that is, if a phrase is returned, all its subphrases are also returned 1501 | # If no parameters, returns all phrases 1502 | # If a pair of positions is given ($s,$e), returns phrases included 1503 | # within the $s and $e positions 1504 | sub phrases { 1505 | my $set = shift; 1506 | my ($s, $e); 1507 | if (!@_) { 1508 | $s = 0; 1509 | $e = $set->[1]-1; 1510 | } 1511 | else { 1512 | ($s,$e) = @_; 1513 | } 1514 | my ($i,$j); 1515 | my @P = (); 1516 | for ($i=$s;$i<=$e;$i++) { 1517 | if (defined($set->[0][$i])) { 1518 | for ($j=$e;$j>=$i;$j--) { 1519 | if (defined($set->[0][$i][$j])) { 1520 | push @P, $set->[0][$i][$j]; 1521 | } 1522 | } 1523 | } 1524 | } 1525 | return @P; 1526 | } 1527 | 1528 | 1529 | # Returns phrases in the set, non-recursively in sequential order 1530 | # that is, if a phrase is returned, its subphrases are not returned 1531 | # If no parameters, returns all phrases 1532 | # If a pair of positions is given ($s,$e), returns phrases included 1533 | # within the $s and $e positions 1534 | sub top_phrases { 1535 | my $set = shift; 1536 | my ($s, $e); 1537 | if (!@_) { 1538 | $s = 0; 1539 | $e = $set->[1]-1; 1540 | } 1541 | else { 1542 | ($s,$e) = @_; 1543 | } 1544 | my ($i,$j); 1545 | my @P = (); 1546 | $i = $s; 1547 | while ($i<=$e) { 1548 | $j=$e; 1549 | while ($j>=$s) { 1550 | if (defined($set->[0][$i][$j])) { 1551 | push @P, $set->[0][$i][$j]; 1552 | $i=$j; 1553 | $j=-1; 1554 | } 1555 | else { 1556 | $j--; 1557 | } 1558 | } 1559 | $i++; 1560 | } 1561 | return @P; 1562 | } 1563 | 1564 | 1565 | # returns the phrases which contain the terminal $wid, in bottom-up order 1566 | sub ancestors { 1567 | my ($set, $wid) = @_; 1568 | 1569 | my @A; 1570 | my $N = $set->[1]; 1571 | 1572 | my ($s,$e); 1573 | 1574 | for ($s = $wid; $s>=0; $s--) { 1575 | if (defined($set->[0][$s])) { 1576 | for ($e = $wid; $e<$N; $e++) { 1577 | if (defined($set->[0][$s][$e])) { 1578 | push @A, $set->[0][$s][$e]; 1579 | } 1580 | } 1581 | } 1582 | } 1583 | 1584 | return @A; 1585 | } 1586 | 1587 | 1588 | # returns a TRUE value if the phrase $p ovelaps with some phrase in 1589 | # the set; the returned value is the reference to the conflicting phrase 1590 | # returns FALSE otherwise 1591 | sub check_overlapping { 1592 | my ($set, $p) = @_; 1593 | 1594 | my ($s,$e); 1595 | for ($s=0; $s<$p->start; $s++) { 1596 | if (defined($set->[0][$s])) { 1597 | for ($e=$p->start; $e<$p->end; $e++) { 1598 | if (defined($set->[0][$s][$e])) { 1599 | return $set->[0][$s][$e]; 1600 | } 1601 | } 1602 | } 1603 | } 1604 | for ($s=$p->start+1; $s<=$p->end; $s++) { 1605 | if (defined($set->[0][$s])) { 1606 | for ($e=$p->end+1; $e<$set->[1]; $e++) { 1607 | if (defined($set->[0][$s][$e])) { 1608 | return $set->[0][$s][$e]; 1609 | } 1610 | } 1611 | } 1612 | } 1613 | 1614 | return 0; 1615 | } 1616 | 1617 | 1618 | ## ---------------------------------------- 1619 | 1620 | # Discriminates a set of phrases (s1) wrt the current set (s0), returning 1621 | # intersection (s0^s1), over-predicted (s1-s0) and missed (s0-s1) 1622 | # Returns a hash reference containing three lists: 1623 | # $out->{ok} : phrases in $s0 and $1 1624 | # $out->{op} : phrases in $s1 and not in $0 1625 | # $out->{ms} : phrases in $s0 and not in $1 1626 | sub discriminate { 1627 | my ($s0, $s1) = @_; 1628 | 1629 | my $out; 1630 | @{$out->{ok}} = (); 1631 | @{$out->{ms}} = (); 1632 | @{$out->{op}} = (); 1633 | 1634 | my $ph; 1635 | my %ok; 1636 | 1637 | foreach $ph ($s1->phrases) { 1638 | my $s = $ph->start; 1639 | my $e = $ph->end; 1640 | 1641 | my $gph = $s0->phrase($s,$e); 1642 | if ($gph and $gph->type eq $ph->type) { 1643 | # correct 1644 | $ok{$s}{$e} = 1; 1645 | push @{$out->{ok}}, $ph; 1646 | } 1647 | else { 1648 | # overpredicted 1649 | push @{$out->{op}}, $ph; 1650 | } 1651 | } 1652 | 1653 | foreach $ph ($s0->phrases) { 1654 | my $s = $ph->start; 1655 | my $e = $ph->end; 1656 | 1657 | if (!$ok{$s}{$e}) { 1658 | # missed 1659 | push @{$out->{ms}}, $ph; 1660 | } 1661 | } 1662 | return $out; 1663 | } 1664 | 1665 | 1666 | # compares the current set (s0) to another set (s1) 1667 | # returns the number of correct, missed an over-predicted phrases 1668 | sub evaluation { 1669 | my ($s0, $s1) = @_; 1670 | 1671 | my $o = $s0->discriminate($s1); 1672 | 1673 | my %e; 1674 | $e{ok} = scalar(@{$o->{ok}}); 1675 | $e{op} = scalar(@{$o->{op}}); 1676 | $e{ms} = scalar(@{$o->{ms}}); 1677 | 1678 | return %e; 1679 | } 1680 | 1681 | 1682 | # generates a string representing the phrase set, 1683 | # for printing purposes 1684 | sub to_string { 1685 | my $s = shift; 1686 | return join(" ", map { $_->to_string } $s->top_phrases); 1687 | } 1688 | 1689 | 1690 | 1; 1691 | 1692 | 1693 | 1694 | 1695 | 1696 | 1697 | 1698 | 1699 | 1700 | 1701 | 1702 | 1703 | 1704 | 1705 | 1706 | 1707 | ################################################################## 1708 | # 1709 | # Package p h r a s e : a generic phrase 1710 | # 1711 | # January 2004 1712 | # 1713 | # This class represents generic phrases. 1714 | # A phrase is a sequence of contiguous words in a sentence. 1715 | # A phrase is identified by the positions of the start/end words 1716 | # of the sequence that the phrase spans. 1717 | # A phrase has a type. 1718 | # A phrase may contain a list of internal subphrases, that is, 1719 | # phrases found within the phrase. Thus, a phrase object is seen 1720 | # eventually as a hierarchical structure. 1721 | # 1722 | # A syntactic base chunk is a phrase with no internal phrases. 1723 | # A clause is a phrase which may have internal phrases 1724 | # A proposition argument is implemented as a special class which 1725 | # inherits from the phrase class. 1726 | # 1727 | ################################################################## 1728 | 1729 | use strict; 1730 | 1731 | package SRL::phrase; 1732 | 1733 | # Constructor: creates a new phrase 1734 | # Parameters: start position, end position and type 1735 | sub new { 1736 | my $pkg = shift; 1737 | 1738 | my $ph = []; 1739 | 1740 | # start word index 1741 | $ph->[0] = (@_) ? shift : undef; 1742 | # end word index 1743 | $ph->[1] = (@_) ? shift : undef; 1744 | # phrase type 1745 | $ph->[2] = (@_) ? shift : undef; 1746 | # 1747 | @{$ph->[3]} = (); 1748 | 1749 | return bless $ph, $pkg; 1750 | } 1751 | 1752 | # returns the start position of the phrase 1753 | sub start { 1754 | my $ph = shift; 1755 | return $ph->[0]; 1756 | } 1757 | 1758 | # initializes the start position of the phrase 1759 | sub set_start { 1760 | my $ph = shift; 1761 | $ph->[0] = shift; 1762 | } 1763 | 1764 | # returns the end position of the phrase 1765 | sub end { 1766 | my $ph = shift; 1767 | return $ph->[1]; 1768 | } 1769 | 1770 | # initializes the end position of the phrase 1771 | sub set_end { 1772 | my $ph = shift; 1773 | $ph->[1] = shift; 1774 | } 1775 | 1776 | # returns the type of the phrase 1777 | sub type { 1778 | my $ph = shift; 1779 | return $ph->[2]; 1780 | } 1781 | 1782 | # initializes the type of the phrase 1783 | sub set_type { 1784 | my $ph = shift; 1785 | $ph->[2] = shift; 1786 | } 1787 | 1788 | # returns the subphrases of the current phrase 1789 | sub phrases { 1790 | my $ph = shift; 1791 | return @{$ph->[3]}; 1792 | } 1793 | 1794 | # adds phrases as subphrases 1795 | sub add_phrases { 1796 | my $ph = shift; 1797 | push @{$ph->[3]}, @_; 1798 | } 1799 | 1800 | # initializes the set of subphrases 1801 | sub set_phrases { 1802 | my $ph = shift; 1803 | @{$ph->[3]} = @_; 1804 | } 1805 | 1806 | 1807 | # depth first search 1808 | # returns the phrases rooted int the current phrase in dfs order 1809 | sub dfs { 1810 | my $ph = shift; 1811 | return ($ph, map { $_->dfs } $ph->phrases); 1812 | } 1813 | 1814 | 1815 | # generates a string representing the phrase (and subphrases if arg is a TRUE value), for printing 1816 | sub to_string { 1817 | my $ph = shift; 1818 | my $rec = ( @_ ) ? shift : 1; 1819 | 1820 | my $str = "(" . $ph->start . " "; 1821 | 1822 | $rec and map { $str .= $_->to_string." " } $ph->phrases; 1823 | 1824 | $str .= $ph->end . ")"; 1825 | if (defined($ph->type)) { 1826 | $str .= "_".$ph->type; 1827 | } 1828 | return $str; 1829 | } 1830 | 1831 | 1832 | 1; 1833 | 1834 | ################################################################## 1835 | # 1836 | # Package a r g : An argument 1837 | # 1838 | # January 2004 1839 | # 1840 | # This class inherits from the class "phrase". 1841 | # An argument is identified by start-end positions of the 1842 | # string spanned by the argument in the sentence. 1843 | # An argument has a type. 1844 | # 1845 | # Most of the arguments consist of a single phrase; in this 1846 | # case the argument and the phrase objects are the same. 1847 | # 1848 | # In the special case of discontinuous arguments, the argument 1849 | # is an "arg" object which contains a number of phrases (one 1850 | # for each discontinuous piece). Then, the argument spans from 1851 | # the start word of its first phrase to the end word of its last 1852 | # phrase. As for the composing phrases, the type of the first one 1853 | # is the type of the argument, say A, whereas the type of the 1854 | # subsequent phrases is "C-A" (continuation tag). 1855 | # 1856 | ################################################################## 1857 | 1858 | package SRL::arg; 1859 | 1860 | use strict; 1861 | 1862 | #push @SRL::arg::ISA, 'SRL::phrase'; 1863 | use base qw(SRL::phrase); 1864 | 1865 | 1866 | # Constructor "new" inherited from SRL::phrase 1867 | 1868 | # Checks whether the argument is single (returning true) 1869 | # or discontinuous (returning false) 1870 | sub single { 1871 | my ($a) = @_; 1872 | return scalar(@{$a->[3]}==0); 1873 | } 1874 | 1875 | # Generates a string representing the argument 1876 | sub to_string { 1877 | my $a = shift; 1878 | 1879 | my $s = $a->type."_(" . $a->start . " "; 1880 | map { $s .= $_->to_string." " } $a->phrases; 1881 | $s .= $a->end . ")"; 1882 | 1883 | return $s; 1884 | } 1885 | 1886 | 1887 | 1; 1888 | 1889 | 1890 | 1891 | 1892 | 1893 | 1894 | 1895 | 1896 | 1897 | ################################################################## 1898 | # 1899 | # Package w o r d : a word 1900 | # 1901 | # April 2004 1902 | # 1903 | # A word, containing id (position in sentence), form and PoS tag 1904 | # 1905 | ################################################################## 1906 | 1907 | use strict; 1908 | 1909 | package SRL::word; 1910 | 1911 | # Constructor: creates a new word 1912 | # Parameters: id (position), form and PoS tag 1913 | sub new { 1914 | my ($pkg, @fields) = @_; 1915 | 1916 | my $w = []; 1917 | 1918 | $w->[0] = shift @fields; # id (position in sentence) 1919 | $w->[1] = shift @fields; # form 1920 | $w->[2] = shift @fields; # PoS 1921 | 1922 | return bless $w, $pkg; 1923 | } 1924 | 1925 | # returns the id of the word 1926 | sub id { 1927 | my $w = shift; 1928 | return $w->[0]; 1929 | } 1930 | 1931 | # returns the form of the word 1932 | sub form { 1933 | my $w = shift; 1934 | return $w->[1]; 1935 | } 1936 | 1937 | # returns the PoS tag of the word 1938 | sub pos { 1939 | my $w = shift; 1940 | return $w->[2]; 1941 | } 1942 | 1943 | sub to_string { 1944 | my $w = shift; 1945 | return "w@".$w->[0].":".$w->[1].":".$w->[2]; 1946 | } 1947 | 1948 | 1; 1949 | 1950 | 1951 | 1952 | 1953 | 1954 | -------------------------------------------------------------------------------- /research/iohandler/SRLToDoc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from research.libnlp import Document, SemanticRole 4 | 5 | class SRLReader: 6 | TRAIN = 0 7 | TEST = 1 8 | 9 | def __init__(self, path, logger, mode): 10 | self.path = path 11 | self.mode = mode 12 | self.logger = logger 13 | self.documents = dict() 14 | 15 | def read_file(self): 16 | try: 17 | reader_object = open(self.path) 18 | for doc_id, line in enumerate(reader_object.readlines()): 19 | text = ' '.join(line.split()[1:]) 20 | doc_text = text.split("|||")[0] 21 | labels = text.split("|||")[1].split() 22 | d = Document.Document(doc_text, doc_id) 23 | s = SemanticRole.SemanticRole(doc_text.split(), labels) 24 | d.linkSemanticRelation(s) 25 | self.documents[doc_id] = d 26 | reader_object.close() 27 | return self.documents 28 | except IOError: 29 | sys.exit("File not found at "+self.path) 30 | 31 | def realign_data(actual_tokens, class_list, bert_tokens): 32 | actual_tokens = [i.lower() for i in actual_tokens] 33 | new_class_list = list() 34 | i = 0 35 | j = 0 36 | while j < len(bert_tokens): 37 | a_token = actual_tokens[i] 38 | b_token = bert_tokens[j] 39 | token_c = class_list[i] 40 | if b_token == a_token: 41 | new_class_list.append(token_c) 42 | i += 1 43 | j += 1 44 | else: 45 | b_token = b_token.replace('#', '') 46 | while b_token in a_token: 47 | new_class_list.append(token_c) 48 | a_token = a_token.replace(b_token, '', 1) 49 | j += 1 50 | if j < len(bert_tokens): 51 | b_token = bert_tokens[j].replace('#', '') 52 | else: 53 | break 54 | i += 1 55 | return new_class_list -------------------------------------------------------------------------------- /research/iohandler/SemEvalToDoc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | from research.libnlp import Document, SemanticRelation 5 | 6 | 7 | class SemEvalReader: 8 | TRAIN = 0 9 | TEST = 1 10 | 11 | def __init__(self, path, logger, mode): 12 | self.path = path 13 | self.mode = mode 14 | self.logger = logger 15 | self.documents = dict() 16 | 17 | 18 | def get_formatted_text(self, text, replace_by_mask = False, predictor = None): 19 | text = text[1:-1] # removing double quotes from the beginning and end of sentence 20 | e1 = re.search("(.*?)", text) 21 | e1_text = e1.group().replace("", "").replace("", "") 22 | text = replace_by_entity_mask(text, e1_text, 1, replace_by_mask = replace_by_mask, predictor= predictor) 23 | e2 = re.search("(.*?)", text) 24 | e2_text = e2.group().replace("", "").replace("", "") 25 | text = replace_by_entity_mask(text, e2_text, 2, replace_by_mask = replace_by_mask, predictor= predictor) 26 | return text, e1_text, e2_text, e1.start(), e2.start() 27 | 28 | def read_file(self, predictor = None): 29 | try: 30 | reader_object = open(self.path) 31 | try: 32 | for sample in reader_object.read().split("\n\n"): 33 | lines = sample.split("\n") 34 | doc_id, dtext = lines[0].split("\t") 35 | document_text, token1, token2, token1_offset, token2_offset = self.get_formatted_text(dtext, replace_by_mask=True, predictor=predictor) 36 | d = Document.Document(document_text, int(doc_id)) 37 | relation_name, relation_direction = getSemanticRelation(lines[1], self.logger) 38 | sr = SemanticRelation.SemanticRelation(token1, token2, relation_name, relation_direction) 39 | d.linkSemanticRelation(sr) 40 | self.documents[doc_id] = d 41 | except ValueError: 42 | self.logger.info("Reached end of file") 43 | reader_object.close() 44 | return self.documents 45 | except IOError: 46 | self.logger.error("File not found at path " + self.path) 47 | sys.exit("File not found at path " + self.path) 48 | 49 | 50 | def getSemanticRelation(text, logger): 51 | relation_name = text.split("(")[0] 52 | if "(e1,e2)" in text: 53 | relation_direction = 0 54 | elif "(e2,e1)" in text: 55 | relation_direction = 1 56 | else: 57 | relation_direction = SemanticRelation.SemanticRelation.NO_DIRECTION 58 | logger.info("Relation " + relation_name + " found with directionality " + str(relation_direction)) 59 | return relation_name, relation_direction 60 | 61 | def replace_by_entity_mask(text, ent_text, ent_mention, replace_by_mask = False, predictor = None): 62 | if ent_mention == 1: 63 | e = "" 64 | e_ = "" 65 | pattern = r"(.*?)" 66 | else: 67 | e = r"" 68 | e_ = r"" 69 | pattern = r"(.*?)" 70 | if replace_by_mask is False: 71 | text = text.replace(e, " ").replace(e_, " ") 72 | else: 73 | if predictor is None: 74 | text = re.sub(pattern, " [MASK] ", text) 75 | else: 76 | ent_tokens = ent_text.split() 77 | doc = predictor(text) 78 | ent = " [ O ] " 79 | for token in doc.ents: # TODO: Fix based on offset values 80 | if token.text == ent_tokens[0]: 81 | ent = " [ " + token.label_ + " ] " 82 | text = re.sub(pattern, ent, text) 83 | return text 84 | -------------------------------------------------------------------------------- /research/iohandler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/iohandler/__init__.py -------------------------------------------------------------------------------- /research/iohandler/part_whole_reader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from research.libnlp import Document, SemanticRelation 4 | 5 | 6 | class PartWholeReader: 7 | TRAIN = 0 8 | TEST = 1 9 | 10 | def __init__(self, path, logger, mode): 11 | self.path = path 12 | self.mode = mode 13 | self.logger = logger 14 | self.documents = dict() 15 | 16 | def read_file(self, predictor=None): 17 | try: 18 | reader_object = open(self.path) 19 | for doc_id, line in enumerate(reader_object.readlines()): 20 | token1 = line.split("\t")[0] 21 | token2 = line.split("\t")[2] 22 | sentence = line.split("\t")[6] 23 | relation_name = line.split("\t")[8] 24 | if predictor is not None: 25 | sentence = replace_by_entity_mask(token1, sentence, predictor) 26 | sentence = replace_by_entity_mask(token2, sentence, predictor) 27 | d = Document.Document(sentence, doc_id) 28 | sr = SemanticRelation.SemanticRelation(token1, token2, relation_name, SemanticRelation.SemanticRelation.NO_DIRECTION) 29 | d.linkSemanticRelation(sr) 30 | self.logger.info("Relation " + relation_name + " found!") 31 | self.documents[doc_id] = d 32 | reader_object.close() 33 | return self.documents 34 | except IOError: 35 | self.logger.error("File not found at path " + self.path) 36 | sys.exit("File not found at path " + self.path) 37 | 38 | def replace_by_entity_mask(entity, sentence, predictor): 39 | if predictor is None: 40 | sentence = sentence.replace(entity, " [MASK] ") 41 | else: 42 | ent_tokens = sentence.split() 43 | doc = predictor(sentence) 44 | ent = " [ O ] " 45 | for token in doc.ents: # TODO: Fix based on offset values 46 | if token.text == ent_tokens[0]: 47 | ent = " [ " + token.label_ + " ] " 48 | sentence = sentence.replace(entity, ent) 49 | return sentence 50 | -------------------------------------------------------------------------------- /research/iohandler/sampler.py: -------------------------------------------------------------------------------- 1 | import random -------------------------------------------------------------------------------- /research/libnlp/Document.py: -------------------------------------------------------------------------------- 1 | class Document: 2 | def __init__(self, text, doc_id): 3 | self.tokens = list() 4 | self.doc_id = doc_id 5 | self.text = text 6 | self.sr = None 7 | self.input_features = None 8 | 9 | def linkTextToDoc(self, text): 10 | self.text = text 11 | 12 | def linkSemanticRelation(self, sr): 13 | self.sr = sr 14 | 15 | def linkTokenIDs(self, input_features): 16 | self.input_features = input_features 17 | -------------------------------------------------------------------------------- /research/libnlp/SemanticRelation.py: -------------------------------------------------------------------------------- 1 | class SemanticRelation: 2 | 3 | NO_DIRECTION = 2 # indicates no directionality 4 | 5 | def __init__(self, token1, token2, sr, dir): 6 | self.token1 = token1 7 | self.token2 = token2 8 | self.sr = sr 9 | self.dir = dir 10 | -------------------------------------------------------------------------------- /research/libnlp/SemanticRole.py: -------------------------------------------------------------------------------- 1 | class SemanticRole: 2 | def __init__(self, tokens, labels): 3 | self.tokens = tokens 4 | self.labels = labels 5 | 6 | def get_verb(self): 7 | for i, label in enumerate(self.labels): 8 | if 'B-V' in label: 9 | verb_start = i 10 | verb_end = verb_start 11 | if 'I-V' in label: 12 | verb_end = i 13 | return ' '.join(self.tokens[verb_start:verb_end + 1]) -------------------------------------------------------------------------------- /research/libnlp/Token.py: -------------------------------------------------------------------------------- 1 | class Token: 2 | def __init__(self, text, start, named_entity): 3 | self.text = text 4 | self.start = start 5 | self.ne = named_entity 6 | -------------------------------------------------------------------------------- /research/libnlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/libnlp/__init__.py -------------------------------------------------------------------------------- /research/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | 5 | import torch 6 | import spacy 7 | 8 | from research.document_processor.Encoder import Encoder 9 | from research.iohandler import SemEvalToDoc, SRLToDoc 10 | from research.tester.test_model import test_semantic_relation_model, test_semantic_role_model 11 | from research.trainer import semantic_relation_train_model, semantic_role_train_model 12 | 13 | argument_parser = argparse.ArgumentParser() 14 | argument_parser.add_argument("-tr", "--path_to_training_file", help="Provide path to training file") 15 | argument_parser.add_argument("-te", "--path_to_test_file", help="Provide path to test file") 16 | argument_parser.add_argument("-log", "--path_to_log_file", help="Provide path to log file") 17 | argument_parser.add_argument("-model", "--model_type", help="Provide model type") 18 | argument_parser.add_argument("-task", "--task_type", help="Provide task type") 19 | argument_parser.add_argument("-max_len", "--max_len", help="Provide maximum sequence length") 20 | argument_parser.add_argument("-b", "--batch_size", help="Provide batch size to work with") 21 | argument_parser.add_argument("-lr", "--learning_rate", help="Provide learning rate to work with") 22 | argument_parser.add_argument("-epochs", "--epochs", help="Number of training epochs") 23 | argument_parser.add_argument("-model_path", "--model_path", help="Path where you want to save the model") 24 | argument_parser.add_argument("-true_file", "--true_file", help="Path where you want to save the true relations") 25 | argument_parser.add_argument("-prediction_file", "--prediction_file", 26 | help="Path where you want to save the predictions") 27 | argument_parser.add_argument("-perl", "--perl_eval_script", help="Path to the Perl evaluation script") 28 | 29 | parse = argument_parser.parse_args() 30 | 31 | # set up GPU 32 | torch.manual_seed(0) 33 | torch.cuda.set_device(0) 34 | device = torch.device('cuda:0') # TODO: ensure default device is first item in list 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | # set up logging session 39 | logging.basicConfig(filename=parse.path_to_log_file, format='%(asctime)s %(message)s', filemode='w') 40 | logger = logging.getLogger() 41 | logger.setLevel(logging.DEBUG) 42 | 43 | # declare objects: 44 | 45 | if parse.task_type == "0": 46 | s = Encoder(int(parse.model_type), int(parse.task_type), logger, int(parse.max_len)) # TODO: Change to arg parse parameter 47 | train_reader = SemEvalToDoc.SemEvalReader(parse.path_to_training_file, logger, SemEvalToDoc.SemEvalReader.TRAIN) 48 | test_reader = SemEvalToDoc.SemEvalReader(parse.path_to_test_file, logger, SemEvalToDoc.SemEvalReader.TEST) 49 | 50 | # go through pipeline: 51 | 52 | training_docs = train_reader.read_file(predictor=al) 53 | training_docs = s.encode_text(training_docs) 54 | test_docs = test_reader.read_file(predictor=al) 55 | test_docs = s.encode_text(test_docs) 56 | 57 | model, sr_dict = semantic_relation_train_model.train_model(training_docs, int(parse.batch_size), s.model, device, float(parse.learning_rate), 58 | int(parse.epochs), logger) 59 | 60 | test_semantic_relation_model(test_docs, sr_dict, int(parse.batch_size), model, device, parse.true_file, parse.prediction_file, logger) 61 | 62 | if parse.task_type == '1': 63 | 64 | s = Encoder(int(parse.model_type), int(parse.task_type), logger, int(parse.max_len)) 65 | train_reader = SRLToDoc.SRLReader(parse.path_to_training_file, logger, SRLToDoc.SRLReader.TRAIN) 66 | test_reader = SRLToDoc.SRLReader(parse.path_to_test_file, logger, SRLToDoc.SRLReader.TEST) 67 | 68 | training_docs = train_reader.read_file() 69 | training_docs = s.encode_text(training_docs) 70 | test_docs = test_reader.read_file() 71 | test_docs = s.encode_text(test_docs) 72 | 73 | model, sr_dict = semantic_role_train_model.train_model(training_docs, int(parse.batch_size), s.model, device, float(parse.learning_rate), int(parse.epochs), logger) 74 | test_semantic_role_model(test_docs, sr_dict, int(parse.batch_size), model, device, None, None, logger) 75 | 76 | else: 77 | 78 | sys.exit("Incorrect/Unsupported task type") -------------------------------------------------------------------------------- /research/models/BERTModels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from research.libnlp.SemanticRelation import SemanticRelation 5 | 6 | 7 | def custom_loss_fct(loss_fct, predicted, labels, odd_one=SemanticRelation.NO_DIRECTION): 8 | true, pred = list(), list() 9 | for i, label in enumerate(labels): 10 | if label != odd_one: 11 | true.append(label) 12 | pred.append(predicted[i]) 13 | try: 14 | return loss_fct(torch.stack(pred), torch.stack(true)) 15 | except RuntimeError: # if either list is empty 16 | return 0 17 | 18 | 19 | class BERTForRelationClassification(nn.Module): 20 | def __init__(self, lm, num_relations, logger, num_directions=2): 21 | super(BERTForRelationClassification, self).__init__() 22 | self.language_model = lm 23 | self.relation_classifier = nn.Linear(770, num_relations) # TODO: get from config, for small or large 24 | self.direction_classifier = nn.Linear(770, num_directions) 25 | self.num_relations = num_relations 26 | self.num_directions = num_directions 27 | self.logger = logger 28 | 29 | def forward(self, input_ids, input_segments, input_masks, position_vector1 = None, position_vector2 = None, relation_labels=None, direction_labels=None): 30 | pooled_output = self.language_model(input_ids, token_type_ids = input_segments, attention_mask = input_masks)[0] 31 | if position_vector1 is not None: 32 | pooled_output = torch.cat((pooled_output, position_vector1.float().unsqueeze(-1), position_vector2.float().unsqueeze(-1)), -1) 33 | pooled_output = pooled_output[:,-1] 34 | predicted_relations = self.relation_classifier(pooled_output) 35 | predicted_directions = self.direction_classifier(pooled_output) 36 | if relation_labels is not None: 37 | loss_fct = nn.CrossEntropyLoss() 38 | relation_loss = loss_fct(predicted_relations.view(-1, self.num_relations), relation_labels.view(-1)) 39 | direction_loss = custom_loss_fct(loss_fct, predicted_directions, direction_labels) 40 | return relation_loss + direction_loss 41 | else: 42 | return predicted_relations, predicted_directions 43 | 44 | class BERTForRoleLabeling(nn.Module): 45 | def __init__(self, lm, num_relations, logger): 46 | super(BERTForRoleLabeling, self).__init__() 47 | self.language_model = lm 48 | self.relation_classifier = nn.Linear(769, num_relations) # TODO: get from config, for small or large 49 | self.num_relations = num_relations 50 | self.logger = logger 51 | 52 | def forward(self, input_ids, input_segments, input_masks, position_vector = None, relation_labels=None): 53 | pooled_output = self.language_model(input_ids, token_type_ids = input_segments, attention_mask = input_masks)[0] 54 | if position_vector is not None: 55 | pooled_output = torch.cat((pooled_output, position_vector.float().unsqueeze(-1)), -1) 56 | predicted_relations = self.relation_classifier(pooled_output) 57 | if relation_labels is not None: 58 | loss_fct = nn.CrossEntropyLoss() 59 | relation_loss = loss_fct(predicted_relations.view(-1, self.num_relations), relation_labels.view(-1)) 60 | return relation_loss 61 | else: 62 | return predicted_relations 63 | 64 | class BERTForMultiTaskRelationClassification(nn.Module): 65 | def __init__(self, lm, num_orig_relations, num_pw_relations, logger, num_directions=2): 66 | super(BERTForMultiTaskRelationClassification, self).__init__() 67 | self.language_model = lm 68 | self.relation_classifier = nn.Linear(770, num_orig_relations) # TODO: get from config, for small or large 69 | self.pw_classifier = nn.Linear(770, num_pw_relations) 70 | self.direction_classifier = nn.Linear(770, num_directions) 71 | self.num_orig_relations = num_orig_relations 72 | self.num_pw_relations = num_pw_relations 73 | self.num_directions = num_directions 74 | self.logger = logger 75 | 76 | def forward(self, input_ids, input_segments, input_masks, position_vector1 = None, position_vector2 = None, relation_labels=None, direction_labels=None, flag = False): 77 | pooled_output = self.language_model(input_ids, token_type_ids = input_segments, attention_mask = input_masks)[0] 78 | if position_vector1 is not None: 79 | pooled_output = torch.cat((pooled_output, position_vector1.float().unsqueeze(-1), position_vector2.float().unsqueeze(-1)), -1) 80 | pooled_output = pooled_output[:,-1] 81 | if flag: 82 | predicted_relations = self.pw_classifier(pooled_output) 83 | predicted_directions = None 84 | else: 85 | predicted_relations = self.relation_classifier(pooled_output) 86 | predicted_directions = self.direction_classifier(pooled_output) 87 | if relation_labels is not None: 88 | loss_fct = nn.CrossEntropyLoss() 89 | if flag: 90 | relation_loss = loss_fct(predicted_relations.view(-1, self.num_pw_relations), relation_labels.view(-1)) 91 | return relation_loss 92 | else: 93 | relation_loss = loss_fct(predicted_relations.view(-1, self.num_orig_relations), relation_labels.view(-1)) 94 | direction_loss = custom_loss_fct(loss_fct, predicted_directions, direction_labels) 95 | return relation_loss + direction_loss 96 | else: 97 | return predicted_relations, predicted_directions -------------------------------------------------------------------------------- /research/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/models/__init__.py -------------------------------------------------------------------------------- /research/tester/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/tester/__init__.py -------------------------------------------------------------------------------- /research/tester/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import TensorDataset, RandomSampler, DataLoader 3 | 4 | from research.evaluation.semeval2010_writer import file_writer 5 | from research.libnlp.Document import Document 6 | 7 | 8 | def test_semantic_relation_model(document, sr_dict, batch_size, model, device, true_file, prediction_file, logger): 9 | test_inputs = list() 10 | test_segments = list() 11 | test_masks = list() 12 | test_relations = list() 13 | test_directions = list() 14 | test_doc_ids = list() 15 | position_vect1 = list() 16 | position_vect2 = list() 17 | if isinstance(document, dict): 18 | for d in document.values(): 19 | [ip, segment, mask], v1, v2 = d.input_features 20 | test_inputs.append(ip) 21 | test_segments.append(segment) 22 | test_masks.append(mask) 23 | test_relations.append(sr_dict[d.sr.sr]) 24 | test_directions.append(d.sr.dir) 25 | test_doc_ids.append(d.doc_id) 26 | position_vect1.append(v1) 27 | position_vect2.append(v2) 28 | elif isinstance(document, Document): 29 | test_inputs, test_segments, test_masks = document.input_features 30 | test_relations.append(sr_dict[document.sr.sr]) 31 | test_directions.append(document.sr.dir) 32 | test_doc_ids.append(document.doc_id) 33 | test_inputs = torch.tensor(test_inputs) 34 | test_segments = torch.tensor(test_segments) 35 | test_masks = torch.tensor(test_masks) 36 | test_relations = torch.tensor(test_relations) 37 | test_directions = torch.tensor(test_directions) 38 | test_doc_ids = torch.tensor(test_doc_ids) 39 | position_vect1 = torch.tensor(position_vect1) 40 | position_vect2 = torch.tensor(position_vect2) 41 | 42 | test_data = TensorDataset(test_inputs, test_segments, test_masks, test_relations, test_directions, test_doc_ids, position_vect1, position_vect2) 43 | test_sampler = RandomSampler(test_data) 44 | test_data_loader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size) 45 | 46 | return get_relation_predictions(test_data_loader, model, device, sr_dict, true_file, prediction_file, logger) 47 | 48 | 49 | def get_relation_predictions(test_data_loader, model, device, sr_dict, true_file, prediction_file, logger, num_directions=2): 50 | eval_loss = 0.0 51 | nb_eval_steps = 0 52 | inverse_sr_dict = {v: k for k, v in sr_dict.items()} 53 | file1 = open(true_file, "w+") 54 | file2 = open(prediction_file, "w+") 55 | model.eval() 56 | for batch in test_data_loader: 57 | batch = tuple(t.to(device) for t in batch) 58 | bt_features, bt_segments, bt_masks, bt_relations, bt_directions, bt_ids, bt_pos1, bt_pos2 = batch 59 | with torch.no_grad(): 60 | loss = model(bt_features, bt_segments, bt_masks, position_vector1 = bt_pos1, position_vector2 = bt_pos2, relation_labels=bt_relations, 61 | direction_labels=bt_directions) 62 | predicted_relations, predicted_directions = model(bt_features, bt_segments, bt_masks, position_vector1 = bt_pos1, position_vector2 = bt_pos2) 63 | predicted_relations = predicted_relations.view(-1, len(sr_dict.keys())) 64 | predicted_directions = predicted_directions.view(-1, num_directions) 65 | eval_loss += loss.mean().item() 66 | nb_eval_steps += 1 67 | file1 = file_writer(bt_relations, bt_directions, bt_ids, file1, inverse_sr_dict) 68 | file2 = file_writer(predicted_relations, predicted_directions, bt_ids, file2, inverse_sr_dict, 69 | class_type="predicted") 70 | logger.info("Total test loss: {}".format(eval_loss)) 71 | logger.info("Test loss: {}".format(eval_loss / nb_eval_steps)) 72 | 73 | file1.close() 74 | file2.close() 75 | 76 | def get_role_predictions(document, test_data_loader, model, device, sr_dict, true_file, prediction_file, logger): 77 | eval_loss = 0.0 78 | nb_eval_steps = 0 79 | inverse_sr_dict = {v: k for k, v in sr_dict.items()} 80 | file1 = open(true_file, "w+") 81 | file2 = open(prediction_file, "w+") 82 | model.eval() 83 | for batch in test_data_loader: 84 | batch = tuple(t.to(device) for t in batch) 85 | bt_features, bt_segments, bt_masks, bt_relations, bt_ids, bt_pos = batch 86 | with torch.no_grad(): 87 | loss = model(bt_features, bt_segments, bt_masks, position_vector = bt_pos, relation_labels=bt_relations) 88 | predicted_relations = model(bt_features, bt_segments, bt_masks, position_vector = bt_pos) 89 | 90 | eval_loss += loss.mean().item() 91 | nb_eval_steps += 1 92 | file1 = file_writer(document, bt_relations, bt_ids, file1, inverse_sr_dict) 93 | file2 = file_writer(document, predicted_relations, bt_ids, file2, inverse_sr_dict, class_type="predicted") 94 | logger.info("Total test loss: {}".format(eval_loss)) 95 | logger.info("Test loss: {}".format(eval_loss / nb_eval_steps)) 96 | 97 | file1.close() 98 | file2.close() 99 | 100 | 101 | def test_semantic_role_model(document, sr_dict, batch_size, model, device, true_file, prediction_file, logger): 102 | test_inputs = list() 103 | test_segments = list() 104 | test_masks = list() 105 | test_relations = list() 106 | test_doc_ids = list() 107 | position_vect = list() 108 | if isinstance(document, dict): 109 | for d in document.values(): 110 | tags = list() 111 | [ip, segment, mask], pos, labels = d.input_features 112 | test_inputs.append(ip) 113 | test_segments.append(segment) 114 | test_masks.append(mask) 115 | for label in labels: 116 | if label in sr_dict: 117 | tags.append(sr_dict[label]) 118 | else: 119 | tags.append(sr_dict["O"]) # TODO: Need a better fix for this!! 120 | test_relations.append(tags) 121 | test_doc_ids.append(d.doc_id) 122 | position_vect.append(pos) 123 | test_inputs = torch.tensor(test_inputs) 124 | test_segments = torch.tensor(test_segments) 125 | test_masks = torch.tensor(test_masks) 126 | test_relations = torch.tensor(test_relations) 127 | test_doc_ids = torch.tensor(test_doc_ids) 128 | position_vect = torch.tensor(position_vect) 129 | 130 | test_data = TensorDataset(test_inputs, test_segments, test_masks, test_relations, test_doc_ids, position_vect) 131 | test_sampler = RandomSampler(test_data) 132 | test_data_loader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size) 133 | 134 | return get_role_predictions(document, test_data_loader, model, device, sr_dict, true_file, prediction_file, logger) 135 | -------------------------------------------------------------------------------- /research/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takshakpdesai/bert_srl_src/499d0c2db4cca807b296af579e592596f2a9a199/research/trainer/__init__.py -------------------------------------------------------------------------------- /research/trainer/semantic_relation_train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.utils.data import TensorDataset, RandomSampler, DataLoader 5 | from tqdm import trange 6 | 7 | from research.libnlp.Document import Document 8 | from research.models.BERTModels import BERTForRelationClassification 9 | 10 | 11 | def create_class_map(document): 12 | sr_dict = dict() 13 | if isinstance(document, dict): 14 | for d in document.values(): 15 | sr = d.sr.sr 16 | if sr not in sr_dict.keys(): 17 | sr_dict[sr] = len(sr_dict) 18 | if isinstance(document, Document): 19 | sr = document.sr.sr 20 | sr_dict[sr] = len(sr_dict) 21 | return sr_dict 22 | 23 | 24 | def train_model(document, batch_size, lm_model, device, lr, epochs, logger): 25 | train_inputs = list() 26 | train_segments = list() 27 | train_masks = list() 28 | train_relations = list() 29 | train_directions = list() 30 | position_vect1 = list() 31 | position_vect2 = list() 32 | sr_dict = create_class_map(document) 33 | if isinstance(document, dict): 34 | for d in document.values(): 35 | [ip, segment, mask], v1, v2 = d.input_features 36 | train_inputs.append(ip) 37 | train_segments.append(segment) 38 | train_masks.append(mask) 39 | position_vect1.append(v1) 40 | position_vect2.append(v2) 41 | train_relations.append(sr_dict[d.sr.sr]) 42 | train_directions.append(d.sr.dir) 43 | elif isinstance(document, Document): 44 | train_inputs, train_segments, train_masks = document.input_features 45 | train_relations.append(sr_dict[document.sr.sr]) 46 | train_directions.append(document.sr.dir) 47 | train_inputs = torch.tensor(train_inputs) 48 | train_segments = torch.tensor(train_segments) 49 | train_masks = torch.tensor(train_masks) 50 | train_relations = torch.tensor(train_relations) 51 | train_directions = torch.tensor(train_directions) 52 | position_vect1 = torch.tensor(position_vect1) 53 | position_vect2 = torch.tensor(position_vect2) 54 | 55 | print("Training set size increased to "+ str(len(train_inputs))) 56 | 57 | train_data = TensorDataset(train_inputs, train_segments, train_masks, train_relations, train_directions, position_vect1, position_vect2) 58 | train_sampler = RandomSampler(train_data) 59 | train_data_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 60 | 61 | our_model = BERTForRelationClassification(lm_model, len(sr_dict), logger) 62 | our_model = nn.DataParallel(our_model, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]) # TODO: make param 63 | our_model.to(device) 64 | logger.info(str(our_model)) 65 | 66 | param_optimizer = list(our_model.named_parameters()) 67 | no_decay = ['bias', 'gamma', 'beta'] 68 | optimizer_grouped_parameters = \ 69 | [{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 70 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}] 71 | 72 | optimizer = Adam(optimizer_grouped_parameters, lr=lr) 73 | tr_t, tr_l, vl_l, vl_a, f1_s = [], [], [], [], [] 74 | 75 | for _ in trange(epochs, desc="Epoch"): 76 | our_model.train() 77 | tr_loss = 0.0 78 | nb_tr_examples, nb_tr_steps = 0, 0 79 | for step, batch in enumerate(train_data_loader): 80 | batch = tuple(item.to(device) for item in batch) 81 | bt_features, bt_segments, bt_masks, bt_relations, bt_dirs, bt_pos1, bt_pos2 = batch 82 | loss = our_model(bt_features, bt_segments, bt_masks, position_vector1 = bt_pos1, position_vector2 = bt_pos2, relation_labels=bt_relations, direction_labels=bt_dirs) 83 | loss.sum().backward() 84 | tr_loss += loss.sum().item() 85 | nb_tr_examples += bt_features.size(0) 86 | nb_tr_steps += 1 87 | optimizer.step() 88 | our_model.zero_grad() 89 | tr_t.append(tr_loss) 90 | tr_l.append(tr_loss / nb_tr_steps) 91 | logger.info("Total training loss: {}".format(tr_loss)) 92 | logger.info("Train loss: {}".format(tr_loss / nb_tr_steps)) 93 | 94 | return our_model, sr_dict 95 | -------------------------------------------------------------------------------- /research/trainer/semantic_role_train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.utils.data import TensorDataset, RandomSampler, DataLoader 5 | from tqdm import trange 6 | 7 | from research.libnlp.Document import Document 8 | from research.models.BERTModels import BERTForRoleLabeling 9 | 10 | 11 | def create_class_map(document): 12 | sr_dict = dict() 13 | if isinstance(document, dict): 14 | for d in document.values(): 15 | sr = d.sr.labels 16 | for label in sr: 17 | if label not in sr_dict.keys(): 18 | sr_dict[label] = len(sr_dict) 19 | if isinstance(document, Document): 20 | sr = document.sr.labels 21 | for label in sr: 22 | if label not in sr_dict.keys(): 23 | sr_dict[label] = len(sr_dict) 24 | return sr_dict 25 | 26 | 27 | def train_model(document, batch_size, lm_model, device, lr, epochs, logger): 28 | train_inputs = list() 29 | train_segments = list() 30 | train_masks = list() 31 | train_relations = list() 32 | position_vect1 = list() 33 | sr_dict = create_class_map(document) 34 | if isinstance(document, dict): 35 | for d in document.values(): 36 | labels = list() 37 | [ip, segment, mask], v, padded_labels = d.input_features 38 | train_inputs.append(ip) 39 | train_segments.append(segment) 40 | train_masks.append(mask) 41 | position_vect1.append(v) 42 | for label in padded_labels: 43 | labels.append(sr_dict[label]) 44 | train_relations.append(labels) 45 | elif isinstance(document, Document): 46 | labels = list() 47 | [train_inputs, train_segments, train_masks], v, _ = document.input_features 48 | position_vect1.append(v) 49 | for label in document.sr.labels: 50 | labels.append(sr_dict[label]) 51 | train_relations.append(labels) 52 | train_inputs = torch.tensor(train_inputs) 53 | train_segments = torch.tensor(train_segments) 54 | train_masks = torch.tensor(train_masks) 55 | train_relations = torch.tensor(train_relations) 56 | position_vect1 = torch.tensor(position_vect1) 57 | 58 | train_data = TensorDataset(train_inputs, train_segments, train_masks, train_relations, position_vect1) 59 | train_sampler = RandomSampler(train_data) 60 | train_data_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 61 | 62 | our_model = BERTForRoleLabeling(lm_model, len(sr_dict), logger) 63 | our_model = nn.DataParallel(our_model, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]) # TODO: make param 64 | our_model.to(device) 65 | logger.info(str(our_model)) 66 | 67 | param_optimizer = list(our_model.named_parameters()) 68 | no_decay = ['bias', 'gamma', 'beta'] 69 | optimizer_grouped_parameters = \ 70 | [{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 71 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}] 72 | 73 | optimizer = Adam(optimizer_grouped_parameters, lr=lr) 74 | tr_t, tr_l, vl_l, vl_a, f1_s = [], [], [], [], [] 75 | 76 | for _ in trange(epochs, desc="Epoch"): 77 | our_model.train() 78 | tr_loss = 0.0 79 | nb_tr_examples, nb_tr_steps = 0, 0 80 | for step, batch in enumerate(train_data_loader): 81 | batch = tuple(item.to(device) for item in batch) 82 | bt_features, bt_segments, bt_masks, bt_relations, bt_pos = batch 83 | loss = our_model(bt_features, bt_segments, bt_masks, position_vector = bt_pos, relation_labels=bt_relations) 84 | loss.sum().backward() 85 | tr_loss += loss.sum().item() 86 | nb_tr_examples += bt_features.size(0) 87 | nb_tr_steps += 1 88 | optimizer.step() 89 | our_model.zero_grad() 90 | tr_t.append(tr_loss) 91 | tr_l.append(tr_loss / nb_tr_steps) 92 | logger.info("Total training loss: {}".format(tr_loss)) 93 | logger.info("Train loss: {}".format(tr_loss / nb_tr_steps)) 94 | 95 | return our_model, sr_dict 96 | -------------------------------------------------------------------------------- /research/trainer/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.utils.data import TensorDataset, RandomSampler, DataLoader 5 | from tqdm import trange 6 | 7 | from research.libnlp.Document import Document 8 | from research.models.BERTModels import BERTForRelationClassification 9 | 10 | 11 | def create_class_map(document): 12 | sr_dict = dict() 13 | if isinstance(document, dict): 14 | for d in document.values(): 15 | sr = d.sr.sr 16 | if sr not in sr_dict.keys(): 17 | sr_dict[sr] = len(sr_dict) 18 | if isinstance(document, Document): 19 | sr = document.sr.sr 20 | sr_dict[sr] = len(sr_dict) 21 | return sr_dict 22 | 23 | 24 | def train_model(document, batch_size, lm_model, device, lr, epochs, logger): 25 | train_inputs = list() 26 | train_segments = list() 27 | train_masks = list() 28 | train_relations = list() 29 | train_directions = list() 30 | position_vect1 = list() 31 | position_vect2 = list() 32 | sr_dict = create_class_map(document) 33 | if isinstance(document, dict): 34 | for d in document.values(): 35 | [ip, segment, mask], v1, v2 = d.input_features 36 | train_inputs.append(ip) 37 | train_segments.append(segment) 38 | train_masks.append(mask) 39 | position_vect1.append(v1) 40 | position_vect2.append(v2) 41 | train_relations.append(sr_dict[d.sr.sr]) 42 | train_directions.append(d.sr.dir) 43 | elif isinstance(document, Document): 44 | train_inputs, train_segments, train_masks = document.input_features 45 | train_relations.append(sr_dict[document.sr.sr]) 46 | train_directions.append(document.sr.dir) 47 | train_inputs = torch.tensor(train_inputs) 48 | train_segments = torch.tensor(train_segments) 49 | train_masks = torch.tensor(train_masks) 50 | train_relations = torch.tensor(train_relations) 51 | train_directions = torch.tensor(train_directions) 52 | position_vect1 = torch.tensor(position_vect1) 53 | position_vect2 = torch.tensor(position_vect2) 54 | 55 | train_data = TensorDataset(train_inputs, train_segments, train_masks, train_relations, train_directions, position_vect1, position_vect2) 56 | train_sampler = RandomSampler(train_data) 57 | train_data_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 58 | 59 | our_model = BERTForRelationClassification(lm_model, len(sr_dict), logger) 60 | our_model = nn.DataParallel(our_model, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]) # TODO: make param 61 | our_model.to(device) 62 | logger.info(str(our_model)) 63 | 64 | param_optimizer = list(our_model.named_parameters()) 65 | no_decay = ['bias', 'gamma', 'beta'] 66 | optimizer_grouped_parameters = \ 67 | [{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 68 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}] 69 | 70 | optimizer = Adam(optimizer_grouped_parameters, lr=lr) 71 | tr_t, tr_l, vl_l, vl_a, f1_s = [], [], [], [], [] 72 | 73 | for _ in trange(epochs, desc="Epoch"): 74 | our_model.train() 75 | tr_loss = 0.0 76 | nb_tr_examples, nb_tr_steps = 0, 0 77 | for step, batch in enumerate(train_data_loader): 78 | batch = tuple(item.to(device) for item in batch) 79 | bt_features, bt_segments, bt_masks, bt_relations, bt_dirs, bt_pos1, bt_pos2 = batch 80 | loss = our_model(bt_features, bt_segments, bt_masks, position_vector1 = bt_pos1, position_vector2 = bt_pos2, relation_labels=bt_relations, direction_labels=bt_dirs) 81 | loss.sum().backward() 82 | tr_loss += loss.sum().item() 83 | nb_tr_examples += bt_features.size(0) 84 | nb_tr_steps += 1 85 | optimizer.step() 86 | our_model.zero_grad() 87 | tr_t.append(tr_loss) 88 | tr_l.append(tr_loss / nb_tr_steps) 89 | logger.info("Total training loss: {}".format(tr_loss)) 90 | logger.info("Train loss: {}".format(tr_loss / nb_tr_steps)) 91 | 92 | return our_model, sr_dict 93 | --------------------------------------------------------------------------------