├── engine ├── __init__.py ├── srl.py ├── utils.py └── modules.py ├── README.md ├── run_inference.py ├── evaluation.py ├── LICENSE ├── run_srl.py └── data_house └── scripts └── srl-eval.pl /engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Heterogeneous Syntax Fuser (HeSyFu) for SRL 2 | 3 | 4 | Code for the ACL2021 (Finding) work [**Better Combine Them Together! Integrating Syntactic Constituency and Dependency Representations for Semantic Role Labeling**](https://aclanthology.org/2021.findings-acl.49/) 5 | 6 | --------- 7 | 8 | 🎉 Visit the project page here [HeSyFu](https://haofei.vip/HeSyFu-SRL/) 9 | 10 | --------- 11 | 12 | 13 | ## Data 14 | 15 | #### Donwload the SRL dataset: 16 | * [CoNLL05](https://www.cs.upc.edu/˜srlconll/soft.html) 17 | * [CoNLL09](https://catalog.ldc.upenn.edu/LDC2012T03) 18 | * [CoNLL12](https://catalog.ldc.upenn.edu/LDC2013T19) 19 | 20 | 21 | #### Ensemble the heterogeneous syntax annotations by linking the CoNLL sentences to the UPB. 22 | 23 | #### Format the data as CoNLLU-like, make sure the dependency syntax (head & dependent label) and the constituency syntax are will compatible to the conllu style. 24 | 25 | 26 | #### Word embedding: 27 | * [GloVe](https://github.com/stanfordnlp/GloVe) 28 | * [RoBERTa (base)](https://github.com/pytorch/fairseq/tree/master/examples/RoBERTa) 29 | 30 | 31 | 32 | ## Training & Evaluating 33 | 34 | 35 | #### Run _run_srl.py_ 36 | 37 | #### Run _run_inference.py_ 38 | 39 | 40 | *** 41 | 42 | ``` 43 | @inproceedings{fei-etal-2021-better, 44 | title = "Better Combine Them Together! Integrating Syntactic Constituency and Dependency Representations for Semantic Role Labeling", 45 | author = "Fei, Hao and 46 | Wu, Shengqiong and 47 | Ren, Yafeng and 48 | Li, Fei and 49 | Ji, Donghong", 50 | booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021", 51 | year = "2021", 52 | publisher = "Association for Computational Linguistics", 53 | pages = "549--559", 54 | } 55 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | from engine.utils import read_data, create_vocs, create_labels_voc, build_vocab_GCN, get_indexes 6 | from evaluation import evaluate 7 | from itertools import chain 8 | from engine.srl import SRLer 9 | from engine.modules import CRF 10 | from pytorch_transformers import RobertaTokenizer, RobertaModel 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser(description="HeSyFu inference") 14 | parser.add_argument("--dir", type=str, default="savedir/", help="Output directory") 15 | parser.add_argument("--modelname", type=str, default="model.pickle") 16 | parser.add_argument("--batch-size", type=int, default=64) 17 | parser.add_argument( 18 | "--emb-dim", type=int, default=300, help="word embedding dimension" 19 | ) 20 | parser.add_argument( 21 | "--use-roberta", type=int, default=0, help="default do not use RoBERTa embeddings" 22 | ) 23 | 24 | parser.add_argument( 25 | "--embedding-layer-norm", 26 | type=int, 27 | default=0, 28 | help="default do not embedding layer norm", 29 | ) 30 | 31 | parser.add_argument( 32 | "--rep_dim", type=int, default=350, help="const_gcn/dep_gcn dimension" 33 | ) 34 | parser.add_argument("--n-layers", type=int, default=2, help="encoder num layers") 35 | 36 | parser.add_argument( 37 | "--train-file", type=str, required=True, help="path of the training file" 38 | ) 39 | parser.add_argument( 40 | "--dev-file", type=str, required=True, help="path of the training file" 41 | ) 42 | 43 | parser.add_argument( 44 | "--test-file", type=str, required=True, help="path of the training file" 45 | ) 46 | 47 | parser.add_argument( 48 | "--glove-path", 49 | type=str, 50 | required=True, 51 | help="path of Glove glove.6B.300d.txt embeddings", 52 | ) 53 | 54 | parser.add_argument( 55 | "--bilinear-dropout", 56 | type=float, 57 | default=0.0, 58 | help="dropout at the bilinear module", 59 | ) 60 | parser.add_argument( 61 | "--gcn-dropout", type=float, default=0.1, help="dropout of the const_gcn/dep_gcn input module" 62 | ) 63 | parser.add_argument( 64 | "--emb-dropout", 65 | type=float, 66 | default=0.2, 67 | help="dropout of the embedding , default off", 68 | ) 69 | parser.add_argument( 70 | "--edge-dropout", 71 | type=float, 72 | default=0.1, 73 | help="dropout of the const_gcn/dep_gcn edges , default off", 74 | ) 75 | parser.add_argument( 76 | "--label-dropout", 77 | type=float, 78 | default=0.1, 79 | help="dropout of the const_gcn/dep_gcn labels , default off", 80 | ) 81 | parser.add_argument( 82 | "--non-linearity", 83 | type=str, 84 | default="relu", 85 | help="nonlinearity used, default relu", 86 | ) 87 | 88 | # gpu 89 | parser.add_argument("--gpu_id", type=int, default=-1, help="GPU ID") 90 | parser.add_argument("--seed", type=int, default=3070, help="seed") 91 | 92 | params, _ = parser.parse_known_args() 93 | 94 | # set gpu device 95 | torch.cuda.set_device(params.gpu_id) 96 | 97 | """ 98 | SEED 99 | """ 100 | np.random.seed(params.seed) 101 | torch.manual_seed(params.seed) 102 | torch.cuda.manual_seed(params.seed) 103 | 104 | GLOVE_PATH = params.glove_path 105 | 106 | train_file = params.train_file 107 | dev_file = params.dev_file 108 | test_file = params.test_file 109 | 110 | train, train_data_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx = read_data(train_file, {}, {}, {}) 111 | dev, dev_data_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx = read_data( 112 | dev_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx 113 | ) 114 | test, test_data_file, _, _ = read_data(test_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx) 115 | 116 | word_to_idx = create_vocs(train) 117 | 118 | roles_to_idx, idx_to_roles = create_labels_voc(train + dev) 119 | word_vec = build_vocab_GCN( 120 | [t["text"] for t in train] 121 | + [t["text"] for t in dev] 122 | + [t["text"] for t in test], 123 | GLOVE_PATH, 124 | ) 125 | 126 | test = get_indexes(test, word_to_idx, roles_to_idx) 127 | 128 | srl = SRLer( 129 | params.rep_dim, 130 | len(roles_to_idx), 131 | params.n_layers, 132 | len(dep_lb_to_idx), 133 | len(w_c_to_idx), 134 | len(c_c_to_idx), 135 | params.embedding_layer_norm, 136 | params.use_bert, 137 | params, 138 | params.gpu_id, 139 | ) 140 | 141 | print(srl) 142 | 143 | crf = CRF( 144 | len(roles_to_idx), None, include_start_end_transitions=True 145 | ) 146 | print(crf) 147 | 148 | model_parameters = filter( 149 | lambda p: p.requires_grad, chain(srl.parameters(), crf.parameters()) 150 | ) 151 | 152 | num_params = sum([np.prod(p.size()) for p in model_parameters]) 153 | print("Total parameters =", num_params) 154 | print(params) 155 | 156 | if params.use_bert: 157 | bert_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 158 | bert_model = RobertaModel.from_pretrained( 159 | "roberta-base", output_hidden_states=True 160 | ) 161 | if params.gpu_id > -1: 162 | bert_model.cuda() 163 | else: 164 | bert_tokenizer = None 165 | bert_model = None 166 | if params.gpu_id > -1: 167 | srl.cuda() 168 | crf.cuda() 169 | 170 | srl.load_state_dict(torch.load(os.path.join(params.dir, params.modelname))) 171 | 172 | crf.load_state_dict(torch.load(os.path.join(params.dir, params.modelname + "crf"))) 173 | 174 | evaluate( 175 | srl, 176 | 1000, 177 | test, 178 | -1, 179 | word_vec, 180 | idx_to_roles, 181 | params, 182 | params.modelname, 183 | params.dir, 184 | crf, 185 | 20, 186 | False, 187 | bert_tokenizer, 188 | bert_model, 189 | eval_type="test", 190 | final_eval=True, 191 | gold_data_file=test_data_file, 192 | gold_file_path=test_file, 193 | ) 194 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from subprocess import check_output 4 | import copy 5 | import time 6 | import os.path 7 | from engine.utils import get_batch_sup 8 | 9 | 10 | def evaluate( 11 | model, 12 | epoch, 13 | data, 14 | val_acc_best, 15 | word_vec, 16 | idx_to_roles, 17 | params, 18 | modelname, 19 | save_dir, 20 | crf, 21 | adam_stop, 22 | stop_training, 23 | bert_tokenizer, 24 | bert_model, 25 | eval_type="valid", 26 | final_eval=False, 27 | gold_data_file=None, 28 | gold_file_path=None, 29 | ): 30 | last_time = time.time() 31 | 32 | srl = model 33 | srl.eval() 34 | crf.eval() 35 | 36 | if eval_type == "valid": 37 | print("\nVALIDATION : Epoch {0}".format(epoch)) 38 | 39 | batch_size = 16 40 | 41 | all_pred = [] 42 | data_len = len(data) 43 | for stidx in range(0, data_len, batch_size): 44 | 45 | labels_batch, sentences_batch, predicate_flags_batch, mask_batch, lengths_batch, fixed_embs, \ 46 | dependency_arcs, dependency_labels, constituent_labels, const_GCN_w_c, const_GCN_c_w, const_GCN_c_c, \ 47 | mask_const_batch, predicate_index, bert_embs = get_batch_sup( 48 | data[stidx : stidx + batch_size], 49 | word_vec, 50 | params.gpu_id, 51 | params.enc_lstm_dim, 52 | 0.0, 53 | bert_tokenizer, 54 | bert_model, 55 | ) 56 | 57 | output = srl( 58 | sentences_batch, 59 | predicate_flags_batch, 60 | mask_batch, 61 | lengths_batch, 62 | fixed_embs, 63 | dependency_arcs, 64 | dependency_labels, 65 | constituent_labels, 66 | const_GCN_w_c, 67 | const_GCN_c_w, 68 | const_GCN_c_c, 69 | mask_const_batch, 70 | predicate_index, 71 | bert_embs, 72 | ) 73 | best_paths = crf.viterbi_tags(output, mask_batch) 74 | 75 | for x, _ in best_paths: 76 | all_pred += x 77 | 78 | print("Eval took", str(round((time.time() - last_time) / 60, 2))) 79 | if gold_data_file: 80 | try: 81 | annotated_data = _prep_conll_predictions( 82 | all_pred, gold_data_file, idx_to_roles 83 | ) 84 | _print_conll_predictions(annotated_data, modelname + "_" + ".txt") 85 | 86 | gold_standard_file = gold_file_path 87 | precision, recall, eval_acc = _evaluate_conll( 88 | modelname + "_" + ".txt", gold_standard_file 89 | ) 90 | 91 | 92 | except IndexError: 93 | print(all_pred), 94 | print(gold_data_file) 95 | print(idx_to_roles) 96 | 97 | if final_eval: 98 | print( 99 | "finalgrep : F1 {0} : {1} precision {0} : {2} recall {0} : {3}".format( 100 | eval_type, eval_acc, precision, recall 101 | ) 102 | ) 103 | else: 104 | print( 105 | "togrep : results : epoch {0} ; F1 score {1} : {2} " 106 | "precision {1} : {3} recall {1} : {4}".format( 107 | epoch, eval_type, eval_acc, precision, recall 108 | ) 109 | ) 110 | 111 | if eval_type == "valid" and epoch <= params.n_epochs: 112 | if eval_acc > val_acc_best: 113 | print("saving model at epoch {0}".format(epoch)) 114 | if not os.path.exists(save_dir): 115 | os.makedirs(save_dir) 116 | torch.save(model.state_dict(), os.path.join(save_dir, modelname)) 117 | torch.save(crf.state_dict(), os.path.join(save_dir, modelname + "crf")) 118 | val_acc_best = eval_acc 119 | adam_stop = 20 120 | stop_training = False 121 | else: 122 | adam_stop -= 1 123 | if adam_stop == 0: 124 | stop_training = True 125 | return eval_acc, val_acc_best, stop_training, adam_stop 126 | 127 | 128 | def _prep_conll_predictions(pred, conll_gold, idx_to_roles): 129 | data = copy.deepcopy(conll_gold) 130 | cur_sent_len = 0 131 | sent_lenghts = [] 132 | for li, line in enumerate(data): 133 | if len(line) == 0: 134 | sent_lenghts.append(cur_sent_len) 135 | cur_sent_len = 0 136 | else: 137 | cur_sent_len += 1 138 | curr_sent = 0 139 | line_count = 0 140 | n_predicates = 0 141 | prev_open = [] 142 | for li, line in enumerate(data): 143 | if len(line) == 0: 144 | if n_predicates > 0: 145 | line_count += sent_lenghts[curr_sent] * (n_predicates - 1) 146 | for la, label in enumerate(prev_open): 147 | if label != 0: 148 | data[li - 1][la + 6] += ")" 149 | prev_open = [] 150 | curr_sent += 1 151 | else: 152 | if len(prev_open) == 0: 153 | for _ in line[6:]: 154 | prev_open.append(0) 155 | for la, label in enumerate(line[6:]): 156 | if pred[line_count + (la * sent_lenghts[curr_sent])] >= len( 157 | idx_to_roles 158 | ): 159 | lb = "O" 160 | else: 161 | lb = idx_to_roles[pred[line_count + (la * sent_lenghts[curr_sent])]] 162 | if lb[0] == "O": 163 | data[li][la + 6] = "*" 164 | if prev_open[la]: 165 | data[li - 1][la + 6] += ")" 166 | prev_open[la] = 0 167 | elif lb[0] == "B": 168 | if prev_open[la]: 169 | data[li - 1][la + 6] += ")" 170 | data[li][la + 6] = "(" + lb[2:] + "*" 171 | prev_open[la] = lb[2:] 172 | elif lb[0] == "I": 173 | if not prev_open[la]: 174 | data[li][la + 6] = "(" + lb[2:] + "*" 175 | prev_open[la] = lb[2:] 176 | elif lb[2:] != prev_open[la]: 177 | data[li - 1][la + 6] += ")" 178 | data[li][la + 6] = "(" + lb[2:] + "*" 179 | prev_open[la] = lb[2:] 180 | else: 181 | data[li][la + 6] = "*" 182 | n_predicates = len(line[6:]) 183 | 184 | if n_predicates > 0: 185 | line_count += 1 186 | return data 187 | 188 | 189 | def _print_conll_predictions(data, name): 190 | with open("data/predictions/" + name, "w") as out: 191 | for line in data: 192 | out.write(" ".join(line[5:]) + "\n") 193 | 194 | 195 | def _evaluate_conll(prediction_file, gold_standard_file): 196 | script = "data/scripts/srl-eval.pl" 197 | 198 | cut_script_args = ["cut", "-d", " ", "-f", "10-100", gold_standard_file] 199 | 200 | eval_script_args = [ 201 | script, 202 | "/tmp/" + gold_standard_file.split("/")[-1], 203 | "data/predictions/" + prediction_file, 204 | ] 205 | 206 | try: 207 | DEVNULL = open(os.devnull, "wb") 208 | cut_out = check_output(cut_script_args, stderr=DEVNULL) 209 | open("/tmp/" + gold_standard_file.split("/")[-1], "wb").write(cut_out) 210 | 211 | out = check_output(eval_script_args, stderr=DEVNULL) 212 | out = out.decode("utf-8") 213 | 214 | out_ = " ".join(out.split()) 215 | all_ = out_.split() 216 | 217 | open("data/predictions/" + prediction_file + "_eval.out", "w").write(out) 218 | prec = all_[27] 219 | rec = all_[28] 220 | f1 = all_[29] 221 | return float(prec), float(rec), float(f1) 222 | except: 223 | raise IOError 224 | -------------------------------------------------------------------------------- /engine/srl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from engine.modules import ConstGCN, DepGCN, BilinearScorer, TransformerEncoder 5 | from engine.utils import _make_VariableLong 6 | 7 | 8 | class SRLer(nn.Module): 9 | def __init__( 10 | self, 11 | hidden_dim, 12 | tagset_size, 13 | num_layers, 14 | dep_tag_vocab_size, 15 | w_c_vocab_size, 16 | c_c_vocab_size, 17 | eln, 18 | use_bert, 19 | params, 20 | gpu_id=-1, 21 | ): 22 | super(SRLer, self).__init__() 23 | if gpu_id > -1: 24 | self.use_gpu = True 25 | else: 26 | self.use_gpu = False 27 | self.num_layers = num_layers 28 | self.vocab_size = w_c_vocab_size 29 | self.eln = eln 30 | self.use_bert = use_bert 31 | self.params = params 32 | self.dropout = nn.Dropout(p=params.gcn_dropout) 33 | self.embedding_dropout = nn.Dropout(p=params.emb_dropout) 34 | 35 | if self.use_bert: 36 | fixed_dim = 768 37 | 38 | else: 39 | fixed_dim = 100 40 | 41 | embedding_dim = self.params.emb_dim 42 | self.indicator_embeddings = nn.Embedding(2, embedding_dim) 43 | 44 | self.tagset_size = tagset_size 45 | 46 | 47 | if self.params.non_linearity == "relu": 48 | self.non_linearity = nn.ReLU() 49 | elif self.params.non_linearity == "tanh": 50 | self.non_linearity = nn.Tanh() 51 | elif self.params.non_linearity == "leakyrelu": 52 | self.non_linearity = nn.LeakyReLU() 53 | elif self.params.non_linearity == "celu": 54 | self.non_linearity = nn.CELU() 55 | elif self.params.non_linearity == "selu": 56 | self.non_linearity = nn.SELU() 57 | else: 58 | raise NotImplementedError 59 | 60 | 61 | self.TrmEncoder = TransformerEncoder(vocab_size=self.vocab_size, 62 | n_layers=self.num_layers, 63 | ) 64 | if self.use_gpu: 65 | self.TrmEncoder.to(self.device) 66 | 67 | self.hidden2predicate = nn.Linear(hidden_dim, hidden_dim) 68 | self.hidden2argument = nn.Linear(hidden_dim, hidden_dim) 69 | self.bilinear_scorer = BilinearScorer( 70 | hidden_dim, tagset_size, params.bilinear_dropout 71 | ) 72 | 73 | # ConstGCN 74 | # boundary bridging 75 | self.const_gcn_w_c = ConstGCN( 76 | hidden_dim, 77 | hidden_dim, 78 | w_c_vocab_size, 79 | in_arcs=True, 80 | out_arcs=True, 81 | use_gates=True, 82 | batch_first=True, 83 | residual=True, 84 | no_loop=True, 85 | dropout=self.params.gcn_dropout, 86 | non_linearity=self.non_linearity, 87 | edge_dropout=self.params.edge_dropout, 88 | ) 89 | 90 | # reverse boundary bridging 91 | self.const_gcn_c_w = ConstGCN( 92 | hidden_dim, 93 | hidden_dim, 94 | w_c_vocab_size, 95 | in_arcs=True, 96 | out_arcs=True, 97 | use_gates=True, 98 | batch_first=True, 99 | residual=True, 100 | no_loop=True, 101 | dropout=self.params.gcn_dropout, 102 | non_linearity=self.non_linearity, 103 | edge_dropout=self.params.edge_dropout, 104 | ) 105 | 106 | # self graph 107 | self.const_gcn_c_c = ConstGCN( 108 | hidden_dim, 109 | hidden_dim, 110 | c_c_vocab_size, 111 | in_arcs=True, 112 | out_arcs=True, 113 | use_gates=True, 114 | batch_first=True, 115 | residual=True, 116 | no_loop=False, 117 | dropout=self.params.gcn_dropout, 118 | non_linearity=self.non_linearity, 119 | edge_dropout=self.params.edge_dropout, 120 | ) 121 | 122 | # DepGCN 123 | self.dep_gcn = DepGCN(dep_tag_vocab_size, hidden_dim + 768, hidden_dim * 2, hidden_dim) 124 | 125 | self.gate = nn.Sigmoid() 126 | 127 | if self.eln: 128 | self.layernorm = nn.LayerNorm(fixed_dim) 129 | 130 | def forward( 131 | self, 132 | sentence, 133 | predicate_flags, 134 | sent_mask, 135 | lengths, 136 | fixed_embs, 137 | dependency_arcs, 138 | dependency_labels, 139 | constituent_labels, 140 | const_GCN_w_c, 141 | const_GCN_c_w, 142 | const_GCN_c_c, 143 | mask_const_batch, 144 | predicate_index, 145 | bert_embs, 146 | ): 147 | 148 | if self.use_bert: 149 | embeds = bert_embs 150 | else: 151 | embeds = fixed_embs 152 | 153 | if self.eln: 154 | embeds = self.layernorm(embeds * sent_mask.unsqueeze(2)) 155 | 156 | embeds = self.embedding_dropout(embeds) 157 | 158 | embeds = torch.cat( 159 | (embeds, self.indicator_embeddings(predicate_flags.long())), 2 160 | ) 161 | 162 | b, t, e = embeds.data.shape 163 | base_out = self.TrmEncoder(embeds) 164 | 165 | const_gcn_in = torch.cat([base_out, constituent_labels], dim=1) 166 | mask_all = torch.cat([sent_mask, mask_const_batch], dim=1) 167 | 168 | # boundary bridging 169 | adj_arc_in_w_c, adj_arc_out_w_c, adj_lab_in_w_c, adj_lab_out_w_c, mask_in_w_c, mask_out_w_c, mask_loop_w_c = ( 170 | const_GCN_w_c 171 | ) 172 | 173 | # inverse-boundary bridging 174 | adj_arc_in_c_w, adj_arc_out_c_w, adj_lab_in_c_w, adj_lab_out_c_w, mask_in_c_w, mask_out_c_w, mask_loop_c_w = ( 175 | const_GCN_c_w 176 | ) 177 | 178 | adj_arc_in_c_c, adj_arc_out_c_c, adj_lab_in_c_c, adj_lab_out_c_c, mask_in_c_c, mask_out_c_c, mask_loop_c_c = ( 179 | const_GCN_c_c 180 | ) 181 | 182 | const_gcn_out = self.const_gcn_w_c( 183 | const_gcn_in, 184 | adj_arc_in_w_c, 185 | adj_arc_out_w_c, 186 | adj_lab_in_w_c, 187 | adj_lab_out_w_c, 188 | mask_in_w_c, 189 | mask_out_w_c, 190 | mask_loop_w_c, 191 | mask_all, 192 | ) 193 | 194 | const_gcn_out = self.const_gcn_c_c( 195 | const_gcn_out, 196 | adj_arc_in_c_c, 197 | adj_arc_out_c_c, 198 | adj_lab_in_c_c, 199 | adj_lab_out_c_c, 200 | mask_in_c_c, 201 | mask_out_c_c, 202 | mask_loop_c_c, 203 | mask_all, 204 | ) 205 | 206 | const_gcn_out = self.const_gcn_c_w( 207 | const_gcn_out, 208 | adj_arc_in_c_w, 209 | adj_arc_out_c_w, 210 | adj_lab_in_c_w, 211 | adj_lab_out_c_w, 212 | mask_in_c_w, 213 | mask_out_c_w, 214 | mask_loop_c_w, 215 | mask_all, 216 | ) 217 | 218 | # const_gcn_out = const_gcn_out.narrow(1, 0, t) 219 | dep_gcn_in = torch.cat([base_out, const_gcn_out], dim=1) 220 | 221 | # learn from dependency 222 | dep_gcn_out = self.dep_gcn(dep_gcn_in, dependency_arcs, dependency_labels) 223 | 224 | # gating 225 | if self.use_gpu: 226 | gpu_id = 1 227 | else: 228 | gpu_id = 0 229 | gate_ = self.gate(torch.cat([dep_gcn_out, const_gcn_out], dim=1)) 230 | all_one = _make_VariableLong(np.zeros((b, t)), gpu_id, False) 231 | hesyfu_out = gate_ * dep_gcn_out + (all_one - gate_) * const_gcn_out 232 | 233 | hesyfu_out_view = hesyfu_out.contiguous().view(b * t, -1) 234 | predicate_index = predicate_index.view(b * t) 235 | 236 | predicates_repr = hesyfu_out_view.index_select(0, predicate_index).view(b, t, -1) 237 | 238 | pred_repr = self.non_linearity( 239 | self.hidden2predicate(self.dropout(predicates_repr)) 240 | ) 241 | arg_repr = self.non_linearity(self.hidden2argument(self.dropout(hesyfu_out_view))) 242 | tag_scores = self.bilinear_scorer(pred_repr, arg_repr) # [b*t, label_size] 243 | 244 | return tag_scores.view(b, t, self.tagset_size) 245 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /run_srl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import time 5 | import argparse 6 | import os 7 | from engine.utils import read_data, create_vocs, create_labels_voc, build_vocab_GCN, get_indexes, get_batch_sup 8 | from evaluation import evaluate 9 | from itertools import chain 10 | from engine.srl import SRLer 11 | from engine.modules import CRF 12 | from pytorch_transformers import RobertaTokenizer, RobertaModel 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser(description="HeSyFu training") 16 | parser.add_argument( 17 | "--outputdir", type=str, default="savedir/", help="Output directory" 18 | ) 19 | parser.add_argument("--outputmodelname", type=str, default="model.pickle") 20 | parser.add_argument("--n-epochs", type=int, default=1000) 21 | parser.add_argument("--batch-size", type=int, default=24) 22 | parser.add_argument("--weight-decay", type=float, default=1e-3, help="weight decay") 23 | parser.add_argument( 24 | "--learning-rate", type=float, default=2e-5, help="learning rate, default 2e-5" 25 | ) 26 | parser.add_argument( 27 | "--gradient-clipping", type=int, default=1, help="gradient clipping, default on" 28 | ) 29 | parser.add_argument( 30 | "--emb-dim", type=int, default=300, help="word embedding dimension" 31 | ) 32 | parser.add_argument( 33 | "--use-roberta", type=int, default=0, help="default do not use RoBERTa embeddings" 34 | ) 35 | parser.add_argument( 36 | "--embedding-layer-norm", 37 | type=int, 38 | default=0, 39 | help="default do not embedding layer norm", 40 | ) 41 | parser.add_argument( 42 | "--rep_dim", type=int, default=350, help="const_gcn/dep_gcn dimension" 43 | ) 44 | parser.add_argument("--n-layers", type=int, default=2, help="encoder num layers") 45 | 46 | parser.add_argument( 47 | "--word-drop", type=float, default=0.0, help="word dropout default 0.0, no drop" 48 | ) 49 | parser.add_argument( 50 | "--corpus", 51 | type=str, 52 | default="conll05", 53 | help="select the corpus, conll05 is default conll12, conll09 are the other options,", 54 | ) 55 | parser.add_argument( 56 | "--train-file", type=str, required=True, help="path of the training file" 57 | ) 58 | parser.add_argument( 59 | "--dev-file", type=str, required=True, help="path of the dev file" 60 | ) 61 | parser.add_argument( 62 | "--glove-path", 63 | type=str, 64 | required=True, 65 | help="path of Glove glove.6B.300d.txt embeddings", 66 | ) 67 | parser.add_argument( 68 | "--bilinear-dropout", 69 | type=float, 70 | default=0.0, 71 | help="dropout at the bilinear module", 72 | ) 73 | parser.add_argument( 74 | "--gcn-dropout", type=float, default=0.1, help="dropout of the const_gcn/dep_gcn input module" 75 | ) 76 | parser.add_argument( 77 | "--emb-dropout", 78 | type=float, 79 | default=0.2, 80 | help="dropout of the embedding , default off", 81 | ) 82 | parser.add_argument( 83 | "--edge-dropout", 84 | type=float, 85 | default=0.1, 86 | help="dropout of the const_gcn/dep_gcn edges , default off", 87 | ) 88 | parser.add_argument( 89 | "--label-dropout", 90 | type=float, 91 | default=0.1, 92 | help="dropout of the const_gcn/dep_gcn labels , default off", 93 | ) 94 | parser.add_argument( 95 | "--non-linearity", 96 | type=str, 97 | default="relu", 98 | help="nonlinearity used, default relu", 99 | ) 100 | 101 | # gpu 102 | parser.add_argument("--gpu-id", type=int, default=-1, help="GPU ID") 103 | parser.add_argument("--seed", type=int, default=3070, help="seed") 104 | 105 | params, _ = parser.parse_known_args() 106 | 107 | # set gpu device 108 | torch.cuda.set_device(params.gpu_id) 109 | 110 | """ 111 | SEED 112 | """ 113 | np.random.seed(params.seed) 114 | torch.manual_seed(params.seed) 115 | torch.cuda.manual_seed(params.seed) 116 | 117 | GLOVE_PATH = params.glove_path 118 | train_file = params.train_file 119 | dev_file = params.dev_file 120 | 121 | train, train_data_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx = read_data(train_file, {}, {}, {}) 122 | print("train examples", len(train)) 123 | dev, dev_data_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx = read_data( 124 | dev_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx 125 | ) 126 | print("dev examples", len(dev)) 127 | 128 | word_to_idx = create_vocs(train) 129 | roles_to_idx, idx_to_roles = create_labels_voc(train + dev) 130 | word_vec = build_vocab_GCN( 131 | [t["text"] for t in train] + [t["text"] for t in dev], GLOVE_PATH 132 | ) 133 | train = get_indexes(train, word_to_idx, roles_to_idx) 134 | dev = get_indexes(dev, word_to_idx, roles_to_idx) 135 | 136 | srl = SRLer( 137 | params.rep_dim, 138 | len(roles_to_idx), 139 | params.n_layers, 140 | len(dep_lb_to_idx), 141 | len(w_c_to_idx), 142 | len(c_c_to_idx), 143 | params.embedding_layer_norm, 144 | params.use_bert, 145 | params, 146 | params.gpu_id, 147 | ) 148 | print(srl) 149 | 150 | crf = CRF( 151 | len(roles_to_idx), None, include_start_end_transitions=True 152 | ) 153 | print(crf) 154 | 155 | model_parameters = filter( 156 | lambda p: p.requires_grad, chain(srl.parameters(), crf.parameters()) 157 | ) 158 | 159 | num_params = sum([np.prod(p.size()) for p in model_parameters]) 160 | print("Total parameters =", num_params) 161 | print(params) 162 | 163 | if params.use_bert: 164 | bert_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 165 | bert_model = RobertaModel.from_pretrained( 166 | "roberta-base", output_hidden_states=True 167 | ) 168 | if params.gpu_id > -1: 169 | bert_model.cuda() 170 | else: 171 | bert_tokenizer = None 172 | bert_model = None 173 | if params.gpu_id > -1: 174 | srl.cuda() 175 | crf.cuda() 176 | 177 | lr = params.learning_rate 178 | # optimizer 179 | 180 | optimizer = optim.Adam( 181 | chain(srl.parameters(), crf.parameters()), 182 | lr=lr, 183 | weight_decay=params.weight_decay, 184 | ) 185 | 186 | first_optim = True 187 | 188 | val_acc_best = -1.0 189 | adam_stop = False 190 | stop_training = 20 191 | 192 | for epc in range(params.n_epochs): 193 | srl.train() 194 | all_costs = [] 195 | logs = [] 196 | np.random.shuffle(train) 197 | last_time = time.time() 198 | train_len = len(train) 199 | for stidx in range(0, train_len, params.batch_size): 200 | 201 | labels_batch, sentences_batch, predicate_flags_batch, mask_batch, lengths_batch, fixed_embs, \ 202 | dependency_arcs, dependency_labels, constituent_labels, const_GCN_w_c, const_GCN_c_w, const_GCN_c_c, \ 203 | mask_const_batch, predicate_index, bert_embs = get_batch_sup( 204 | train[stidx: stidx + params.batch_size], 205 | word_vec, 206 | params.gpu_id, 207 | params.enc_lstm_dim, 208 | params.word_drop, 209 | bert_tokenizer, 210 | bert_model, 211 | ) 212 | 213 | output = srl( 214 | sentences_batch, 215 | predicate_flags_batch, 216 | mask_batch, 217 | lengths_batch, 218 | fixed_embs, 219 | dependency_arcs, 220 | dependency_labels, 221 | constituent_labels, 222 | const_GCN_w_c, 223 | const_GCN_c_w, 224 | const_GCN_c_c, 225 | mask_const_batch, 226 | predicate_index, 227 | bert_embs, 228 | ) 229 | 230 | # CRF Log-likelihood loss 231 | log_likelihood = crf(output, labels_batch, mask_batch) 232 | all_costs.append(-log_likelihood.item()) 233 | # backward 234 | 235 | optimizer.zero_grad() 236 | 237 | total_loss = -log_likelihood 238 | 239 | # optimizer step 240 | total_loss.backward() 241 | 242 | # Gradient clipping 243 | if params.gradient_clipping: 244 | torch.nn.utils.clip_grad_norm_( 245 | chain(srl.parameters(), crf.parameters()), 1 246 | ) 247 | 248 | optimizer.step() 249 | 250 | if params.corpus == "2005": 251 | stop_every = 10000 // params.batch_size 252 | else: 253 | stop_every = 30000 // params.batch_size 254 | 255 | start_check_at = 2 256 | if ( 257 | len(all_costs) == stop_every 258 | ): 259 | print( 260 | "Training for " + str(stop_every) + " batches took", 261 | str(round((time.time() - last_time) / 60, 2)), 262 | "minutes.", 263 | ) 264 | 265 | logs.append( 266 | "{0} ; loss {1}".format(stidx, round(np.mean(all_costs), 4)) 267 | ) 268 | if epc > start_check_at: 269 | eval_acc, val_acc_best, stop_training, adam_stop = evaluate( 270 | srl, 271 | epc + 1, 272 | dev, 273 | val_acc_best, 274 | word_vec, 275 | idx_to_roles, 276 | params, 277 | params.outputmodelname, 278 | params.outputdir, 279 | crf, 280 | adam_stop, 281 | stop_training, 282 | bert_tokenizer, 283 | bert_model, 284 | eval_type="valid", 285 | final_eval=False, 286 | gold_data_file=dev_data_file, 287 | gold_file_path=dev_file, 288 | ) 289 | 290 | srl.train() 291 | crf.train() 292 | print(logs[-1]) 293 | all_costs = [] 294 | last_time = time.time() 295 | 296 | print("epoch " + str(epc + 1) + " done.") 297 | 298 | eval_acc, val_acc_best, stop_training, adam_stop = evaluate( 299 | srl, 300 | epc + 1, 301 | dev, 302 | val_acc_best, 303 | word_vec, 304 | idx_to_roles, 305 | params, 306 | params.outputmodelname, 307 | params.outputdir, 308 | crf, 309 | adam_stop, 310 | stop_training, 311 | bert_tokenizer, 312 | bert_model, 313 | eval_type="valid", 314 | final_eval=False, 315 | gold_data_file=dev_data_file, 316 | gold_file_path=dev_file, 317 | ) 318 | 319 | srl.train() 320 | crf.train() 321 | 322 | if stop_training: 323 | adam_stop = 20 324 | stop_training = False 325 | lr = lr * 0.5 326 | print("Learning rate reduced to", str(lr)) 327 | for param_group in optimizer.param_groups: 328 | param_group["lr"] = lr 329 | 330 | if lr < 0.000125: 331 | break 332 | 333 | srl.load_state_dict( 334 | torch.load(os.path.join(params.outputdir, params.outputmodelname)) 335 | ) 336 | 337 | crf.load_state_dict( 338 | torch.load(os.path.join(params.outputdir, params.outputmodelname + "crf")) 339 | ) 340 | evaluate( 341 | srl, 342 | 1000, 343 | dev, 344 | val_acc_best, 345 | word_vec, 346 | idx_to_roles, 347 | params, 348 | params.outputmodelname, 349 | params.outputdir, 350 | crf, 351 | adam_stop, 352 | stop_training, 353 | bert_tokenizer, 354 | bert_model, 355 | eval_type="valid", 356 | final_eval=False, 357 | gold_data_file=dev_data_file, 358 | gold_file_path=dev_file, 359 | ) 360 | -------------------------------------------------------------------------------- /engine/utils.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | def get_const_adj_BE(batch, max_batch_len, gpu_id, max_degr_in, max_degr_out, forward): 8 | node1_index = [[word[1] for word in sent] for sent in batch] 9 | node2_index = [[word[2] for word in sent] for sent in batch] 10 | label_index = [[word[0] for word in sent] for sent in batch] 11 | begin_index = [[word[3] for word in sent] for sent in batch] 12 | 13 | batch_size = len(batch) 14 | 15 | _MAX_BATCH_LEN = max_batch_len 16 | _MAX_DEGREE_IN = max_degr_in 17 | _MAX_DEGREE_OUT = max_degr_out 18 | 19 | adj_arc_in = np.zeros( 20 | (batch_size * _MAX_BATCH_LEN * _MAX_DEGREE_IN, 2), dtype="int32" 21 | ) 22 | adj_lab_in = np.zeros( 23 | (batch_size * _MAX_BATCH_LEN * _MAX_DEGREE_IN, 1), dtype="int32" 24 | ) 25 | adj_arc_out = np.zeros( 26 | (batch_size * _MAX_BATCH_LEN * _MAX_DEGREE_OUT, 2), dtype="int32" 27 | ) 28 | adj_lab_out = np.zeros( 29 | (batch_size * _MAX_BATCH_LEN * _MAX_DEGREE_OUT, 1), dtype="int32" 30 | ) 31 | 32 | mask_in = np.zeros((batch_size * _MAX_BATCH_LEN * _MAX_DEGREE_IN), dtype="float32") 33 | mask_out = np.zeros( 34 | (batch_size * _MAX_BATCH_LEN * _MAX_DEGREE_OUT), dtype="float32" 35 | ) 36 | mask_loop = np.ones((batch_size * _MAX_BATCH_LEN, 1), dtype="float32") 37 | 38 | tmp_in = {} 39 | tmp_out = {} 40 | 41 | for d, de in enumerate(node1_index): # iterates over the batch 42 | for a, arc in enumerate(de): 43 | if not forward: 44 | arc_1 = arc 45 | arc_2 = node2_index[d][a] 46 | else: 47 | arc_2 = arc 48 | arc_1 = node2_index[d][a] 49 | 50 | if begin_index[d][a] == 0: # BEGIN 51 | if arc_1 in tmp_in: 52 | tmp_in[arc_1] += 1 53 | else: 54 | tmp_in[arc_1] = 0 55 | 56 | idx_in = ( 57 | (d * _MAX_BATCH_LEN * _MAX_DEGREE_IN) 58 | + arc_1 * _MAX_DEGREE_IN 59 | + tmp_in[arc_1] 60 | ) 61 | 62 | if tmp_in[arc_1] < _MAX_DEGREE_IN: 63 | adj_arc_in[idx_in] = np.array([d, arc_2]) # incoming arcs 64 | adj_lab_in[idx_in] = np.array([label_index[d][a]]) # incoming arcs 65 | mask_in[idx_in] = 1.0 66 | 67 | else: # END 68 | if arc_1 in tmp_out: 69 | tmp_out[arc_1] += 1 70 | else: 71 | tmp_out[arc_1] = 0 72 | 73 | idx_out = ( 74 | (d * _MAX_BATCH_LEN * _MAX_DEGREE_OUT) 75 | + arc_1 * _MAX_DEGREE_OUT 76 | + tmp_out[arc_1] 77 | ) 78 | 79 | if tmp_out[arc_1] < _MAX_DEGREE_OUT: 80 | adj_arc_out[idx_out] = np.array([d, arc_2]) # outgoing arcs 81 | adj_lab_out[idx_out] = np.array( 82 | [label_index[d][a]] 83 | ) # outgoing arcs 84 | mask_out[idx_out] = 1.0 85 | 86 | tmp_in = {} 87 | tmp_out = {} 88 | 89 | adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in).tolist()) 90 | adj_arc_out = torch.LongTensor(np.transpose(adj_arc_out).tolist()) 91 | 92 | adj_lab_in = torch.LongTensor(np.transpose(adj_lab_in).tolist()) 93 | adj_lab_out = torch.LongTensor(np.transpose(adj_lab_out).tolist()) 94 | 95 | mask_in = autograd.Variable( 96 | torch.FloatTensor( 97 | mask_in.reshape((_MAX_BATCH_LEN * batch_size, _MAX_DEGREE_IN)).tolist() 98 | ), 99 | requires_grad=False, 100 | ) 101 | mask_out = autograd.Variable( 102 | torch.FloatTensor( 103 | mask_out.reshape((_MAX_BATCH_LEN * batch_size, _MAX_DEGREE_OUT)).tolist() 104 | ), 105 | requires_grad=False, 106 | ) 107 | mask_loop = autograd.Variable( 108 | torch.FloatTensor(mask_loop.tolist()), requires_grad=False 109 | ) 110 | 111 | if gpu_id > -1: 112 | adj_arc_in = adj_arc_in.cuda() 113 | adj_arc_out = adj_arc_out.cuda() 114 | adj_lab_in = adj_lab_in.cuda() 115 | adj_lab_out = adj_lab_out.cuda() 116 | mask_in = mask_in.cuda() 117 | mask_out = mask_out.cuda() 118 | mask_loop = mask_loop.cuda() 119 | return [ 120 | adj_arc_in, 121 | adj_arc_out, 122 | adj_lab_in, 123 | adj_lab_out, 124 | mask_in, 125 | mask_out, 126 | mask_loop, 127 | ] 128 | 129 | 130 | 131 | def get_indexes(data, word_to_idx, roles_to_idx): 132 | 133 | for d in data: 134 | 135 | d["i_text"] = [ 136 | word_to_idx[x.lower()] if x.lower() in word_to_idx else word_to_idx[""] 137 | for x in d["text"] 138 | ] 139 | d["lower_text"] = [x.lower() for x in d["text"]] 140 | 141 | d["i_labels"] = [ 142 | roles_to_idx[x] if x in roles_to_idx else roles_to_idx["O"] 143 | for x in d["labels"] 144 | ] 145 | 146 | return data 147 | 148 | 149 | def _get_word_dict_GCN(sentences): 150 | # create vocab of words 151 | word_dict = {} 152 | for sent in sentences: 153 | for word in sent: 154 | word = word.lower() 155 | if word not in word_dict: 156 | word_dict[word] = "" 157 | word_dict[""] = "" 158 | word_dict[""] = "" 159 | word_dict[""] = "" 160 | return word_dict 161 | 162 | 163 | def _get_glove_GCN(word_dict, glove_path): 164 | # create word_vec with glove vectors 165 | word_vec = {} 166 | avg = [] 167 | with open(glove_path, encoding="utf8") as f: 168 | for line in f: 169 | word, vec = line.split(" ", 1) 170 | if word in word_dict: 171 | word_vec[word] = np.array(list(map(float, vec.split()))) 172 | for word in word_vec: 173 | if avg == []: 174 | avg = word_vec[word] 175 | count = 1 176 | else: 177 | avg += word_vec[word] 178 | count += 1 179 | word_vec[""] = avg / count 180 | print( 181 | "Found {0}(/{1}) words with glove vectors".format(len(word_vec), len(word_dict)) 182 | ) 183 | return word_vec 184 | 185 | 186 | def build_vocab_GCN(sentences, glove_path): 187 | word_dict = _get_word_dict_GCN(sentences) 188 | word_vec = _get_glove_GCN(word_dict, glove_path) 189 | print("Vocab size : {0}".format(len(word_vec))) 190 | return word_vec 191 | 192 | 193 | def create_labels_voc(all_data): 194 | roles_to_idx = {"O": 0} 195 | idx_to_roles = ["O"] 196 | for example in all_data: 197 | for role in example["labels"]: 198 | if role not in roles_to_idx: 199 | idx_to_roles.append(role) 200 | roles_to_idx[role] = len(roles_to_idx) 201 | return roles_to_idx, idx_to_roles 202 | 203 | 204 | def create_constraints(all_data, roles_to_idx): 205 | constraints = set() 206 | for example in all_data: 207 | prev_role = "" 208 | for r, role in enumerate(example["labels"]): 209 | if r == 0: 210 | constraints.add((len(roles_to_idx), roles_to_idx[role])) 211 | if role[0] == "B": 212 | constraints.add((0, roles_to_idx[role])) 213 | constraints.add((roles_to_idx[role], 0)) 214 | elif role[0] == "I": 215 | constraints.add((roles_to_idx[role], 0)) 216 | 217 | prev_role = roles_to_idx[role] 218 | elif r == len(example["labels"]) - 1: 219 | constraints.add((prev_role, roles_to_idx[role])) 220 | constraints.add((roles_to_idx[role], len(roles_to_idx) + 1)) 221 | 222 | if role[0] == "B": 223 | constraints.add((0, roles_to_idx[role])) 224 | constraints.add((roles_to_idx[role], 0)) 225 | constraints.add((len(roles_to_idx), roles_to_idx[role])) 226 | elif role[0] == "I": 227 | constraints.add((roles_to_idx[role], 0)) 228 | else: 229 | constraints.add((prev_role, roles_to_idx[role])) 230 | if role[0] == "B": 231 | constraints.add((0, roles_to_idx[role])) 232 | constraints.add((roles_to_idx[role], 0)) 233 | constraints.add((len(roles_to_idx), roles_to_idx[role])) 234 | elif role[0] == "I": 235 | constraints.add((roles_to_idx[role], 0)) 236 | constraints.add((roles_to_idx[role], len(roles_to_idx) + 1)) 237 | prev_role = roles_to_idx[role] 238 | 239 | return list(constraints) 240 | 241 | 242 | def create_predicate_constraints(all_data, roles_to_idx): 243 | constraints = {} 244 | for example in all_data: 245 | pred = example["predicates"][0].split("#")[1] 246 | if pred not in constraints: 247 | constraints[pred] = set() 248 | for r, role in enumerate(example["labels"]): 249 | if roles_to_idx[role] not in constraints[pred]: 250 | constraints[pred].add(roles_to_idx[role]) 251 | 252 | return constraints 253 | 254 | 255 | def create_vocs(train): 256 | word_to_idx = {"": 0, "": 1, "": 2, "": 3} 257 | 258 | for example in train: 259 | for word in example["text"]: 260 | word = word.lower() 261 | if word not in word_to_idx: 262 | word_to_idx[word] = len(word_to_idx) 263 | return word_to_idx 264 | 265 | 266 | def read_data(data_file, w_c_to_idx, c_c_to_idx, dep_lb_to_idx): 267 | data = [] 268 | data_file_ = [] 269 | curr_sent = [] 270 | with open(data_file, "r") as f: 271 | text = [] 272 | pos = [] 273 | 274 | predicate_positions = ( 275 | [] 276 | ) # this contains the positions fo predicate in the sentence 277 | predicates = [] # this contains disambiguated predicates 278 | labels = [] # this contains lists of semantic roles, one for each predicate. 279 | line_count = 0 280 | inside = [] 281 | 282 | dep_head = [] 283 | dep_lb = [] 284 | 285 | word_to_constituents = [] 286 | stack_const = [] 287 | stack_num = [] 288 | curr_num = 500 289 | 290 | constituents_to_constituents = [] 291 | 292 | children = {} 293 | 294 | for line in f: 295 | line = " ".join(line.split()) 296 | line_split = line.strip().split() 297 | curr_sent.append(line_split) 298 | 299 | if len(line_split) > 2: 300 | 301 | text.append(line_split[0]) 302 | pos.append(line_split[1]) 303 | 304 | # Adds the word to constituents arcs 305 | if line_split[2].find("(") > -1: 306 | for const in line_split[2].split("(")[1:]: 307 | const = const.replace("*", "").replace(")", "") 308 | 309 | if const == "TOP": 310 | pass 311 | else: 312 | if const not in w_c_to_idx: 313 | w_c_to_idx[const] = len(w_c_to_idx) 314 | word_to_constituents.append( 315 | [w_c_to_idx[const], line_count, curr_num, 0] 316 | ) 317 | 318 | stack_num.append(curr_num) 319 | stack_const.append(const) 320 | curr_num += 1 321 | 322 | if line_split[2].find(")") > -1: 323 | for c in line_split[2]: 324 | if c == ")": 325 | num = stack_num.pop() 326 | const = stack_const.pop() 327 | 328 | if const == "TOP": 329 | pass 330 | else: 331 | if const not in w_c_to_idx: 332 | w_c_to_idx[const] = len(w_c_to_idx) 333 | word_to_constituents.append( 334 | [w_c_to_idx[const], line_count, num, 1] 335 | ) 336 | 337 | if len(stack_num) != 0: 338 | 339 | if stack_const[-1] == "TOP": 340 | pass 341 | else: 342 | 343 | if stack_const[-1] not in c_c_to_idx: 344 | c_c_to_idx[stack_const[-1]] = len(c_c_to_idx) 345 | constituents_to_constituents.append( 346 | [ 347 | c_c_to_idx[stack_const[-1]], 348 | stack_num[-1], 349 | num, 350 | 0, 351 | ] 352 | ) # from super to sub 353 | 354 | if stack_const[-1] not in children: 355 | children[stack_const[-1]] = [const] 356 | else: 357 | children[stack_const[-1]].append(const) 358 | 359 | if const == "TOP": 360 | pass 361 | else: 362 | if const not in c_c_to_idx: 363 | c_c_to_idx[const] = len(c_c_to_idx) 364 | constituents_to_constituents.append( 365 | [c_c_to_idx[const], num, stack_num[-1], 1] 366 | ) 367 | 368 | # Adds the dependency 369 | if line_split[3] != "-": 370 | dep_head.append(int(line_split[3])) 371 | dep_lb.append(str(line_split[4])) 372 | 373 | if str(line_split[4]) not in dep_lb_to_idx.keys(): 374 | dep_lb_to_idx[str(line_split[4])] = len(dep_lb_to_idx) 375 | 376 | 377 | # Adds the predicates 378 | if line_split[5] != "-": 379 | predicate_positions.append(line_count) 380 | predicates.append(line_split[5] + "." + line_split[4]) 381 | 382 | # Adds the roles 383 | if len(labels) != len(line_split[6:]): 384 | for _ in line_split[6:]: 385 | labels.append([]) 386 | inside.append(0) 387 | for l, label in enumerate(line_split[6:]): 388 | if label.find("(") > -1 and label.find(")") > -1: 389 | lab = label.split("*")[0][1:] 390 | labels[l].append("B-" + lab) 391 | 392 | elif label.find("(") > -1: 393 | if inside[l]: 394 | raise OSError("parsing error") 395 | else: 396 | lab = label.split("*")[0][1:] 397 | inside[l] = lab 398 | labels[l].append("B-" + lab) 399 | elif label.find(")") > -1: 400 | if not inside[l]: 401 | raise OSError("parsing error") 402 | else: 403 | labels[l].append("I-" + inside[l]) 404 | inside[l] = 0 405 | else: 406 | if inside[l]: 407 | labels[l].append("I-" + inside[l]) 408 | else: 409 | labels[l].append("O") 410 | 411 | line_count += 1 412 | else: 413 | data_file_ += curr_sent 414 | curr_sent = [] 415 | if len(predicate_positions) > 0: 416 | for p, pred in enumerate(predicates): 417 | data.append( 418 | { 419 | "text": text, 420 | "pos": pos, 421 | "predicate_position": predicate_positions[p], 422 | "predicates": predicates[p], 423 | "labels": labels[p], 424 | "dep_head": dep_head, 425 | "dep_lb": dep_lb, 426 | "word_to_constituents": word_to_constituents, 427 | "constituents_to_constituents": constituents_to_constituents, 428 | "number_constituents": curr_num - 500, 429 | } 430 | ) 431 | text = [] 432 | pos = [] 433 | word_to_constituents = [] 434 | constituents_to_constituents = [] 435 | children = {} 436 | 437 | dep_head = [] 438 | dep_lb = [] 439 | 440 | predicate_positions = [] 441 | predicates = [] 442 | labels = [] 443 | line_count = 0 444 | inside = [] 445 | 446 | stack_const = [] 447 | stack_num = [] 448 | curr_num = 500 449 | return data, data_file_, w_c_to_idx, c_c_to_idx, dep_lb_to_idx 450 | 451 | 452 | 453 | 454 | def read_data_file(data_file): 455 | data = [] 456 | with open(data_file, "r") as f: 457 | for line in f: 458 | line_split = line.strip().split() 459 | data.append(line_split) 460 | return data 461 | 462 | 463 | def get_batch_sup( 464 | batch, word_vec, gpu_id, lstm_hidden_dim, word_drop, bert_tokenizer, bert_model 465 | ): 466 | 467 | max_sent_len = 0 468 | max_const_len = 0 469 | lengths = [] 470 | for d in batch: 471 | max_sent_len = max(len(d["text"]), max_sent_len) 472 | max_const_len = max(d["number_constituents"], max_const_len) 473 | batch_len = len(batch) 474 | sentences = np.zeros((batch_len, max_sent_len)) 475 | predicate_flags = np.zeros((batch_len, max_sent_len)) 476 | labels = np.zeros((batch_len, max_sent_len)) 477 | dependency_arcs = np.zeros((batch_len, max_sent_len, max_sent_len)) 478 | dependency_labels = np.zeros((batch_len, max_sent_len)) 479 | 480 | mask = np.zeros((batch_len, max_sent_len)) 481 | fixed_embs = np.zeros((batch_len, max_sent_len, 100)) 482 | bert_embs = np.zeros((batch_len, max_sent_len, 768)) 483 | predicate_index = np.zeros((batch_len, max_sent_len)) 484 | 485 | constituent_labels = np.zeros((batch_len, max_const_len, lstm_hidden_dim)) 486 | 487 | const_mask = np.zeros((batch_len, max_const_len)) 488 | 489 | plain_sentences = [] 490 | 491 | for d, data in enumerate(batch): 492 | num_const = data["number_constituents"] 493 | const_mask[d][:num_const] = 1.0 494 | 495 | predicate_flags[d, data["predicate_position"]] = 1.0 496 | for w, word in enumerate(data["i_text"]): 497 | predicate_index[d, w] = data["predicate_position"] 498 | if np.random.rand() > word_drop: 499 | sentences[d, w] = word 500 | else: 501 | sentences[d, w] = 1 # UNK 502 | mask[d, w] = 1.0 503 | labels[d, w] = data["i_labels"][w] 504 | word_lower = data["lower_text"][w] 505 | 506 | dependency_labels[d, w] = data["dep_lb"][w] 507 | dependency_arcs[d, w, w] = 1 508 | dependency_arcs[d, w, data["dep_head"][w]] = 1 509 | dependency_arcs[d, data["dep_head"][w], w] = 1 510 | 511 | if word_lower in word_vec: 512 | fixed_embs[d, w, :] = word_vec[word_lower] 513 | else: 514 | fixed_embs[d, w, :] = word_vec[""] 515 | 516 | plain_sentences.append(data["text"]) 517 | lengths.append(len(data["i_text"])) 518 | 519 | batch_w_c = [] 520 | for d in batch: 521 | batch_w_c.append([]) 522 | for i in d["word_to_constituents"]: 523 | batch_w_c[-1].append([]) 524 | for j in i: 525 | batch_w_c[-1][-1].append(j) 526 | 527 | batch_c_c = [] 528 | for d in batch: 529 | batch_c_c.append([]) 530 | for i in d["word_to_constituents"]: 531 | batch_c_c[-1].append([]) 532 | for j in i: 533 | batch_c_c[-1][-1].append(j) 534 | 535 | for d, _ in enumerate(batch): 536 | for t, trip in enumerate(batch_w_c[d]): 537 | for e, elem in enumerate(trip): 538 | if elem > 499: 539 | batch_w_c[d][t][e] = (elem - 500) + max_sent_len 540 | 541 | for t, trip in enumerate(batch_c_c[d]): 542 | for e, elem in enumerate(trip): 543 | if elem > 499: 544 | batch_c_c[d][t][e] = (elem - 500) + max_sent_len 545 | 546 | const_GCN_w_c = get_const_adj_BE( 547 | batch_w_c, max_sent_len + max_const_len, gpu_id, 2, 2, forward=True 548 | ) 549 | const_GCN_c_w = get_const_adj_BE( 550 | batch_w_c, max_sent_len + max_const_len, gpu_id, 5, 20, forward=False 551 | ) 552 | 553 | const_GCN_c_c = get_const_adj_BE( 554 | batch_c_c, max_sent_len + max_const_len, gpu_id, 2, 7, forward=True 555 | ) 556 | 557 | if gpu_id > -1: 558 | cuda = True 559 | else: 560 | cuda = False 561 | 562 | # BERT 563 | if bert_model is not None: 564 | 565 | bert_encoded_sentences = [ 566 | bert_tokenizer.encode(" ".join(sent), add_special_tokens=True) 567 | for sent in plain_sentences 568 | ] 569 | bert_tokenized_sentences = [ 570 | bert_tokenizer.convert_ids_to_tokens(bert_encoded_sentence) 571 | for bert_encoded_sentence in bert_encoded_sentences 572 | ] 573 | input_ids = torch.nn.utils.rnn.pad_sequence( 574 | [ 575 | torch.tensor(bert_encoded_sentence) 576 | for bert_encoded_sentence in bert_encoded_sentences 577 | ], 578 | padding_value=-1, 579 | batch_first=True, 580 | ) 581 | if cuda: 582 | input_ids = input_ids.cuda() 583 | 584 | attention_mask = input_ids == -1 585 | input_ids = input_ids.masked_fill(attention_mask, 2) 586 | attention_mask = (attention_mask.float() - 1).abs() 587 | if cuda: 588 | input_ids = input_ids.cuda() 589 | attention_mask = attention_mask.cuda() 590 | with torch.no_grad(): 591 | bert_last_vectors = bert_model(input_ids, attention_mask=attention_mask)[0] 592 | if cuda: 593 | bert_last_vectors = bert_last_vectors.cuda() 594 | 595 | for s, sent in enumerate(bert_last_vectors): 596 | real_index = 0 597 | for i, bert_token in enumerate(bert_tokenized_sentences[s]): 598 | if ( 599 | bert_token == "" 600 | or bert_token == "" 601 | or not bert_token.startswith("Ġ") 602 | ): 603 | pass 604 | else: 605 | bert_embs[s, real_index, :] = sent[i, :].cpu() 606 | real_index += 1 607 | 608 | labels_batch = _make_VariableLong(labels, cuda, False) 609 | sentences_batch = _make_VariableLong(sentences, cuda, False) 610 | predicate_flags_batch = _make_VariableFloat(predicate_flags, cuda, False) 611 | mask_batch = _make_VariableFloat(mask, cuda, False) 612 | mask_const_batch = _make_VariableFloat(const_mask, cuda, False) 613 | lengths_batch = _make_VariableLong(lengths, cuda, False) 614 | fixed_embs = _make_VariableFloat(fixed_embs, cuda, False) 615 | bert_embs = _make_VariableFloat(bert_embs, cuda, False) 616 | constituent_labels = _make_VariableFloat(constituent_labels, cuda, False) 617 | predicate_index_batch = _make_VariableLong(predicate_index, cuda, False) 618 | 619 | return ( 620 | labels_batch, 621 | sentences_batch, 622 | predicate_flags_batch, 623 | mask_batch, 624 | lengths_batch, 625 | fixed_embs, 626 | dependency_arcs, 627 | dependency_labels, 628 | constituent_labels, 629 | const_GCN_w_c, 630 | const_GCN_c_w, 631 | const_GCN_c_c, 632 | mask_const_batch, 633 | predicate_index_batch, 634 | bert_embs, 635 | ) 636 | 637 | 638 | def _make_VariableFloat(numpy_obj, cuda, requires_grad): 639 | if cuda: 640 | return Variable( 641 | torch.FloatTensor(numpy_obj), requires_grad=requires_grad 642 | ).cuda() 643 | else: 644 | return Variable(torch.FloatTensor(numpy_obj), requires_grad=requires_grad) 645 | 646 | 647 | def _make_VariableLong(numpy_obj, cuda, requires_grad): 648 | if cuda: 649 | return Variable(torch.LongTensor(numpy_obj), requires_grad=requires_grad).cuda() 650 | else: 651 | return Variable(torch.LongTensor(numpy_obj), requires_grad=requires_grad) 652 | -------------------------------------------------------------------------------- /engine/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from typing import List, Tuple, Dict 5 | import numpy as np 6 | from torch.autograd import Variable 7 | 8 | 9 | class DepGCN(nn.Module): 10 | """ 11 | Label-aware Dependency Convolutional Neural Network Layer 12 | """ 13 | 14 | def __init__(self, dep_num, dep_dim, in_features, out_features): 15 | super(DepGCN, self).__init__() 16 | self.dep_dim = dep_dim 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | 20 | self.dep_embedding = nn.Embedding(dep_num, dep_dim, padding_idx=0) 21 | 22 | self.dep_attn = nn.Linear(dep_dim + in_features, out_features) 23 | self.dep_fc = nn.Linear(dep_dim, out_features) 24 | self.relu = nn.ReLU() 25 | 26 | def forward(self, text, dep_mat, dep_labels): 27 | dep_label_embed = self.dep_embedding(dep_labels) 28 | 29 | batch_size, seq_len, feat_dim = text.shape 30 | 31 | val_us = text.unsqueeze(dim=2) 32 | val_us = val_us.repeat(1, 1, seq_len, 1) 33 | 34 | val_sum = torch.cat([val_us, dep_label_embed], dim=-1) 35 | 36 | r = self.dep_attn(val_sum) 37 | 38 | p = torch.sum(r, dim=-1) 39 | mask = (dep_mat == 0).float() * (-1e30) 40 | p = p + mask 41 | p = torch.softmax(p, dim=2) 42 | p_us = p.unsqueeze(3).repeat(1, 1, 1, feat_dim) 43 | 44 | output = val_us + self.dep_fc(dep_label_embed) 45 | output = torch.mul(p_us, output) 46 | 47 | output_sum = torch.sum(output, dim=2) 48 | output_sum = self.relu(output_sum) 49 | 50 | return output_sum 51 | 52 | 53 | class ConstGCN(nn.Module): 54 | """ 55 | Label-aware Constituency Convolutional Neural Network Layer 56 | """ 57 | 58 | def __init__( 59 | self, 60 | num_inputs, 61 | num_units, 62 | num_labels, 63 | dropout=0.0, 64 | in_arcs=True, 65 | out_arcs=True, 66 | batch_first=False, 67 | use_gates=True, 68 | residual=False, 69 | no_loop=False, 70 | non_linearity="relu", 71 | edge_dropout=0.0, 72 | ): 73 | super(ConstGCN, self).__init__() 74 | 75 | self.in_arcs = in_arcs 76 | self.out_arcs = out_arcs 77 | self.no_loop = no_loop 78 | self.retain = 1.0 - edge_dropout 79 | self.num_inputs = num_inputs 80 | self.num_units = num_units 81 | self.num_labels = num_labels 82 | self.batch_first = batch_first 83 | self.non_linearity = non_linearity 84 | self.sigmoid = nn.Sigmoid() 85 | self.use_gates = use_gates 86 | self.residual = residual 87 | self.dropout = nn.Dropout(p=dropout) 88 | self.layernorm = nn.LayerNorm(num_units) 89 | 90 | if in_arcs: 91 | self.V_in = Parameter(torch.Tensor(self.num_inputs, self.num_units)) 92 | nn.init.xavier_normal_(self.V_in) 93 | 94 | self.b_in = Parameter(torch.Tensor(num_labels, self.num_units)) 95 | nn.init.constant_(self.b_in, 0) 96 | 97 | if self.use_gates: 98 | self.V_in_gate = Parameter(torch.Tensor(self.num_inputs, 1)) 99 | nn.init.xavier_normal_(self.V_in_gate) 100 | self.b_in_gate = Parameter(torch.Tensor(num_labels, 1)) 101 | nn.init.constant_(self.b_in_gate, 1) 102 | 103 | if out_arcs: 104 | # self.V_out = autograd.Variable(torch.FloatTensor(self.num_inputs, self.num_units)) 105 | self.V_out = Parameter(torch.Tensor(self.num_inputs, self.num_units)) 106 | nn.init.xavier_normal_(self.V_out) 107 | 108 | # self.b_out = autograd.Variable(torch.FloatTensor(num_labels, self.num_units)) 109 | self.b_out = Parameter(torch.Tensor(num_labels, self.num_units)) 110 | nn.init.constant_(self.b_out, 0) 111 | 112 | if self.use_gates: 113 | self.V_out_gate = Parameter(torch.Tensor(self.num_inputs, 1)) 114 | nn.init.xavier_normal_(self.V_out_gate) 115 | self.b_out_gate = Parameter(torch.Tensor(num_labels, 1)) 116 | nn.init.constant_(self.b_out_gate, 1) 117 | if not self.no_loop: 118 | self.W_self_loop = Parameter(torch.Tensor(self.num_inputs, self.num_units)) 119 | nn.init.xavier_normal_(self.W_self_loop) 120 | 121 | if self.use_gates: 122 | self.W_self_loop_gate = Parameter(torch.Tensor(self.num_inputs, 1)) 123 | nn.init.xavier_normal_(self.W_self_loop_gate) 124 | 125 | def forward( 126 | self, 127 | src, 128 | arc_tensor_in=None, 129 | arc_tensor_out=None, 130 | label_tensor_in=None, 131 | label_tensor_out=None, 132 | mask_in=None, 133 | mask_out=None, 134 | mask_loop=None, 135 | sent_mask=None, 136 | ): 137 | 138 | if not self.batch_first: 139 | encoder_outputs = src.permute(1, 0, 2).contiguous() 140 | else: 141 | encoder_outputs = src.contiguous() 142 | 143 | batch_size = encoder_outputs.size()[0] 144 | seq_len = encoder_outputs.size()[1] 145 | max_degree = 1 146 | input_ = encoder_outputs.view( 147 | (batch_size * seq_len, self.num_inputs) 148 | ) # [b* t, h] 149 | input_ = self.dropout(input_) 150 | if self.in_arcs: 151 | input_in = torch.mm(input_, self.V_in) # [b* t, h] * [h,h] = [b*t, h] 152 | first_in = input_in.index_select( 153 | 0, arc_tensor_in[0] * seq_len + arc_tensor_in[1] 154 | ) # [b* t* degr, h] 155 | second_in = self.b_in.index_select(0, label_tensor_in[0]) # [b* t* degr, h] 156 | in_ = first_in + second_in 157 | degr = int(first_in.size()[0] / batch_size // seq_len) 158 | in_ = in_.view((batch_size, seq_len, degr, self.num_units)) 159 | if self.use_gates: 160 | # compute gate weights 161 | input_in_gate = torch.mm( 162 | input_, self.V_in_gate 163 | ) # [b* t, h] * [h,h] = [b*t, h] 164 | first_in_gate = input_in_gate.index_select( 165 | 0, arc_tensor_in[0] * seq_len + arc_tensor_in[1] 166 | ) # [b* t* mxdeg, h] 167 | second_in_gate = self.b_in_gate.index_select(0, label_tensor_in[0]) 168 | in_gate = (first_in_gate + second_in_gate).view( 169 | (batch_size, seq_len, degr) 170 | ) 171 | 172 | max_degree += degr 173 | 174 | if self.out_arcs: 175 | input_out = torch.mm(input_, self.V_out) # [b* t, h] * [h,h] = [b* t, h] 176 | first_out = input_out.index_select( 177 | 0, arc_tensor_out[0] * seq_len + arc_tensor_out[1] 178 | ) # [b* t* mxdeg, h] 179 | second_out = self.b_out.index_select(0, label_tensor_out[0]) 180 | 181 | degr = int(first_out.size()[0] / batch_size // seq_len) 182 | max_degree += degr 183 | 184 | out_ = (first_out + second_out).view( 185 | (batch_size, seq_len, degr, self.num_units) 186 | ) 187 | 188 | if self.use_gates: 189 | # compute gate weights 190 | input_out_gate = torch.mm( 191 | input_, self.V_out_gate 192 | ) # [b* t, h] * [h,h] = [b* t, h] 193 | first_out_gate = input_out_gate.index_select( 194 | 0, arc_tensor_out[0] * seq_len + arc_tensor_out[1] 195 | ) # [b* t* mxdeg, h] 196 | second_out_gate = self.b_out_gate.index_select(0, label_tensor_out[0]) 197 | out_gate = (first_out_gate + second_out_gate).view( 198 | (batch_size, seq_len, degr) 199 | ) 200 | if self.no_loop: 201 | if self.in_arcs and self.out_arcs: 202 | potentials = torch.cat((in_, out_), dim=2) # [b, t, mxdeg, h] 203 | if self.use_gates: 204 | potentials_gate = torch.cat( 205 | (in_gate, out_gate), dim=2 206 | ) # [b, t, mxdeg, h] 207 | mask_soft = torch.cat((mask_in, mask_out), dim=1) # [b* t, mxdeg] 208 | elif self.out_arcs: 209 | potentials = out_ # [b, t, 2*mxdeg+1, h] 210 | if self.use_gates: 211 | potentials_gate = out_gate # [b, t, mxdeg, h] 212 | mask_soft = mask_out # [b* t, mxdeg] 213 | elif self.in_arcs: 214 | potentials = in_ # [b, t, 2*mxdeg+1, h] 215 | if self.use_gates: 216 | potentials_gate = in_gate # [b, t, mxdeg, h] 217 | mask_soft = mask_in # [b* t, mxdeg] 218 | max_degree -= 1 219 | else: 220 | same_input = torch.mm(input_, self.W_self_loop).view( 221 | encoder_outputs.size(0), encoder_outputs.size(1), -1 222 | ) 223 | same_input = same_input.view( 224 | encoder_outputs.size(0), 225 | encoder_outputs.size(1), 226 | 1, 227 | self.W_self_loop.size(1), 228 | ) 229 | if self.use_gates: 230 | same_input_gate = torch.mm(input_, self.W_self_loop_gate).view( 231 | encoder_outputs.size(0), encoder_outputs.size(1), -1 232 | ) 233 | 234 | if self.in_arcs and self.out_arcs: 235 | potentials = torch.cat( 236 | (in_, out_, same_input), dim=2 237 | ) # [b, t, mxdeg, h] 238 | if self.use_gates: 239 | potentials_gate = torch.cat( 240 | (in_gate, out_gate, same_input_gate), dim=2 241 | ) # [b, t, mxdeg, h] 242 | mask_soft = torch.cat( 243 | (mask_in, mask_out, mask_loop), dim=1 244 | ) # [b* t, mxdeg] 245 | elif self.out_arcs: 246 | potentials = torch.cat( 247 | (out_, same_input), dim=2 248 | ) # [b, t, 2*mxdeg+1, h] 249 | if self.use_gates: 250 | potentials_gate = torch.cat( 251 | (out_gate, same_input_gate), dim=2 252 | ) # [b, t, mxdeg, h] 253 | mask_soft = torch.cat((mask_out, mask_loop), dim=1) # [b* t, mxdeg] 254 | elif self.in_arcs: 255 | potentials = torch.cat( 256 | (in_, same_input), dim=2 257 | ) # [b, t, 2*mxdeg+1, h] 258 | if self.use_gates: 259 | potentials_gate = torch.cat( 260 | (in_gate, same_input_gate), dim=2 261 | ) # [b, t, mxdeg, h] 262 | mask_soft = torch.cat((mask_in, mask_loop), dim=1) # [b* t, mxdeg] 263 | else: 264 | potentials = same_input # [b, t, 2*mxdeg+1, h] 265 | if self.use_gates: 266 | potentials_gate = same_input_gate # [b, t, mxdeg, h] 267 | mask_soft = mask_loop # [b* t, mxdeg] 268 | 269 | potentials_resh = potentials.view( 270 | (batch_size * seq_len, max_degree, self.num_units) 271 | ) # [h, b * t, mxdeg] 272 | 273 | if self.use_gates: 274 | potentials_r = potentials_gate.view( 275 | (batch_size * seq_len, max_degree) 276 | ) # [b * t, mxdeg] 277 | probs_det_ = (self.sigmoid(potentials_r) * mask_soft).unsqueeze( 278 | 2 279 | ) # [b * t, mxdeg] 280 | 281 | potentials_masked = potentials_resh * probs_det_ # [b * t, mxdeg,h] 282 | else: 283 | # NO Gates 284 | potentials_masked = potentials_resh * mask_soft.unsqueeze(2) 285 | 286 | if self.retain == 1 or not self.training: 287 | pass 288 | else: 289 | mat_1 = torch.Tensor(mask_soft.data.size()).uniform_(0, 1) 290 | ret = torch.Tensor([self.retain]) 291 | mat_2 = (mat_1 < ret).float() 292 | drop_mask = Variable(mat_2, requires_grad=False) 293 | if potentials_resh.is_cuda: 294 | drop_mask = drop_mask.cuda() 295 | 296 | potentials_masked *= drop_mask.unsqueeze(2) 297 | 298 | potentials_masked_ = potentials_masked.sum(dim=1) # [b * t, h] 299 | 300 | potentials_masked_ = self.layernorm(potentials_masked_) * sent_mask.view( 301 | batch_size * seq_len 302 | ).unsqueeze(1) 303 | 304 | potentials_masked_ = self.non_linearity(potentials_masked_) # [b * t, h] 305 | 306 | result_ = potentials_masked_.view( 307 | (batch_size, seq_len, self.num_units) 308 | ) # [ b, t, h] 309 | 310 | result_ = result_ * sent_mask.unsqueeze(2) # [b, t, h] 311 | memory_bank = result_ # [t, b, h] 312 | 313 | if self.residual: 314 | memory_bank += src 315 | 316 | return memory_bank 317 | 318 | 319 | class BilinearScorer(nn.Module): 320 | def __init__(self, hidden_dim, role_vocab_size, dropout=0.0, gpu_id=-1): 321 | super(BilinearScorer, self).__init__() 322 | 323 | if gpu_id > -1: 324 | self.use_gpu = True 325 | else: 326 | self.use_gpu = False 327 | self.hidden_dim = hidden_dim 328 | self.role_vocab_size = role_vocab_size 329 | 330 | self.dropout = nn.Dropout(p=dropout) 331 | 332 | self.U = Parameter( 333 | torch.Tensor(self.hidden_dim, self.role_vocab_size, self.hidden_dim) 334 | ) 335 | nn.init.orthogonal_(self.U) 336 | 337 | self.bias1 = Parameter(torch.Tensor(1, self.hidden_dim * self.role_vocab_size)) 338 | nn.init.constant_(self.bias1, 0) 339 | self.bias2 = Parameter(torch.Tensor(1, self.role_vocab_size)) 340 | nn.init.constant_(self.bias2, 0) 341 | 342 | def forward(self, pred_input, args_input): 343 | 344 | b, t, h = pred_input.data.shape 345 | pred_input = self.dropout(pred_input) 346 | args_input = self.dropout(args_input) 347 | 348 | first = ( 349 | torch.mm(pred_input.view(-1, h), self.U.view(h, -1)) + self.bias1 350 | ) # [b*t, h] * [h,r*h] = [b*t,r*h] 351 | 352 | out = torch.bmm( 353 | first.view(-1, self.role_vocab_size, h), args_input.view(-1, h).unsqueeze(2) 354 | ) # [b*t,r,h] [b*t, h, 1] = [b*t, r] 355 | out = out.squeeze(2) + self.bias2 356 | return out 357 | 358 | 359 | class ScaledDotProductAttention(nn.Module): 360 | def __init__(self, d_k): 361 | super(ScaledDotProductAttention, self).__init__() 362 | self.d_k = d_k 363 | 364 | def forward(self, q, k, v, attn_mask): 365 | attn_score = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.d_k) 366 | attn_score.masked_fill_(attn_mask, -1e9) 367 | 368 | attn_weights = nn.Softmax(dim=-1)(attn_score) 369 | 370 | output = torch.matmul(attn_weights, v) 371 | 372 | return output, attn_weights 373 | 374 | 375 | class MultiHeadAttention(nn.Module): 376 | def __init__(self, d_model, n_heads): 377 | super(MultiHeadAttention, self).__init__() 378 | self.n_heads = n_heads 379 | self.d_k = self.d_v = d_model // n_heads 380 | 381 | self.WQ = nn.Linear(d_model, d_model) 382 | self.WK = nn.Linear(d_model, d_model) 383 | self.WV = nn.Linear(d_model, d_model) 384 | self.scaled_dot_product_attn = ScaledDotProductAttention(self.d_k) 385 | self.linear = nn.Linear(n_heads * self.d_v, d_model) 386 | 387 | def forward(self, Q, K, V, attn_mask): 388 | batch_size = Q.size(0) 389 | 390 | q_heads = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) 391 | k_heads = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) 392 | v_heads = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) 393 | 394 | attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) 395 | attn, attn_weights = self.scaled_dot_product_attn(q_heads, k_heads, v_heads, attn_mask) 396 | 397 | attn = attn.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) 398 | output = self.linear(attn) 399 | 400 | return output, attn_weights 401 | 402 | 403 | class PositionWiseFeedForwardNetwork(nn.Module): 404 | def __init__(self, d_model, d_ff): 405 | super(PositionWiseFeedForwardNetwork, self).__init__() 406 | 407 | self.linear1 = nn.Linear(d_model, d_ff) 408 | self.linear2 = nn.Linear(d_ff, d_model) 409 | self.relu = nn.ReLU() 410 | 411 | def forward(self, inputs): 412 | output = self.relu(self.linear1(inputs)) 413 | output = self.linear2(output) 414 | 415 | return output 416 | 417 | 418 | class EncoderLayer(nn.Module): 419 | def __init__(self, d_model, n_heads, p_drop, d_ff): 420 | super(EncoderLayer, self).__init__() 421 | 422 | self.mha = MultiHeadAttention(d_model, n_heads) 423 | self.dropout1 = nn.Dropout(p_drop) 424 | self.layernorm1 = nn.LayerNorm(d_model, eps=1e-6) 425 | 426 | self.ffn = PositionWiseFeedForwardNetwork(d_model, d_ff) 427 | self.dropout2 = nn.Dropout(p_drop) 428 | self.layernorm2 = nn.LayerNorm(d_model, eps=1e-6) 429 | 430 | def forward(self, inputs, attn_mask): 431 | attn_outputs, attn_weights = self.mha(inputs, inputs, inputs, attn_mask) 432 | attn_outputs = self.dropout1(attn_outputs) 433 | attn_outputs = self.layernorm1(inputs + attn_outputs) 434 | 435 | ffn_outputs = self.ffn(attn_outputs) 436 | ffn_outputs = self.dropout2(ffn_outputs) 437 | ffn_outputs = self.layernorm2(attn_outputs + ffn_outputs) 438 | 439 | return ffn_outputs, attn_weights 440 | 441 | 442 | class TransformerEncoder(nn.Module): 443 | def __init__(self, vocab_size, seq_len=300, d_model=768, n_layers=3, n_heads=8, p_drop=0.1, d_ff=500, pad_id=0): 444 | super(TransformerEncoder, self).__init__() 445 | self.pad_id = pad_id 446 | self.sinusoid_table = self.get_sinusoid_table(seq_len + 1, d_model) # (seq_len+1, d_model) 447 | 448 | self.embedding = nn.Embedding(vocab_size, d_model) 449 | self.pos_embedding = nn.Embedding.from_pretrained(self.sinusoid_table, freeze=True) 450 | self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, p_drop, d_ff) for _ in range(n_layers)]) 451 | 452 | def forward(self, inputs): 453 | positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).repeat(inputs.size(0), 1) + 1 454 | position_pad_mask = inputs.eq(self.pad_id) 455 | positions.masked_fill_(position_pad_mask, 0) 456 | 457 | outputs = self.embedding(inputs) + self.pos_embedding(positions) 458 | 459 | attn_pad_mask = self.get_attention_padding_mask(inputs, inputs, self.pad_id) 460 | 461 | for layer in self.layers: 462 | outputs, attn_weights = layer(outputs, attn_pad_mask) 463 | 464 | return outputs 465 | 466 | def get_attention_padding_mask(self, q, k, pad_id): 467 | attn_pad_mask = k.eq(pad_id).unsqueeze(1).repeat(1, q.size(1), 1) 468 | 469 | return attn_pad_mask 470 | 471 | def get_sinusoid_table(self, seq_len, d_model): 472 | def get_angle(pos, i, d_model): 473 | return pos / np.power(10000, (2 * (i // 2)) / d_model) 474 | 475 | sinusoid_table = np.zeros((seq_len, d_model)) 476 | for pos in range(seq_len): 477 | for i in range(d_model): 478 | if i % 2 == 0: 479 | sinusoid_table[pos, i] = np.sin(get_angle(pos, i, d_model)) 480 | else: 481 | sinusoid_table[pos, i] = np.cos(get_angle(pos, i, d_model)) 482 | 483 | return torch.FloatTensor(sinusoid_table) 484 | 485 | 486 | def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]: 487 | """ 488 | Given labels and a constraint type, returns the allowed transitions. It will 489 | additionally include transitions for the start and end states, which are used 490 | by the conditional random field. 491 | 492 | Parameters 493 | ---------- 494 | constraint_type : ``str``, required 495 | Indicates which constraint to apply. Current choices are 496 | "BIO", "IOB1", "BIOUL", and "BMES". 497 | labels : ``Dict[int, str]``, required 498 | A mapping {label_id -> label}. Most commonly this would be the value from 499 | Vocabulary.get_index_to_token_vocabulary() 500 | 501 | Returns 502 | ------- 503 | ``List[Tuple[int, int]]`` 504 | The allowed transitions (from_label_id, to_label_id). 505 | """ 506 | num_labels = len(labels) 507 | start_tag = num_labels 508 | end_tag = num_labels + 1 509 | labels_with_boundaries = list(labels.items()) + [(start_tag, "START"), (end_tag, "END")] 510 | 511 | allowed = [] 512 | for from_label_index, from_label in labels_with_boundaries: 513 | if from_label in ("START", "END"): 514 | from_tag = from_label 515 | from_entity = "" 516 | else: 517 | from_tag = from_label[0] 518 | from_entity = from_label[1:] 519 | for to_label_index, to_label in labels_with_boundaries: 520 | if to_label in ("START", "END"): 521 | to_tag = to_label 522 | to_entity = "" 523 | else: 524 | to_tag = to_label[0] 525 | to_entity = to_label[1:] 526 | if is_transition_allowed(constraint_type, from_tag, from_entity, 527 | to_tag, to_entity): 528 | allowed.append((from_label_index, to_label_index)) 529 | return allowed 530 | 531 | 532 | def is_transition_allowed(constraint_type: str, 533 | from_tag: str, 534 | from_entity: str, 535 | to_tag: str, 536 | to_entity: str): 537 | """ 538 | Given a constraint type and strings ``from_tag`` and ``to_tag`` that 539 | represent the origin and destination of the transition, return whether 540 | the transition is allowed under the given constraint type. 541 | 542 | Parameters 543 | ---------- 544 | constraint_type : ``str``, required 545 | Indicates which constraint to apply. Current choices are 546 | "BIO", "IOB1", "BIOUL", and "BMES". 547 | from_tag : ``str``, required 548 | The tag that the transition originates from. For example, if the 549 | label is ``I-PER``, the ``from_tag`` is ``I``. 550 | from_entity: ``str``, required 551 | The entity corresponding to the ``from_tag``. For example, if the 552 | label is ``I-PER``, the ``from_entity`` is ``PER``. 553 | to_tag : ``str``, required 554 | The tag that the transition leads to. For example, if the 555 | label is ``I-PER``, the ``to_tag`` is ``I``. 556 | to_entity: ``str``, required 557 | The entity corresponding to the ``to_tag``. For example, if the 558 | label is ``I-PER``, the ``to_entity`` is ``PER``. 559 | 560 | Returns 561 | ------- 562 | ``bool`` 563 | Whether the transition is allowed under the given ``constraint_type``. 564 | """ 565 | # pylint: disable=too-many-return-statements 566 | if to_tag == "START" or from_tag == "END": 567 | return False 568 | 569 | if constraint_type == "BIOUL": 570 | if from_tag == "START": 571 | return to_tag in ('O', 'B', 'U') 572 | if to_tag == "END": 573 | return from_tag in ('O', 'L', 'U') 574 | return any([ 575 | from_tag in ('O', 'L', 'U') and to_tag in ('O', 'B', 'U'), 576 | from_tag in ('B', 'I') and to_tag in ('I', 'L') and from_entity == to_entity 577 | ]) 578 | elif constraint_type == "BIO": 579 | if from_tag == "START": 580 | return to_tag in ('O', 'B') 581 | if to_tag == "END": 582 | return from_tag in ('O', 'B', 'I') 583 | return any([ 584 | to_tag in ('O', 'B'), 585 | to_tag == 'I' and from_tag in ('B', 'I') and from_entity == to_entity 586 | ]) 587 | elif constraint_type == "IOB1": 588 | if from_tag == "START": 589 | return to_tag in ('O', 'I') 590 | if to_tag == "END": 591 | return from_tag in ('O', 'B', 'I') 592 | return any([ 593 | to_tag in ('O', 'I'), 594 | to_tag == 'B' and from_tag in ('B', 'I') and from_entity == to_entity 595 | ]) 596 | elif constraint_type == "BMES": 597 | if from_tag == "START": 598 | return to_tag in ('B', 'S') 599 | if to_tag == "END": 600 | return from_tag in ('E', 'S') 601 | return any([ 602 | to_tag in ('B', 'S') and from_tag in ('E', 'S'), 603 | to_tag == 'M' and from_tag == 'B' and from_entity == to_entity, 604 | to_tag == 'E' and from_tag in ('B', 'M') and from_entity == to_entity, 605 | ]) 606 | else: 607 | raise IOError("Unknown constraint type: {constraint_type}") 608 | 609 | 610 | class CRF(torch.nn.Module): 611 | def __init__(self, 612 | num_tags: int, 613 | constraints: List[Tuple[int, int]] = None, 614 | include_start_end_transitions: bool = True) -> None: 615 | super().__init__() 616 | self.num_tags = num_tags 617 | 618 | self.transitions = torch.nn.Parameter(torch.Tensor(num_tags, num_tags)) 619 | 620 | if constraints is None: 621 | constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(1.) 622 | else: 623 | constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(0.) 624 | for i, j in constraints: 625 | constraint_mask[i, j] = 1. 626 | 627 | self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False) 628 | 629 | self.include_start_end_transitions = include_start_end_transitions 630 | if include_start_end_transitions: 631 | self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) 632 | self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) 633 | 634 | self.reset_parameters() 635 | 636 | def reset_parameters(self): 637 | torch.nn.init.xavier_normal_(self.transitions) 638 | if self.include_start_end_transitions: 639 | torch.nn.init.normal_(self.start_transitions) 640 | torch.nn.init.normal_(self.end_transitions) 641 | 642 | def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 643 | """ 644 | Computes the (batch_size,) denominator term for the log-likelihood, which is the 645 | sum of the likelihoods across all possible state sequences. 646 | """ 647 | batch_size, sequence_length, num_tags = logits.size() 648 | 649 | mask = mask.float().transpose(0, 1).contiguous() 650 | logits = logits.transpose(0, 1).contiguous() 651 | 652 | if self.include_start_end_transitions: 653 | alpha = self.start_transitions.view(1, num_tags) + logits[0] 654 | else: 655 | alpha = logits[0] 656 | 657 | for i in range(1, sequence_length): 658 | emit_scores = logits[i].view(batch_size, 1, num_tags) 659 | transition_scores = self.transitions.view(1, num_tags, num_tags) 660 | broadcast_alpha = alpha.view(batch_size, num_tags, 1) 661 | 662 | inner = broadcast_alpha + emit_scores + transition_scores 663 | 664 | alpha = (logsumexp(inner, 1) * mask[i].view(batch_size, 1) + 665 | alpha * (1 - mask[i]).view(batch_size, 1)) 666 | 667 | if self.include_start_end_transitions: 668 | stops = alpha + self.end_transitions.view(1, num_tags) 669 | else: 670 | stops = alpha 671 | 672 | return logsumexp(stops) 673 | 674 | def _joint_likelihood(self, 675 | logits: torch.Tensor, 676 | tags: torch.Tensor, 677 | mask: torch.LongTensor) -> torch.Tensor: 678 | """ 679 | Computes the numerator term for the log-likelihood, which is just score(inputs, tags) 680 | """ 681 | batch_size, sequence_length, _ = logits.data.shape 682 | 683 | logits = logits.transpose(0, 1).contiguous() 684 | mask = mask.float().transpose(0, 1).contiguous() 685 | tags = tags.transpose(0, 1).contiguous() 686 | 687 | if self.include_start_end_transitions: 688 | score = self.start_transitions.index_select(0, tags[0]) 689 | else: 690 | score = 0.0 691 | 692 | for i in range(sequence_length - 1): 693 | current_tag, next_tag = tags[i], tags[i + 1] 694 | 695 | transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)] 696 | 697 | emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) 698 | 699 | score = score + transition_score * mask[i + 1] + emit_score * mask[i] 700 | 701 | last_tag_index = mask.sum(0).long() - 1 702 | last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) 703 | 704 | if self.include_start_end_transitions: 705 | last_transition_score = self.end_transitions.index_select(0, last_tags) 706 | else: 707 | last_transition_score = 0.0 708 | 709 | last_inputs = logits[-1] # (batch_size, num_tags) 710 | last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) 711 | last_input_score = last_input_score.squeeze() # (batch_size,) 712 | 713 | score = score + last_transition_score + last_input_score * mask[-1] 714 | 715 | return score 716 | 717 | def forward(self, 718 | inputs: torch.Tensor, 719 | tags: torch.Tensor, 720 | mask: torch.ByteTensor = None) -> torch.Tensor: 721 | """ 722 | Computes the log likelihood. 723 | """ 724 | if mask is None: 725 | mask = torch.ones(*tags.size(), dtype=torch.long) 726 | 727 | log_denominator = self._input_likelihood(inputs, mask) 728 | log_numerator = self._joint_likelihood(inputs, tags, mask) 729 | 730 | return torch.sum(log_numerator - log_denominator) 731 | 732 | def viterbi_tags(self, 733 | logits: torch.Tensor, 734 | mask: torch.Tensor) -> List[Tuple[List[int], float]]: 735 | 736 | _, max_seq_length, num_tags = logits.size() 737 | 738 | logits, mask = logits.data, mask.data 739 | 740 | start_tag = num_tags 741 | end_tag = num_tags + 1 742 | transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.) 743 | 744 | constrained_transitions = ( 745 | self.transitions * self._constraint_mask[:num_tags, :num_tags] + 746 | -10000.0 * (1 - self._constraint_mask[:num_tags, :num_tags]) 747 | ) 748 | transitions[:num_tags, :num_tags] = constrained_transitions.data 749 | 750 | if self.include_start_end_transitions: 751 | transitions[start_tag, :num_tags] = ( 752 | self.start_transitions.detach() * self._constraint_mask[start_tag, :num_tags].data + 753 | -10000.0 * (1 - self._constraint_mask[start_tag, :num_tags].detach()) 754 | ) 755 | transitions[:num_tags, end_tag] = ( 756 | self.end_transitions.detach() * self._constraint_mask[:num_tags, end_tag].data + 757 | -10000.0 * (1 - self._constraint_mask[:num_tags, end_tag].detach()) 758 | ) 759 | else: 760 | transitions[start_tag, :num_tags] = (-10000.0 * 761 | (1 - self._constraint_mask[start_tag, :num_tags].detach())) 762 | transitions[:num_tags, end_tag] = -10000.0 * (1 - self._constraint_mask[:num_tags, end_tag].detach()) 763 | 764 | best_paths = [] 765 | tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2) 766 | 767 | for prediction, prediction_mask in zip(logits, mask): 768 | sequence_length = (torch.sum(prediction_mask)).int() 769 | 770 | tag_sequence.fill_(-10000.) 771 | tag_sequence[0, start_tag] = 0. 772 | tag_sequence[1:(sequence_length + 1), :num_tags] = prediction[:sequence_length] 773 | tag_sequence[sequence_length + 1, end_tag] = 0. 774 | 775 | viterbi_path, viterbi_score = viterbi_decode(tag_sequence[:(sequence_length + 2)], transitions) 776 | viterbi_path = viterbi_path[1:-1] 777 | best_paths.append((viterbi_path, viterbi_score.item())) 778 | 779 | return best_paths 780 | 781 | 782 | def logsumexp(tensor: torch.Tensor, 783 | dim: int = -1, 784 | keepdim: bool = False) -> torch.Tensor: 785 | max_score, _ = tensor.max(dim, keepdim=keepdim) 786 | if keepdim: 787 | stable_vec = tensor - max_score 788 | else: 789 | stable_vec = tensor - max_score.unsqueeze(dim) 790 | return max_score + (stable_vec.exp().sum(dim, keepdim=keepdim)).log() 791 | 792 | 793 | def viterbi_decode(tag_sequence: torch.Tensor, 794 | transition_matrix: torch.Tensor, 795 | tag_observations=None): 796 | sequence_length, num_tags = list(tag_sequence.size()) 797 | if tag_observations: 798 | if len(tag_observations) != sequence_length: 799 | raise IOError("Observations were provided, but they were not the same length " 800 | "as the sequence. Found sequence of length: {} and evidence: {}" 801 | .format(sequence_length, tag_observations)) 802 | else: 803 | tag_observations = [-1 for _ in range(sequence_length)] 804 | 805 | path_scores = [] 806 | path_indices = [] 807 | 808 | if tag_observations[0] != -1: 809 | one_hot = torch.zeros(num_tags) 810 | one_hot[tag_observations[0]] = 100000. 811 | path_scores.append(one_hot) 812 | else: 813 | path_scores.append(tag_sequence[0, :]) 814 | 815 | for timestep in range(1, sequence_length): 816 | summed_potentials = path_scores[timestep - 1].unsqueeze(-1) + transition_matrix 817 | scores, paths = torch.max(summed_potentials, 0) 818 | 819 | observation = tag_observations[timestep] 820 | if tag_observations[timestep - 1] != -1: 821 | if transition_matrix[tag_observations[timestep - 1], observation] < -10000: 822 | print("The pairwise potential between tags you have passed as " 823 | "observations is extremely unlikely. Double check your evidence " 824 | "or transition potentials!") 825 | if observation != -1: 826 | one_hot = torch.zeros(num_tags) 827 | one_hot[observation] = 100000. 828 | path_scores.append(one_hot) 829 | else: 830 | path_scores.append(tag_sequence[timestep, :] + scores.squeeze()) 831 | path_indices.append(paths.squeeze()) 832 | 833 | viterbi_score, best_path = torch.max(path_scores[-1], 0) 834 | viterbi_path = [int(best_path.numpy())] 835 | for backward_timestep in reversed(path_indices): 836 | viterbi_path.append(int(backward_timestep[viterbi_path[-1]])) 837 | viterbi_path.reverse() 838 | return viterbi_path, viterbi_score 839 | -------------------------------------------------------------------------------- /data_house/scripts/srl-eval.pl: -------------------------------------------------------------------------------- 1 | #! /usr/bin/perl 2 | 3 | ################################################################## 4 | # 5 | # srl-eval.pl : evaluation program for the CoNLL-2005 Shared Task 6 | # 7 | # Authors : Xavier Carreras and Lluis Marquez 8 | # Contact : carreras@lsi.upc.edu 9 | # 10 | # Created : January 2004 11 | # Modified: 12 | # 2005/04/21 minor update; for perl-5.8 the table in LateX 13 | # did not print correctly 14 | # 2005/02/05 minor updates for CoNLL-2005 15 | # 16 | ################################################################## 17 | 18 | 19 | use strict; 20 | 21 | 22 | 23 | ############################################################ 24 | # A r g u m e n t s a n d H e l p 25 | 26 | use Getopt::Long; 27 | my %options; 28 | GetOptions(\%options, 29 | "latex", # latex output 30 | "C", # confusion matrix 31 | "noW" 32 | ); 33 | 34 | 35 | my $script = "srl-eval.pl"; 36 | my $help = << "end_of_help;"; 37 | Usage: srl-eval.pl 38 | Options: 39 | -latex Produce a results table in LaTeX 40 | -C Produce a confusion matrix of gold vs. predicted argments, wrt. their role 41 | 42 | end_of_help; 43 | 44 | 45 | ############################################################ 46 | # M A I N P R O G R A M 47 | 48 | 49 | my $ns = 0; # number of sentence 50 | my $ntargets = 0; # number of target verbs 51 | my %E; # evaluation results 52 | my %C; # confusion matrix 53 | 54 | my %excluded = ( V => 1); 55 | 56 | ## 57 | 58 | # open files 59 | 60 | if (@ARGV != 2) { 61 | print $help; 62 | exit; 63 | } 64 | 65 | my $goldfile = shift @ARGV; 66 | my $predfile = shift @ARGV; 67 | 68 | if ($goldfile =~ /\.gz/) { 69 | open GOLD, "gunzip -c $goldfile |" or die "$script: could not open gzipped file of gold props ($goldfile)! $!\n"; 70 | } 71 | else { 72 | open GOLD, $goldfile or die "$script: could not open file of gold props ($goldfile)! $!\n"; 73 | } 74 | if ($predfile =~ /\.gz/) { 75 | open PRED, "gunzip -c $predfile |" or die "$script: could not open gzipped file of predicted props ($predfile)! $!\n"; 76 | } 77 | else { 78 | open PRED, $predfile or die "$script: could not open file of predicted props ($predfile)! $!\n"; 79 | } 80 | 81 | 82 | ## 83 | # read and evaluate propositions, sentence by sentence 84 | 85 | my $s = SRL::sentence->read_props($ns, GOLD => \*GOLD, PRED => \*PRED); 86 | 87 | while ($s) { 88 | 89 | my $prop; 90 | 91 | my (@G, @P, $i); 92 | 93 | map { $G[$_->position] = $_ } $s->gold_props; 94 | map { $P[$_->position] = $_ } $s->pred_props; 95 | 96 | for($i=0; $i<@G; $i++) { 97 | my $gprop = $G[$i]; 98 | my $pprop = $P[$i]; 99 | 100 | if ($pprop and !$gprop) { 101 | !$options{noW} and print STDERR "WARNING : sentence $ns : verb ", $pprop->verb, 102 | " at position ", $pprop->position, " : found predicted prop without its gold reference! Skipping prop!\n"; 103 | } 104 | elsif ($gprop) { 105 | if (!$pprop) { 106 | !$options{noW} and print STDERR "WARNING : sentence $ns : verb ", $gprop->verb, 107 | " at position ", $gprop->position, " : missing predicted prop! Counting all arguments as missed!\n"; 108 | $pprop = SRL::prop->new($gprop->verb, $gprop->position); 109 | } 110 | elsif ($gprop->verb ne $pprop->verb) { 111 | !$options{noW} and print STDERR "WARNING : sentence $ns : props do not match : expecting ", 112 | $gprop->verb, " at position ", $gprop->position, 113 | ", found ", $pprop->verb, " at position ", $pprop->position, "! Counting all gold arguments as missed!\n"; 114 | $pprop = SRL::prop->new($gprop->verb, $gprop->position); 115 | } 116 | 117 | $ntargets++; 118 | my %e = evaluate_proposition($gprop, $pprop); 119 | 120 | 121 | # Update global evaluation results 122 | 123 | $E{ok} += $e{ok}; 124 | $E{op} += $e{op}; 125 | $E{ms} += $e{ms}; 126 | $E{ptv} += $e{ptv}; 127 | 128 | my $t; 129 | foreach $t ( keys %{$e{T}} ) { 130 | $E{T}{$t}{ok} += $e{T}{$t}{ok}; 131 | $E{T}{$t}{op} += $e{T}{$t}{op}; 132 | $E{T}{$t}{ms} += $e{T}{$t}{ms}; 133 | } 134 | foreach $t ( keys %{$e{E}} ) { 135 | $E{E}{$t}{ok} += $e{E}{$t}{ok}; 136 | $E{E}{$t}{op} += $e{E}{$t}{op}; 137 | $E{E}{$t}{ms} += $e{E}{$t}{ms}; 138 | } 139 | 140 | if ($options{C}) { 141 | update_confusion_matrix(\%C, $gprop, $pprop); 142 | } 143 | } 144 | } 145 | 146 | $ns++; 147 | $s = SRL::sentence->read_props($ns, GOLD => \*GOLD, PRED => \*PRED); 148 | 149 | } 150 | 151 | 152 | # Print Evaluation results 153 | my $t; 154 | 155 | if ($options{latex}) { 156 | print '\begin{table}[t]', "\n"; 157 | print '\centering', "\n"; 158 | print '\begin{tabular}{|l|r|r|r|}\cline{2-4}', "\n"; 159 | print '\multicolumn{1}{l|}{}', "\n"; 160 | print ' & Precision & Recall & F$_{\beta=1}$', '\\\\', "\n", '\hline', "\n"; #' 161 | 162 | printf("%-10s & %6.2f\\%% & %6.2f\\%% & %6.2f\\\\\n", "Overall", precrecf1($E{ok}, $E{op}, $E{ms})); 163 | print '\hline', "\n"; 164 | 165 | foreach $t ( sort keys %{$E{T}} ) { 166 | printf("%-10s & %6.2f\\%% & %6.2f\\%% & %6.2f\\\\\n", $t, precrecf1($E{T}{$t}{ok}, $E{T}{$t}{op}, $E{T}{$t}{ms})); 167 | } 168 | print '\hline', "\n"; 169 | 170 | if (%excluded) { 171 | print '\hline', "\n"; 172 | foreach $t ( sort keys %{$E{E}} ) { 173 | printf("%-10s & %6.2f\\%% & %6.2f\\%% & %6.2f\\\\\n", $t, precrecf1($E{E}{$t}{ok}, $E{E}{$t}{op}, $E{E}{$t}{ms})); 174 | } 175 | print '\hline', "\n"; 176 | } 177 | 178 | print '\end{tabular}', "\n"; 179 | print '\end{table}', "\n"; 180 | } 181 | else { 182 | printf("Number of Sentences : %6d\n", $ns); 183 | printf("Number of Propositions : %6d\n", $ntargets); 184 | printf("Percentage of perfect props : %6.2f\n",($ntargets>0 ? 100*$E{ptv}/$ntargets : 0)); 185 | print "\n"; 186 | 187 | printf("%10s %6s %6s %6s %6s %6s %6s\n", "", "corr.", "excess", "missed", "prec.", "rec.", "F1"); 188 | print "------------------------------------------------------------\n"; 189 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 190 | "Overall", $E{ok}, $E{op}, $E{ms}, precrecf1($E{ok}, $E{op}, $E{ms})); 191 | # print "------------------------------------------------------------\n"; 192 | print "----------\n"; 193 | 194 | # printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 195 | # "all - {V}", $O2{ok}, $O2{op}, $O2{ms}, precrecf1($O2{ok}, $O2{op}, $O2{ms})); 196 | # print "------------------------------------------------------------\n"; 197 | 198 | foreach $t ( sort keys %{$E{T}} ) { 199 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 200 | $t, $E{T}{$t}{ok}, $E{T}{$t}{op}, $E{T}{$t}{ms}, precrecf1($E{T}{$t}{ok}, $E{T}{$t}{op}, $E{T}{$t}{ms})); 201 | } 202 | print "------------------------------------------------------------\n"; 203 | 204 | foreach $t ( sort keys %{$E{E}} ) { 205 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f\n", 206 | $t, $E{E}{$t}{ok}, $E{E}{$t}{op}, $E{E}{$t}{ms}, precrecf1($E{E}{$t}{ok}, $E{E}{$t}{op}, $E{E}{$t}{ms})); 207 | } 208 | print "------------------------------------------------------------\n"; 209 | } 210 | 211 | 212 | # print confusion matrix 213 | if ($options{C}) { 214 | 215 | my $k; 216 | 217 | # Evaluation of Unlabelled arguments 218 | my ($uok, $uop, $ums, $uacc) = (0,0,0,0); 219 | foreach $k ( grep { $_ ne "-NONE-" && $_ ne "V" } keys %C ) { 220 | map { $uok += $C{$k}{$_} } grep { $_ ne "-NONE-" && $_ ne "V" } keys %{$C{$k}}; 221 | $uacc += $C{$k}{$k}; 222 | $ums += $C{$k}{"-NONE-"}; 223 | } 224 | map { $uop += $C{"-NONE-"}{$_} } grep { $_ ne "-NONE-" && $_ ne "V" } keys %{$C{"-NONE-"}}; 225 | 226 | print "--------------------------------------------------------------------\n"; 227 | printf("%10s %6s %6s %6s %6s %6s %6s %6s\n", "", "corr.", "excess", "missed", "prec.", "rec.", "F1", "lAcc"); 228 | printf("%10s %6d %6d %6d %6.2f %6.2f %6.2f %6.2f\n", 229 | "Unlabeled", $uok, $uop, $ums, precrecf1($uok, $uop, $ums), 100*$uacc/$uok); 230 | print "--------------------------------------------------------------------\n"; 231 | 232 | 233 | 234 | print "\n---- Confusion Matrix: (one row for each correct role, with the distribution of predictions)\n"; 235 | 236 | my %AllKeys; 237 | map { $AllKeys{$_} = 1 } map { $_, keys %{$C{$_}} } keys %C; 238 | my @AllKeys = sort keys %AllKeys; 239 | 240 | 241 | 242 | my $i = -1; 243 | print " "; 244 | map { printf("%4d ", $i); $i++} @AllKeys; 245 | print "\n"; 246 | $i = -1; 247 | foreach $k ( @AllKeys ) { 248 | printf("%2d: %-8s ", $i++, $k); 249 | map { printf("%4d ", $C{$k}{$_}) } @AllKeys; 250 | print "\n"; 251 | } 252 | 253 | 254 | my ($t1,$t2); 255 | foreach $t1 ( sort keys %C ) { 256 | foreach $t2 ( sort keys %{$C{$t1}} ) { 257 | # printf(" %-6s vs %-6s : %-5d\n", $t1, $t2, $C{$t1}{$t2}); 258 | } 259 | } 260 | } 261 | 262 | # end of main program 263 | ##################### 264 | 265 | ############################################################ 266 | # S U B R O U T I N E S 267 | 268 | 269 | # evaluates a predicted proposition wrt the gold correct proposition 270 | # returns a hash with the following keys 271 | # ok : number of correctly predicted args 272 | # ms : number of missed args 273 | # op : number of over-predicted args 274 | # T : a hash indexed by argument types, where 275 | # each value is in turn a hash of {ok,ms,op} numbers 276 | # E : a hash indexed by excluded argument types, where 277 | # each value is in turn a hash of {ok,ms,op} numbers 278 | sub evaluate_proposition { 279 | my ($gprop, $pprop) = @_; 280 | 281 | my $o = $gprop->discriminate_args($pprop); 282 | 283 | my %e; 284 | 285 | my $a; 286 | foreach $a (@{$o->{ok}}) { 287 | if (!$excluded{$a->type}) { 288 | $e{ok}++; 289 | $e{T}{$a->type}{ok}++; 290 | } 291 | else { 292 | $e{E}{$a->type}{ok}++; 293 | } 294 | } 295 | foreach $a (@{$o->{op}}) { 296 | if (!$excluded{$a->type}) { 297 | $e{op}++; 298 | $e{T}{$a->type}{op}++; 299 | } 300 | else { 301 | $e{E}{$a->type}{op}++; 302 | } 303 | } 304 | foreach $a (@{$o->{ms}}) { 305 | if (!$excluded{$a->type}) { 306 | $e{ms}++; 307 | $e{T}{$a->type}{ms}++; 308 | } 309 | else { 310 | $e{E}{$a->type}{ms}++; 311 | } 312 | } 313 | 314 | $e{ptv} = (!$e{op} and !$e{ms}) ? 1 : 0; 315 | 316 | return %e; 317 | } 318 | 319 | 320 | # computes precision, recall and F1 measures 321 | sub precrecf1 { 322 | my ($ok, $op, $ms) = @_; 323 | 324 | my $p = ($ok + $op > 0) ? 100*$ok/($ok+$op) : 0; 325 | my $r = ($ok + $ms > 0) ? 100*$ok/($ok+$ms) : 0; 326 | 327 | my $f1 = ($p+$r>0) ? (2*$p*$r)/($p+$r) : 0; 328 | 329 | return ($p,$r,$f1); 330 | } 331 | 332 | 333 | 334 | 335 | sub update_confusion_matrix { 336 | my ($C, $gprop, $pprop) = @_; 337 | 338 | my $o = $gprop->discriminate_args($pprop, 0); 339 | 340 | my $a; 341 | foreach $a ( @{$o->{ok}} ) { 342 | my $g = shift @{$o->{eq}}; 343 | $C->{$g->type}{$a->type}++; 344 | } 345 | foreach $a ( @{$o->{ms}} ) { 346 | $C->{$a->type}{"-NONE-"}++; 347 | } 348 | foreach $a ( @{$o->{op}} ) { 349 | $C->{"-NONE-"}{$a->type}++; 350 | } 351 | } 352 | 353 | 354 | # end of script 355 | ############### 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | ################################################################################ 378 | # 379 | # Package s e n t e n c e 380 | # 381 | # February 2004 382 | # 383 | # Stores information of a sentence, namely words, chunks, clauses, 384 | # named entities and propositions (gold and predicted). 385 | # 386 | # Provides access methods. 387 | # Provides methods for reading/writing sentences from/to files in 388 | # CoNLL-2004/CoNLL-2005 formats. 389 | # 390 | # 391 | ################################################################################ 392 | 393 | 394 | package SRL::sentence; 395 | use strict; 396 | 397 | 398 | 399 | sub new { 400 | my ($pkg, $id) = @_; 401 | 402 | my $s = []; 403 | 404 | $s->[0] = $id; # sentence number 405 | $s->[1] = undef; # words (the list or the number of words) 406 | $s->[2] = []; # gold props 407 | $s->[3] = []; # predicted props 408 | $s->[4] = undef; # chunks 409 | $s->[5] = undef; # clauses 410 | $s->[6] = undef; # full syntactic tree 411 | $s->[7] = undef; # named entities 412 | 413 | return bless $s, $pkg; 414 | } 415 | 416 | #----- 417 | 418 | sub id { 419 | my $s = shift; 420 | return $s->[0]; 421 | } 422 | 423 | #----- 424 | 425 | sub length { 426 | my $s = shift; 427 | if (ref($s->[1])) { 428 | return scalar(@{$s->[1]}); 429 | } 430 | else { 431 | return $s->[1]; 432 | } 433 | } 434 | 435 | sub set_length { 436 | my $s = shift; 437 | $s->[1] = shift; 438 | } 439 | 440 | #----- 441 | 442 | # returns the i-th word of the sentence 443 | sub word { 444 | my ($s, $i) = @_; 445 | return $s->[1][$i]; 446 | } 447 | 448 | 449 | # returns the list of words of the sentence 450 | sub words { 451 | my $s = shift; 452 | if (@_) { 453 | return map { $s->[1][$_] } @_; 454 | } 455 | else { 456 | return @{$s->[1]}; 457 | } 458 | } 459 | 460 | sub ref_words { 461 | my $s = shift; 462 | return $s->[1]; 463 | } 464 | 465 | 466 | sub chunking { 467 | my $s = shift; 468 | return $s->[4]; 469 | } 470 | 471 | sub clausing { 472 | my $s = shift; 473 | return $s->[5]; 474 | } 475 | 476 | sub syntree { 477 | my $s = shift; 478 | return $s->[6]; 479 | } 480 | 481 | sub named_entities { 482 | my $s = shift; 483 | return $s->[7]; 484 | } 485 | 486 | #----- 487 | 488 | sub add_gold_props { 489 | my $s = shift; 490 | push @{$s->[2]}, @_; 491 | } 492 | 493 | sub gold_props { 494 | my $s = shift; 495 | return @{$s->[2]}; 496 | } 497 | 498 | sub add_pred_props { 499 | my $s = shift; 500 | push @{$s->[3]}, @_; 501 | } 502 | 503 | sub pred_props { 504 | my $s = shift; 505 | return @{$s->[3]}; 506 | } 507 | 508 | 509 | #------------------------------------------------------------ 510 | # I/O F U N C T I O N S 511 | #------------------------------------------------------------ 512 | 513 | # Reads a complete (words, synt, props) sentence from a stream 514 | # Returns: the reference to the sentence object or 515 | # undef if no sentence found 516 | # The propositions in the file are stored as gold props 517 | # For each gold prop, an empty predicted prop is created 518 | # 519 | # The %C hash contains the column number for each annotation of 520 | # the datafile. 521 | # 522 | sub read_from_stream { 523 | my ($pkg, $id, $fh, %C) = @_; 524 | 525 | if (!%C) { 526 | %C = ( words => 0, 527 | pos => 1, 528 | chunks => 2, 529 | clauses => 3, 530 | syntree => 4, 531 | ne => 5, 532 | props => 6 533 | ) 534 | } 535 | 536 | # my $k; 537 | # foreach $k ( "words", "pos", "props" ) { 538 | # if (!exists($C{$k}) { 539 | # die "sentence->read_from_stream :: undefined column number for $k.\n"; 540 | # } 541 | # } 542 | 543 | my $cols = read_columns($fh); 544 | 545 | if (!@$cols) { 546 | return undef; 547 | } 548 | 549 | my $s = $pkg->new($id); 550 | 551 | # words and PoS 552 | my $words = $cols->[$C{words}]; 553 | my $pos = $cols->[$C{pos}]; 554 | 555 | # initialize list of words 556 | $s->[1] = []; 557 | my $i; 558 | for ($i=0;$i<@$words;$i++) { 559 | push @{$s->[1]}, SRL::word->new($i, $words->[$i], $pos->[$i]); 560 | } 561 | 562 | my $c; 563 | 564 | # chunks 565 | if (exists($C{chunks})) { 566 | $c = $cols->[$C{chunks}]; 567 | # initialize chunking 568 | $s->[4] = SRL::phrase_set->new(); 569 | $s->[4]->load_SE_tagging(@$c); 570 | } 571 | 572 | # clauses 573 | if (exists($C{clauses})) { 574 | $c = $cols->[$C{clauses}]; 575 | # initialize clauses 576 | $s->[5] = SRL::phrase_set->new(); 577 | $s->[5]->load_SE_tagging(@$c); 578 | } 579 | 580 | # syntree 581 | if (exists($C{syntree})) { 582 | $c = $cols->[$C{syntree}]; 583 | # initialize syntree 584 | $s->[6] = SRL::syntree->new(); 585 | $s->[6]->load_SE_tagging($s->[1], @$c); 586 | } 587 | 588 | # named entities 589 | if (exists($C{ne})) { 590 | $c = $cols->[$C{ne}]; 591 | $s->[7] = SRL::phrase_set->new(); 592 | $s->[7]->load_SE_tagging(@$c); 593 | } 594 | 595 | 596 | my $i = 0; 597 | while ($i<$C{props}) { 598 | shift @$cols; 599 | $i++; 600 | } 601 | 602 | # gold props 603 | my $targets = shift @$cols or die "error :: reading sentence $id :: no targets found!\n"; 604 | if (@$cols) { 605 | $s->load_props($s->[2], $targets, $cols); 606 | } 607 | 608 | # initialize predicted props 609 | foreach $i ( grep { $targets->[$_] ne "-" } ( 0 .. scalar(@$targets)-1 ) ) { 610 | push @{$s->[3]}, SRL::prop->new($targets->[$i], $i); 611 | } 612 | 613 | return $s; 614 | } 615 | 616 | 617 | 618 | #------------------------------------------------------------ 619 | 620 | 621 | # reads the propositions of a sentence from files 622 | # allows to store propositions as gold and/or predicted, 623 | # by specifying filehandles as values in the %FILES hash 624 | # indexed by {GOLD,PRED} keys 625 | # expects: each prop file: first column specifying target verbs, 626 | # and remaining columns specifying arguments 627 | # returns a new sentence, containing the list of prop 628 | # objects, one for each column, in gold/pred contexts 629 | # returns undef when EOF 630 | sub read_props { 631 | my ($pkg, $id, %FILES) = @_; 632 | 633 | my $s = undef; 634 | my $length = undef; 635 | 636 | if (exists($FILES{GOLD})) { 637 | my $cols = read_columns($FILES{GOLD}); 638 | 639 | # end of file 640 | if (!@$cols) { 641 | return undef; 642 | } 643 | 644 | $s = $pkg->new($id); 645 | my $targets = shift @$cols; 646 | $length = scalar(@$targets); 647 | $s->set_length($length); 648 | $s->load_props($s->[2], $targets, $cols); 649 | } 650 | if (exists($FILES{PRED})) { 651 | my $cols = read_columns($FILES{PRED}); 652 | 653 | if (!defined($s)) { 654 | # end of file 655 | if (!@$cols) { 656 | return undef; 657 | } 658 | $s = $pkg->new($id); 659 | } 660 | my $targets = shift @$cols; 661 | 662 | if (defined($length)) { 663 | ($length != scalar(@$targets)) and 664 | die "ERROR : sentence $id : gold and pred sentences do not align correctly!\n"; 665 | } 666 | else { 667 | $length = scalar(@$targets); 668 | $s->set_length($length); 669 | } 670 | $s->load_props($s->[3], $targets, $cols); 671 | } 672 | 673 | return $s; 674 | } 675 | 676 | 677 | sub load_props { 678 | my ($s, $where, $targets, $cols) = @_; 679 | 680 | my $i; 681 | for ($i=0; $i<@$targets; $i++) { 682 | if ($targets->[$i] ne "-") { 683 | my $prop = SRL::prop->new($targets->[$i], $i); 684 | 685 | my $col = shift @$cols; 686 | if (defined($col)) { 687 | # print "SE Tagging: ", join(" ", @$col), "\n"; 688 | $prop->load_SE_tagging(@$col); 689 | } 690 | else { 691 | print STDERR "WARNING : sentence ", $s->id, " : can't find column of args for prop ", $prop->verb, "!\n"; 692 | } 693 | push @$where, $prop; 694 | } 695 | } 696 | } 697 | 698 | 699 | # writes a sentence to an output stream 700 | # allows to specify which parts of the sentence are written 701 | # by giving true values to the %WHAT hash, indexed by 702 | # {WORDS,SYNT,GOLD,PRED} keys 703 | sub write_to_stream { 704 | my ($s, $fh, %WHAT) = @_; 705 | 706 | if (!%WHAT) { 707 | %WHAT = ( WORDS => 1, 708 | PSYNT => 1, 709 | FSYNT => 1, 710 | GOLD => 0, 711 | PRED => 1 712 | ); 713 | } 714 | 715 | my @columns; 716 | 717 | if ($WHAT{WORDS}) { 718 | my @words = map { $_->form } $s->words; 719 | push @columns, \@words; 720 | } 721 | if ($WHAT{PSYNT}) { 722 | my @pos = map { $_->pos } $s->words; 723 | push @columns, \@pos; 724 | my @chunks = $s->chunking->to_SE_tagging($s->length); 725 | push @columns, \@chunks; 726 | my @clauses = $s->clausing->to_SE_tagging($s->length); 727 | push @columns, \@clauses; 728 | } 729 | if ($WHAT{FSYNT}) { 730 | my @pos = map { $_->pos } $s->words; 731 | push @columns, \@pos; 732 | my @sttags = $s->syntree->to_SE_tagging(); 733 | push @columns, \@sttags; 734 | } 735 | if ($WHAT{GOLD}) { 736 | push @columns, $s->props_to_columns($s->[2]); 737 | } 738 | if ($WHAT{PRED}) { 739 | push @columns, $s->props_to_columns($s->[3]); 740 | } 741 | if ($WHAT{PROPS}) { 742 | push @columns, $s->props_to_columns($WHAT{PROPS}); 743 | } 744 | 745 | 746 | reformat_columns(\@columns); 747 | 748 | # finally, print columns word by word 749 | my $i; 750 | for ($i=0;$i<$s->length;$i++) { 751 | print $fh join(" ", map { $_->[$i] } @columns), "\n"; 752 | } 753 | print $fh "\n"; 754 | 755 | 756 | } 757 | 758 | # turns a set of propositions (target verbs + args for each one) into a set of 759 | # columns in the CoNLL Start-End format 760 | sub props_to_columns { 761 | my ($s, $Pref) = @_; 762 | 763 | my @props = sort { $a->position <=> $b->position } @{$Pref}; 764 | 765 | my $l = $s->length; 766 | my $verbs = []; 767 | my @cols = ( $verbs ); 768 | my $p; 769 | 770 | foreach $p ( @props ) { 771 | defined($verbs->[$p->position]) and die "sentence->preds_to_columns: already defined verb at sentence ", $s->id, " position ", $p->position, "!\n"; 772 | $verbs->[$p->position] = sprintf("%-15s", $p->verb); 773 | 774 | my @tags = $p->to_SE_tagging($l); 775 | push @cols, \@tags; 776 | } 777 | 778 | # finally, define empty verb positions 779 | my $i; 780 | for ($i=0;$i<$l;$i++) { 781 | if (!defined($verbs->[$i])) { 782 | $verbs->[$i] = sprintf("%-15s", "-"); 783 | } 784 | } 785 | 786 | return @cols; 787 | } 788 | 789 | 790 | 791 | # Writes the predicted propositions of the sentence to an output file handler ($fh) 792 | # Specifically, writes a column of target verbs, and a column of arguments 793 | # for each target verb 794 | # OBSOLETE : the same can be done with write_to_stream($s, PRED => 1) 795 | sub write_pred_props { 796 | my ($s, $fh) = @_; 797 | 798 | my @props = sort { $a->position <=> $b->position } $s->pred_props; 799 | 800 | my $l = $s->length; 801 | my @verbs = (); 802 | my @cols = (); 803 | my $p; 804 | 805 | foreach $p ( @props ) { 806 | defined($verbs[$p->position]) and die "prop->write_pred_props: already defined verb at sentence ", $s->id, " position ", $p->position, "!\n"; 807 | $verbs[$p->position] = $p->verb; 808 | 809 | my @tags = $p->to_SE_tagging($l); 810 | push @cols, \@tags; 811 | } 812 | 813 | # finally, print columns word by word 814 | my $i; 815 | for ($i=0;$i<$l;$i++) { 816 | printf $fh "%-15s %s\n", (defined($verbs[$i])? $verbs[$i] : "-"), 817 | join(" ", map { $_->[$i] } @cols); 818 | } 819 | print "\n"; 820 | } 821 | 822 | 823 | 824 | # reads columns until blank line or EOF 825 | # returns an array of columns (each column is a reference to an array containing the column) 826 | # each column in the returned array should be the same size 827 | sub read_columns { 828 | my $fh = shift; 829 | 830 | # read columns until blank line or eof 831 | my @cols; 832 | my $i; 833 | my @line = split(" ", <$fh>); 834 | while (@line) { 835 | for ($i=0; $i<@line; $i++) { 836 | push @{$cols[$i]}, $line[$i]; 837 | } 838 | @line = split(" ", <$fh>); 839 | } 840 | 841 | return \@cols; 842 | } 843 | 844 | 845 | 846 | # reformats the tags of a list of columns, so that each 847 | # column has a fixed width along all tags 848 | # 849 | # 850 | sub reformat_columns { 851 | my $cols = shift; # a reference to the list of columns of a sentence 852 | 853 | my $i; 854 | for ($i=0;$i[$i]); 856 | } 857 | } 858 | 859 | 860 | 861 | # reformats the tags of a column, so that each 862 | # tag has the same width 863 | # 864 | # tag sequences are left justified 865 | # start-end annotations are centered at the asterisk 866 | # 867 | sub column_pretty_format { 868 | my $col = shift; # a reference to the column (array) of tags 869 | 870 | (!@$col) and return undef; 871 | 872 | my ($i); 873 | if ($col->[0] =~ /\*/) { 874 | 875 | # Start-End 876 | my $ok = 1; 877 | 878 | my (@s,@e,$t,$ms,$me); 879 | $ms = 2; $me = 2; 880 | $i = 0; 881 | while ($ok and $i<@$col) { 882 | if ($col->[$i] =~ /^(.*\*)(.*)$/) { 883 | $s[$i] = $1; 884 | $e[$i] = $2; 885 | if (length($s[$i]) > $ms) { 886 | $ms = length($s[$i]); 887 | } 888 | if (length($e[$i]) > $me) { 889 | $me = length($e[$i]); 890 | } 891 | } 892 | else { 893 | # In this case, the current token is not compliant with SE format 894 | # So, we treat format the column as a sequence of tags 895 | $ok = 0; 896 | } 897 | $i++; 898 | } 899 | # print "M $ms $me\n"; 900 | 901 | if ($ok) { 902 | my $f = "%".($ms+1)."s%-".($me+1)."s"; 903 | for ($i=0; $i<@$col; $i++) { 904 | $col->[$i] = sprintf($f, $s[$i], $e[$i]); 905 | } 906 | return; 907 | } 908 | } 909 | 910 | # Tokens 911 | my $l=0; 912 | map { (length($_)>$l) and ($l=length($_)) } @$col; 913 | my $f = "%-".($l+1)."s"; 914 | for ($i=0; $i<@$col; $i++) { 915 | $col->[$i] = sprintf($f,$col->[$i]); 916 | } 917 | 918 | } 919 | 920 | 921 | 922 | 1; 923 | 924 | 925 | 926 | 927 | 928 | 929 | 930 | 931 | ################################################################## 932 | # 933 | # Package p r o p : A proposition (verb + args) 934 | # 935 | # January 2004 936 | # 937 | ################################################################## 938 | 939 | 940 | package SRL::prop; 941 | 942 | use strict; 943 | 944 | 945 | # Constructor: creates a new prop, with empty arguments 946 | # Parameters: verb form, position of verb 947 | sub new { 948 | my ($pkg, $v, $position) = @_; 949 | 950 | my $p = []; 951 | 952 | $p->[0] = $v; # the verb 953 | $p->[1] = $position; # verb position 954 | $p->[2] = undef; # verb sense 955 | $p->[3] = []; # args, empty by default 956 | 957 | return bless $p, $pkg; 958 | } 959 | 960 | ## Accessor/Initialization methods 961 | 962 | # returns the verb form of the prop 963 | sub verb { 964 | my $p = shift; 965 | return $p->[0]; 966 | } 967 | 968 | # returns the verb position of the verb in the prop 969 | sub position { 970 | my $p = shift; 971 | return $p->[1]; 972 | } 973 | 974 | # returns the verb sense of the verb in the prop 975 | sub sense { 976 | my $p = shift; 977 | return $p->[2]; 978 | } 979 | 980 | # initializes the verb sense of the verb in the prop 981 | sub set_sense { 982 | my $p = shift; 983 | $p->[2] = shift; 984 | } 985 | 986 | 987 | # returns the list of arguments of the prop 988 | sub args { 989 | my $p = shift; 990 | return @{$p->[3]}; 991 | } 992 | 993 | # initializes the list of arguments of the prop 994 | sub set_args { 995 | my $p = shift; 996 | @{$p->[3]} = @_; 997 | } 998 | 999 | # adds arguments to the prop 1000 | sub add_args { 1001 | my $p = shift; 1002 | push @{$p->[3]}, @_; 1003 | } 1004 | 1005 | # Returns the list of phrases of the prop 1006 | # Each argument corresponds to one phrase, except for 1007 | # discontinuous arguments, where each piece forms a phrase 1008 | sub phrases { 1009 | my $p = shift; 1010 | return map { $_->single ? $_ : $_->phrases} @{$p->[3]}; 1011 | } 1012 | 1013 | 1014 | ###### Methods 1015 | 1016 | # Adds arguments represented in Start-End tagging 1017 | # Receives a list of Start-End tags (one per word in the sentence) 1018 | # Creates an arg object for each argument in the taggging 1019 | # and modifies the prop so that the arguments are part of it 1020 | # Takes into account special treatment for discontinuous arguments 1021 | sub load_SE_tagging { 1022 | my ($prop, @tags) = @_; 1023 | 1024 | # auxiliar phrase set 1025 | my $set = SRL::phrase_set->new(); 1026 | $set->load_SE_tagging(@tags); 1027 | 1028 | # store args per type, to be able to continue them 1029 | my %ARGS; 1030 | my $a; 1031 | 1032 | # add each phrase as an argument, with special treatment for multi-phrase arguments (A C-A C-A) 1033 | foreach $a ( $set->phrases ) { 1034 | 1035 | # the phrase continues a started arg 1036 | if ($a->type =~ /^C\-/) { 1037 | my $type = $'; # ' 1038 | if (exists($ARGS{$type})) { 1039 | my $pc = $a; 1040 | $a = $ARGS{$type}; 1041 | if ($a->single) { 1042 | # create the head phrase, considered arg until now 1043 | my $ph = SRL::phrase->new($a->start, $a->end, $type); 1044 | $a->add_phrases($ph); 1045 | } 1046 | $a->add_phrases($pc); 1047 | $a->set_end($pc->end); 1048 | } 1049 | else { 1050 | # print STDERR "WARNING : found continuation phrase \"C-$type\" without heading phrase: turned into regular $type argument.\n"; 1051 | # turn the phrase into arg 1052 | bless $a, "SRL::arg"; 1053 | $a->set_type($type); 1054 | push @{$prop->[3]}, $a; 1055 | $ARGS{$a->type} = $a; 1056 | } 1057 | } 1058 | else { 1059 | # turn the phrase into arg 1060 | bless $a, "SRL::arg"; 1061 | push @{$prop->[3]}, $a; 1062 | $ARGS{$a->type} = $a; 1063 | } 1064 | } 1065 | 1066 | } 1067 | 1068 | 1069 | ## discriminates the args of prop $pb wrt the args of prop $pa, returning intersection(a^b), a-b and b-a 1070 | # returns a hash reference containing three lists: 1071 | # $out->{ok} : args in $pa and $pb 1072 | # $out->{ms} : args in $pa and not in $pb 1073 | # $out->{op} : args in $pb and not in $pa 1074 | sub discriminate_args { 1075 | my $pa = shift; 1076 | my $pb = shift; 1077 | my $check_type = @_ ? shift : 1; 1078 | 1079 | my $out = {}; 1080 | !$check_type and @{$out->{eq}} = (); 1081 | @{$out->{ok}} = (); 1082 | @{$out->{ms}} = (); 1083 | @{$out->{op}} = (); 1084 | 1085 | my $a; 1086 | my %ok; 1087 | 1088 | my %ARGS; 1089 | 1090 | foreach $a ($pa->args) { 1091 | $ARGS{$a->start}{$a->end} = $a; 1092 | } 1093 | 1094 | foreach $a ($pb->args) { 1095 | my $s = $a->start; 1096 | my $e = $a->end; 1097 | 1098 | my $gold = $ARGS{$s}{$e}; 1099 | if (!defined($gold)) { 1100 | push @{$out->{op}}, $a; 1101 | } 1102 | elsif ($gold->single and $a->single) { 1103 | if (!$check_type or ($gold->type eq $a->type)) { 1104 | !$check_type and push @{$out->{eq}}, $gold; 1105 | push @{$out->{ok}}, $a; 1106 | delete($ARGS{$s}{$e}); 1107 | } 1108 | else { 1109 | push @{$out->{op}}, $a; 1110 | } 1111 | } 1112 | elsif (!$gold->single and $a->single) { 1113 | push @{$out->{op}}, $a; 1114 | } 1115 | elsif ($gold->single and !$a->single) { 1116 | push @{$out->{op}}, $a; 1117 | } 1118 | else { 1119 | # Check phrases of arg 1120 | my %P; 1121 | my $ok = (!$check_type or ($gold->type eq $a->type)); 1122 | $ok and map { $P{ $_->start.".".$_->end } = 1 } $gold->phrases; 1123 | my @P = $a->phrases; 1124 | while ($ok and @P) { 1125 | my $p = shift @P; 1126 | if ($P{ $p->start.".".$p->end }) { 1127 | delete $P{ $p->start.".".$p->end } 1128 | } 1129 | else { 1130 | $ok = 0; 1131 | } 1132 | } 1133 | if ($ok and !(values %P)) { 1134 | !$check_type and push @{$out->{eq}}, $gold; 1135 | push @{$out->{ok}}, $a; 1136 | delete $ARGS{$s}{$e} 1137 | } 1138 | else { 1139 | push @{$out->{op}}, $a; 1140 | } 1141 | } 1142 | } 1143 | 1144 | my ($s); 1145 | foreach $s ( keys %ARGS ) { 1146 | foreach $a ( values %{$ARGS{$s}} ) { 1147 | push @{$out->{ms}}, $a; 1148 | } 1149 | } 1150 | 1151 | return $out; 1152 | } 1153 | 1154 | 1155 | # Generates a Start-End tagging for the prop arguments 1156 | # Expects the prop object, and l=length of the sentence 1157 | # Returns a list of l tags 1158 | sub to_SE_tagging { 1159 | my $prop = shift; 1160 | my $l = shift; 1161 | my @tags = (); 1162 | 1163 | my ($a, $p); 1164 | foreach $a ( $prop->args ) { 1165 | my $t = $a->type; 1166 | my $cont = 0; 1167 | foreach $p ( $a->single ? $a : $a->phrases ) { 1168 | if (defined($tags[$p->start])) { 1169 | die "prop->to_SE_tagging: Already defined tag in position ", $p->start, "! Prop phrases overlap or embed!\n"; 1170 | } 1171 | if ($p->start != $p->end) { 1172 | $tags[$p->start] = sprintf("%7s", "(".$t)."* "; 1173 | if (defined($tags[$p->end])) { 1174 | die "prop->to_SE_tagging: Already defined tag in position ", $p->end, "! Prop phrases overlap or embed!\n"; 1175 | } 1176 | # $tags[$p->end] = " *".sprintf("%-7s", $t.")"); 1177 | $tags[$p->end] = " *".sprintf("%-3s", ")"); 1178 | } 1179 | else { 1180 | # $tags[$p->start] = sprintf("%7s", "(".$t)."*".sprintf("%-7s", $t.")"); 1181 | $tags[$p->start] = sprintf("%7s", "(".$t)."*".sprintf("%-3s",")"); 1182 | } 1183 | 1184 | if (!$cont) { 1185 | $cont = 1; 1186 | $t = "C-".$t; 1187 | } 1188 | } 1189 | } 1190 | 1191 | my $i; 1192 | for ($i=0; $i<$l; $i++) { 1193 | if (!defined($tags[$i])) { 1194 | $tags[$i] = " * "; 1195 | } 1196 | } 1197 | 1198 | return @tags; 1199 | } 1200 | 1201 | 1202 | # generates a string representing the proposition 1203 | sub to_string { 1204 | my $p = shift; 1205 | 1206 | my $s = "[". $p->verb . "@" . $p->position . ": "; 1207 | $s .= join(" ", map { $_->to_string } $p->args); 1208 | $s .= " ]"; 1209 | 1210 | return $s; 1211 | } 1212 | 1213 | 1214 | 1; 1215 | 1216 | 1217 | ################################################################################ 1218 | # 1219 | # Package p h r a s e _ s e t 1220 | # 1221 | # A set of phrases 1222 | # Each phrase is indexed by (start,end) positions 1223 | # 1224 | # Holds non-overlapping phrase sets. 1225 | # Embedding of phrases allowed and exploited in class methods 1226 | # 1227 | # Brings useful functions on phrase sets, such as: 1228 | # - Load phrases from tag sequences in IOB1, IOB2, Start-End formats 1229 | # - Retrieve a phrase given its (start,end) positions 1230 | # - List phrases found within a given (s,e) segment 1231 | # - Discriminate a predicted set of phrases with respect to the gold set 1232 | # 1233 | ################################################################################ 1234 | 1235 | use strict; 1236 | 1237 | 1238 | package SRL::phrase_set; 1239 | 1240 | ## $phrase_types global variable 1241 | # If defined, contains a hash table specifying the phrase types to be considered 1242 | # If undefined, any phrase type is considered 1243 | my $phrase_types = undef; 1244 | sub set_phrase_types { 1245 | $phrase_types = {}; 1246 | my $t; 1247 | foreach $t ( @_ ) { 1248 | $phrase_types->{$t} = 1; 1249 | } 1250 | } 1251 | 1252 | # Constructor: creates a new phrase set 1253 | # Arguments: an initial set of phrases, which are added to the set 1254 | sub new { 1255 | my ($pkg, @P) = @_; 1256 | my $s = []; 1257 | @{$s->[0]} = (); # NxN half-matrix, storing phrases 1258 | $s->[1] = 0; # N (length of the sentence) 1259 | bless $s, $pkg; 1260 | 1261 | $s->add_phrases(@P); 1262 | 1263 | return $s; 1264 | } 1265 | 1266 | 1267 | # Adds phrases represented in IOB2 tagging 1268 | # Receives a list of IOB2 tags (one per word in the sentence) 1269 | # Creates a phrase object for each phrase in the taggging 1270 | # and modifies the set so that the phrases are part of it 1271 | sub load_IOB2_tagging { 1272 | my ($set, @tags) = @_; 1273 | 1274 | my $wid = 0; # word id 1275 | my $phrase = undef; # current phrase 1276 | my $t; 1277 | foreach $t (@tags) { 1278 | if ($phrase and $t !~ /^I/) { 1279 | $phrase->set_end($wid-1); 1280 | $set->add_phrases($phrase); 1281 | $phrase = undef; 1282 | } 1283 | if ($t =~ /^B-/) { 1284 | my $type = $'; 1285 | if (!defined($phrase_types) or $phrase_types->{$type}) { 1286 | $phrase = SRL::phrase->new($wid); 1287 | $phrase->set_type($type); 1288 | } 1289 | } 1290 | $wid++; 1291 | } 1292 | if ($phrase) { 1293 | $phrase->set_end($wid-1); 1294 | $set->add_phrases($phrase); 1295 | } 1296 | } 1297 | 1298 | 1299 | # Adds phrases represented in IOB1 tagging 1300 | # Receives a list of IOB1 tags (one per word in the sentence) 1301 | # Creates a phrase object for each phrase in the taggging 1302 | # and modifies the set so that the phrases are part of it 1303 | sub load_IOB1_tagging { 1304 | my ($set, @tags) = @_; 1305 | 1306 | my $wid = 0; # word id 1307 | my $phrase = undef; # current phrase 1308 | my $t = shift @tags; 1309 | while (defined($t)) { 1310 | if ($t =~ /^[BI]-/) { 1311 | my $type = $'; 1312 | if (!defined($phrase_types) or $phrase_types->{$type}) { 1313 | $phrase = SRL::phrase->new($wid); 1314 | $phrase->set_type($type); 1315 | my $tag = "I-".$type; 1316 | $t = shift @tags; 1317 | $wid++; 1318 | while ($t eq $tag) { 1319 | $t = shift @tags; 1320 | $wid++; 1321 | } 1322 | $phrase->set_end($wid-1); 1323 | $set->add_phrases($phrase); 1324 | } 1325 | else { 1326 | $t = shift @tags; 1327 | $wid++; 1328 | } 1329 | } 1330 | else { 1331 | $t = shift @tags; 1332 | $wid++; 1333 | } 1334 | } 1335 | } 1336 | 1337 | # Adds phrases represented in Start-End tagging 1338 | # Receives a list of Start-End tags (one per word in the sentence) 1339 | # Creates a phrase object for each phrase in the taggging 1340 | # and modifies the set so that the phrases are part of it 1341 | sub load_SE_tagging { 1342 | my ($set, @tags) = @_; 1343 | 1344 | my (@SP); # started phrases 1345 | my $wid = 0; 1346 | my ($tag, $p); 1347 | foreach $tag ( @tags ) { 1348 | while ($tag !~ /^\*/) { 1349 | $tag =~ /^\(((\\\*|[^*(])+)/ or die "phrase_set->load_SE_tagging: opening nodes -- bad format in $tag at $wid-th position!\n"; 1350 | my $type = $1; 1351 | $tag = $'; 1352 | if (!defined($phrase_types) or $phrase_types->{$type}) { 1353 | $p = SRL::phrase->new($wid); 1354 | $p->set_type($type); 1355 | push @SP, $p; 1356 | } 1357 | } 1358 | $tag =~ s/^\*//; 1359 | while ($tag ne "") { 1360 | $tag =~ /^([^\)]*)\)/ or die "phrase_set->load_SE_tagging: closing phrases -- bad format in $tag!\n"; 1361 | my $type = $1; 1362 | $tag = $'; 1363 | if (!$type or !defined($phrase_types) or $phrase_types->{$type}) { 1364 | $p = pop @SP; 1365 | (!$type) or ($type eq $p->type) or die "phrase_set->load_SE_tagging: types do not match!\n"; 1366 | $p->set_end($wid); 1367 | 1368 | if (@SP) { 1369 | $SP[$#SP]->add_phrases($p); 1370 | } 1371 | else { 1372 | $set->add_phrases($p); 1373 | } 1374 | } 1375 | } 1376 | $wid++; 1377 | } 1378 | (!@SP) or die "phrase_set->load_SE_tagging: some phrases are unclosed!\n"; 1379 | } 1380 | 1381 | 1382 | sub refs_start_end_tags { 1383 | my ($s, $l) = @_; 1384 | 1385 | my (@S,@E,$i); 1386 | for ($i=0; $i<$l; $i++) { 1387 | $S[$i] = ""; 1388 | $E[$i] = ""; 1389 | } 1390 | 1391 | my $p; 1392 | foreach $p ( $s->phrases ) { 1393 | $S[$p->start] .= "(".$p->type; 1394 | # $E[$p->end] = $E[$p->end].$p->type.")"; 1395 | $E[$p->end] .= ")"; 1396 | } 1397 | 1398 | return (\@S,\@E); 1399 | } 1400 | 1401 | 1402 | sub to_SE_tagging { 1403 | my ($s, $l) = @_; 1404 | 1405 | # my (@S,@E,$i); 1406 | # for ($i=0; $i<$l; $i++) { 1407 | # $S[$i] = ""; 1408 | # $E[$i] = ""; 1409 | # } 1410 | 1411 | # my $p; 1412 | # foreach $p ( $s->phrases ) { 1413 | # $S[$p->start] .= "(".$p->type; 1414 | # # $E[$p->end] = $E[$p->end].$p->type.")"; 1415 | # $E[$p->end] .= ")"; 1416 | # } 1417 | 1418 | my ($S,$E) = refs_start_end_tags($s,$l); 1419 | 1420 | my $i; 1421 | my @tags; 1422 | for ($i=0; $i<$l; $i++) { 1423 | # $tags[$i] = sprintf("%8s*%-12s", $S->[$i], $E->[$i]); 1424 | $tags[$i] = sprintf("%8s*%-5s", $S->[$i], $E->[$i]); 1425 | } 1426 | return @tags; 1427 | } 1428 | 1429 | 1430 | sub to_IOB2_tagging { 1431 | my ($s, $l) = @_; 1432 | 1433 | my (@tags,$p,$i); 1434 | 1435 | foreach $p ( $s->phrases ) { 1436 | my $tag = $p->type; 1437 | $i = $p->start; 1438 | $tags[$i] and $tags[$i] .= "/"; 1439 | $tags[$i] .= "B-".$tag; 1440 | $i++; 1441 | while ($i<=$p->end) { 1442 | $tags[$i] and $tags[$i] .= "/"; 1443 | $tags[$i] .= "I-".$tag; 1444 | $i++; 1445 | } 1446 | } 1447 | for ($i=0; $i<$l; $i++) { 1448 | if (!defined($tags[$i])) { 1449 | $tags[$i] = "O "; 1450 | } 1451 | else { 1452 | $tags[$i] = sprintf("%-6s", $tags[$i]); 1453 | } 1454 | } 1455 | return @tags; 1456 | } 1457 | 1458 | 1459 | # ------------------------------------------------------------ 1460 | 1461 | # Adds phrases in the set, recursively (ie. internal phrases are also added) 1462 | sub add_phrases { 1463 | my ($s, @P) = @_; 1464 | my $ph; 1465 | foreach $ph ( map { $_->dfs } @P ) { 1466 | $s->[0][$ph->start][$ph->end] = $ph; 1467 | if ($ph->end >= $s->[1]) { 1468 | $s->[1] = $ph->end +1; 1469 | } 1470 | } 1471 | } 1472 | 1473 | # returns the number of phrases in the set 1474 | sub size { 1475 | my $set = shift; 1476 | 1477 | my ($i,$j); 1478 | my $n; 1479 | for ($i=0; $i<@{$set->[0]}; $i++) { 1480 | if (defined($set->[0][$i])) { 1481 | for ($j=$i; $j<@{$set->[0][$i]}; $j++) { 1482 | if (defined($set->[0][$i][$j])) { 1483 | $n++; 1484 | } 1485 | } 1486 | } 1487 | } 1488 | return $n; 1489 | } 1490 | 1491 | # returns the phrase starting at word position $s and ending at $e 1492 | # or undef if it doesn't exist 1493 | sub phrase { 1494 | my ($set, $s, $e) = @_; 1495 | return $set->[0][$s][$e]; 1496 | } 1497 | 1498 | 1499 | # Returns phrases in the set, recursively in depth first search order 1500 | # that is, if a phrase is returned, all its subphrases are also returned 1501 | # If no parameters, returns all phrases 1502 | # If a pair of positions is given ($s,$e), returns phrases included 1503 | # within the $s and $e positions 1504 | sub phrases { 1505 | my $set = shift; 1506 | my ($s, $e); 1507 | if (!@_) { 1508 | $s = 0; 1509 | $e = $set->[1]-1; 1510 | } 1511 | else { 1512 | ($s,$e) = @_; 1513 | } 1514 | my ($i,$j); 1515 | my @P = (); 1516 | for ($i=$s;$i<=$e;$i++) { 1517 | if (defined($set->[0][$i])) { 1518 | for ($j=$e;$j>=$i;$j--) { 1519 | if (defined($set->[0][$i][$j])) { 1520 | push @P, $set->[0][$i][$j]; 1521 | } 1522 | } 1523 | } 1524 | } 1525 | return @P; 1526 | } 1527 | 1528 | 1529 | # Returns phrases in the set, non-recursively in sequential order 1530 | # that is, if a phrase is returned, its subphrases are not returned 1531 | # If no parameters, returns all phrases 1532 | # If a pair of positions is given ($s,$e), returns phrases included 1533 | # within the $s and $e positions 1534 | sub top_phrases { 1535 | my $set = shift; 1536 | my ($s, $e); 1537 | if (!@_) { 1538 | $s = 0; 1539 | $e = $set->[1]-1; 1540 | } 1541 | else { 1542 | ($s,$e) = @_; 1543 | } 1544 | my ($i,$j); 1545 | my @P = (); 1546 | $i = $s; 1547 | while ($i<=$e) { 1548 | $j=$e; 1549 | while ($j>=$s) { 1550 | if (defined($set->[0][$i][$j])) { 1551 | push @P, $set->[0][$i][$j]; 1552 | $i=$j; 1553 | $j=-1; 1554 | } 1555 | else { 1556 | $j--; 1557 | } 1558 | } 1559 | $i++; 1560 | } 1561 | return @P; 1562 | } 1563 | 1564 | 1565 | # returns the phrases which contain the terminal $wid, in bottom-up order 1566 | sub ancestors { 1567 | my ($set, $wid) = @_; 1568 | 1569 | my @A; 1570 | my $N = $set->[1]; 1571 | 1572 | my ($s,$e); 1573 | 1574 | for ($s = $wid; $s>=0; $s--) { 1575 | if (defined($set->[0][$s])) { 1576 | for ($e = $wid; $e<$N; $e++) { 1577 | if (defined($set->[0][$s][$e])) { 1578 | push @A, $set->[0][$s][$e]; 1579 | } 1580 | } 1581 | } 1582 | } 1583 | 1584 | return @A; 1585 | } 1586 | 1587 | 1588 | # returns a TRUE value if the phrase $p ovelaps with some phrase in 1589 | # the set; the returned value is the reference to the conflicting phrase 1590 | # returns FALSE otherwise 1591 | sub check_overlapping { 1592 | my ($set, $p) = @_; 1593 | 1594 | my ($s,$e); 1595 | for ($s=0; $s<$p->start; $s++) { 1596 | if (defined($set->[0][$s])) { 1597 | for ($e=$p->start; $e<$p->end; $e++) { 1598 | if (defined($set->[0][$s][$e])) { 1599 | return $set->[0][$s][$e]; 1600 | } 1601 | } 1602 | } 1603 | } 1604 | for ($s=$p->start+1; $s<=$p->end; $s++) { 1605 | if (defined($set->[0][$s])) { 1606 | for ($e=$p->end+1; $e<$set->[1]; $e++) { 1607 | if (defined($set->[0][$s][$e])) { 1608 | return $set->[0][$s][$e]; 1609 | } 1610 | } 1611 | } 1612 | } 1613 | 1614 | return 0; 1615 | } 1616 | 1617 | 1618 | ## ---------------------------------------- 1619 | 1620 | # Discriminates a set of phrases (s1) wrt the current set (s0), returning 1621 | # intersection (s0^s1), over-predicted (s1-s0) and missed (s0-s1) 1622 | # Returns a hash reference containing three lists: 1623 | # $out->{ok} : phrases in $s0 and $1 1624 | # $out->{op} : phrases in $s1 and not in $0 1625 | # $out->{ms} : phrases in $s0 and not in $1 1626 | sub discriminate { 1627 | my ($s0, $s1) = @_; 1628 | 1629 | my $out; 1630 | @{$out->{ok}} = (); 1631 | @{$out->{ms}} = (); 1632 | @{$out->{op}} = (); 1633 | 1634 | my $ph; 1635 | my %ok; 1636 | 1637 | foreach $ph ($s1->phrases) { 1638 | my $s = $ph->start; 1639 | my $e = $ph->end; 1640 | 1641 | my $gph = $s0->phrase($s,$e); 1642 | if ($gph and $gph->type eq $ph->type) { 1643 | # correct 1644 | $ok{$s}{$e} = 1; 1645 | push @{$out->{ok}}, $ph; 1646 | } 1647 | else { 1648 | # overpredicted 1649 | push @{$out->{op}}, $ph; 1650 | } 1651 | } 1652 | 1653 | foreach $ph ($s0->phrases) { 1654 | my $s = $ph->start; 1655 | my $e = $ph->end; 1656 | 1657 | if (!$ok{$s}{$e}) { 1658 | # missed 1659 | push @{$out->{ms}}, $ph; 1660 | } 1661 | } 1662 | return $out; 1663 | } 1664 | 1665 | 1666 | # compares the current set (s0) to another set (s1) 1667 | # returns the number of correct, missed an over-predicted phrases 1668 | sub evaluation { 1669 | my ($s0, $s1) = @_; 1670 | 1671 | my $o = $s0->discriminate($s1); 1672 | 1673 | my %e; 1674 | $e{ok} = scalar(@{$o->{ok}}); 1675 | $e{op} = scalar(@{$o->{op}}); 1676 | $e{ms} = scalar(@{$o->{ms}}); 1677 | 1678 | return %e; 1679 | } 1680 | 1681 | 1682 | # generates a string representing the phrase set, 1683 | # for printing purposes 1684 | sub to_string { 1685 | my $s = shift; 1686 | return join(" ", map { $_->to_string } $s->top_phrases); 1687 | } 1688 | 1689 | 1690 | 1; 1691 | 1692 | 1693 | 1694 | 1695 | 1696 | 1697 | 1698 | 1699 | 1700 | 1701 | 1702 | 1703 | 1704 | 1705 | 1706 | 1707 | ################################################################## 1708 | # 1709 | # Package p h r a s e : a generic phrase 1710 | # 1711 | # January 2004 1712 | # 1713 | # This class represents generic phrases. 1714 | # A phrase is a sequence of contiguous words in a sentence. 1715 | # A phrase is identified by the positions of the start/end words 1716 | # of the sequence that the phrase spans. 1717 | # A phrase has a type. 1718 | # A phrase may contain a list of internal subphrases, that is, 1719 | # phrases found within the phrase. Thus, a phrase object is seen 1720 | # eventually as a hierarchical structure. 1721 | # 1722 | # A syntactic base chunk is a phrase with no internal phrases. 1723 | # A clause is a phrase which may have internal phrases 1724 | # A proposition argument is implemented as a special class which 1725 | # inherits from the phrase class. 1726 | # 1727 | ################################################################## 1728 | 1729 | use strict; 1730 | 1731 | package SRL::phrase; 1732 | 1733 | # Constructor: creates a new phrase 1734 | # Parameters: start position, end position and type 1735 | sub new { 1736 | my $pkg = shift; 1737 | 1738 | my $ph = []; 1739 | 1740 | # start word index 1741 | $ph->[0] = (@_) ? shift : undef; 1742 | # end word index 1743 | $ph->[1] = (@_) ? shift : undef; 1744 | # phrase type 1745 | $ph->[2] = (@_) ? shift : undef; 1746 | # 1747 | @{$ph->[3]} = (); 1748 | 1749 | return bless $ph, $pkg; 1750 | } 1751 | 1752 | # returns the start position of the phrase 1753 | sub start { 1754 | my $ph = shift; 1755 | return $ph->[0]; 1756 | } 1757 | 1758 | # initializes the start position of the phrase 1759 | sub set_start { 1760 | my $ph = shift; 1761 | $ph->[0] = shift; 1762 | } 1763 | 1764 | # returns the end position of the phrase 1765 | sub end { 1766 | my $ph = shift; 1767 | return $ph->[1]; 1768 | } 1769 | 1770 | # initializes the end position of the phrase 1771 | sub set_end { 1772 | my $ph = shift; 1773 | $ph->[1] = shift; 1774 | } 1775 | 1776 | # returns the type of the phrase 1777 | sub type { 1778 | my $ph = shift; 1779 | return $ph->[2]; 1780 | } 1781 | 1782 | # initializes the type of the phrase 1783 | sub set_type { 1784 | my $ph = shift; 1785 | $ph->[2] = shift; 1786 | } 1787 | 1788 | # returns the subphrases of the current phrase 1789 | sub phrases { 1790 | my $ph = shift; 1791 | return @{$ph->[3]}; 1792 | } 1793 | 1794 | # adds phrases as subphrases 1795 | sub add_phrases { 1796 | my $ph = shift; 1797 | push @{$ph->[3]}, @_; 1798 | } 1799 | 1800 | # initializes the set of subphrases 1801 | sub set_phrases { 1802 | my $ph = shift; 1803 | @{$ph->[3]} = @_; 1804 | } 1805 | 1806 | 1807 | # depth first search 1808 | # returns the phrases rooted int the current phrase in dfs order 1809 | sub dfs { 1810 | my $ph = shift; 1811 | return ($ph, map { $_->dfs } $ph->phrases); 1812 | } 1813 | 1814 | 1815 | # generates a string representing the phrase (and subphrases if arg is a TRUE value), for printing 1816 | sub to_string { 1817 | my $ph = shift; 1818 | my $rec = ( @_ ) ? shift : 1; 1819 | 1820 | my $str = "(" . $ph->start . " "; 1821 | 1822 | $rec and map { $str .= $_->to_string." " } $ph->phrases; 1823 | 1824 | $str .= $ph->end . ")"; 1825 | if (defined($ph->type)) { 1826 | $str .= "_".$ph->type; 1827 | } 1828 | return $str; 1829 | } 1830 | 1831 | 1832 | 1; 1833 | 1834 | ################################################################## 1835 | # 1836 | # Package a r g : An argument 1837 | # 1838 | # January 2004 1839 | # 1840 | # This class inherits from the class "phrase". 1841 | # An argument is identified by start-end positions of the 1842 | # string spanned by the argument in the sentence. 1843 | # An argument has a type. 1844 | # 1845 | # Most of the arguments consist of a single phrase; in this 1846 | # case the argument and the phrase objects are the same. 1847 | # 1848 | # In the special case of discontinuous arguments, the argument 1849 | # is an "arg" object which contains a number of phrases (one 1850 | # for each discontinuous piece). Then, the argument spans from 1851 | # the start word of its first phrase to the end word of its last 1852 | # phrase. As for the composing phrases, the type of the first one 1853 | # is the type of the argument, say A, whereas the type of the 1854 | # subsequent phrases is "C-A" (continuation tag). 1855 | # 1856 | ################################################################## 1857 | 1858 | package SRL::arg; 1859 | 1860 | use strict; 1861 | 1862 | #push @SRL::arg::ISA, 'SRL::phrase'; 1863 | use base qw(SRL::phrase); 1864 | 1865 | 1866 | # Constructor "new" inherited from SRL::phrase 1867 | 1868 | # Checks whether the argument is single (returning true) 1869 | # or discontinuous (returning false) 1870 | sub single { 1871 | my ($a) = @_; 1872 | return scalar(@{$a->[3]}==0); 1873 | } 1874 | 1875 | # Generates a string representing the argument 1876 | sub to_string { 1877 | my $a = shift; 1878 | 1879 | my $s = $a->type."_(" . $a->start . " "; 1880 | map { $s .= $_->to_string." " } $a->phrases; 1881 | $s .= $a->end . ")"; 1882 | 1883 | return $s; 1884 | } 1885 | 1886 | 1887 | 1; 1888 | 1889 | 1890 | 1891 | 1892 | 1893 | 1894 | 1895 | 1896 | 1897 | ################################################################## 1898 | # 1899 | # Package w o r d : a word 1900 | # 1901 | # April 2004 1902 | # 1903 | # A word, containing id (position in sentence), form and PoS tag 1904 | # 1905 | ################################################################## 1906 | 1907 | use strict; 1908 | 1909 | package SRL::word; 1910 | 1911 | # Constructor: creates a new word 1912 | # Parameters: id (position), form and PoS tag 1913 | sub new { 1914 | my ($pkg, @fields) = @_; 1915 | 1916 | my $w = []; 1917 | 1918 | $w->[0] = shift @fields; # id (position in sentence) 1919 | $w->[1] = shift @fields; # form 1920 | $w->[2] = shift @fields; # PoS 1921 | 1922 | return bless $w, $pkg; 1923 | } 1924 | 1925 | # returns the id of the word 1926 | sub id { 1927 | my $w = shift; 1928 | return $w->[0]; 1929 | } 1930 | 1931 | # returns the form of the word 1932 | sub form { 1933 | my $w = shift; 1934 | return $w->[1]; 1935 | } 1936 | 1937 | # returns the PoS tag of the word 1938 | sub pos { 1939 | my $w = shift; 1940 | return $w->[2]; 1941 | } 1942 | 1943 | sub to_string { 1944 | my $w = shift; 1945 | return "w@".$w->[0].":".$w->[1].":".$w->[2]; 1946 | } 1947 | 1948 | 1; 1949 | 1950 | 1951 | 1952 | 1953 | 1954 | --------------------------------------------------------------------------------