├── LICENSE ├── README.md ├── datasets └── FewFC │ ├── cascading_sampled │ ├── dev.json │ ├── shared_args_list.json │ ├── test.json │ ├── train.json │ └── ty_args.json │ └── data │ ├── dev.json │ ├── test.json │ └── train.json ├── logs └── model.log ├── main.py ├── models ├── layers.py └── model.py ├── pre_cascading.py └── utils ├── data_loader.py ├── framework.py ├── metric.py ├── params.py ├── predict_with_oracle.py ├── predict_without_oracle.py ├── utils_io_data.py └── utils_io_model.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jiawei Sheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CasEE 2 | Source code for ACL2021 finding paper: [*CasEE: A Joint Learning Framework with Cascade Decoding for Overlapping Event Extraction*](https://aclanthology.org/2021.findings-acl.14/). 3 | 4 | Event extraction (EE) is a crucial information extraction task that aims to extract event information in texts. This work studies the realistic event overlapping problem, where a word may serve as triggers with several types or arguments with different roles. To tackle the above issues, this work proposes a joint learning framework CasEE with cascade decoding for overlapping event extraction. Particularly, CasEE sequentially performs type detection, trigger extraction and argument extraction, where the overlapped targets are extracted separately conditioned on the specific former prediction. All the subtasks are jointly learned in a framework to capture dependencies among the subtasks. The evaluation demonstrates that CasEE achieves significant improvements on overlapping event extraction over previous competitive methods. 5 | 6 | 7 | 8 | # Requirements 9 | 10 | We conduct our experiments on the following environments: 11 | 12 | ``` 13 | python 3.6 14 | CUDA: 9.0 15 | GPU: Tesla T4 16 | pytorch == 1.1.0 17 | transformers == 4.9.1 18 | ``` 19 | 20 | # Datasets 21 | 22 | Since ACE 2005 dataset has few overlapping problem, we adopt Chinese Financial Event Extraction dataset as our evaluation dataset. 23 | The original dataset can be accessed at [this repo](https://github.com/TimeBurningFish/FewFC). 24 | Here we re-split train/dev/test data since the original literature has different experimental settings. 25 | Note that the re-splited data is avaliable at ``/dataset/FewFC/data``, and we adjust data format for simplicity of data loader. 26 | To run the code on other dataset, you could also adjust the data as the data format presented. 27 | 28 | # How to run 29 | 30 | To run the code, you could sequentially run the code as following steps: 31 | 32 | 1. Data preprocessing: Generate cascading sampled data for training, achieving the cascading learning strategy of the framework (which has been generated at ``/dataset/FewFC/cascading_sampled``): 33 | 34 | ``` 35 | python pre_cascading.py 36 | ``` 37 | 38 | 39 | 2. Train/Dev/Test the model: Run as follows to train/dev/test the model: 40 | 41 | ``` 42 | CUDA_VISIBLE_DEVICES=0 nohup python -u main.py --output_model_path ./models_save/model.bin --do_train True --do_eval True --do_test True > logs/model.log & 43 | ``` 44 | 45 | The hyper-parameters are recorded in ``/utils/params.py``. 46 | We adopt ``bert-base-chinese`` as our pretrained language model. For extention, you could also try further hyper-parameters for even better performance. 47 | 48 | # Citation 49 | 50 | If you find this code useful, please cite our work: 51 | 52 | ``` 53 | @inproceedings{Sheng2021:CasEE, 54 | title = "{C}as{EE}: {A} Joint Learning Framework with Cascade Decoding for Overlapping Event Extraction", 55 | author = "Sheng, Jiawei and 56 | Guo, Shu and 57 | Yu, Bowen and 58 | Li, Qian and 59 | Hei, Yiming and 60 | Wang, Lihong and 61 | Liu, Tingwen and 62 | Xu, Hongbo", 63 | booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021", 64 | month = aug, 65 | year = "2021", 66 | address = "Online", 67 | publisher = "Association for Computational Linguistics", 68 | url = "https://aclanthology.org/2021.findings-acl.14", 69 | doi = "10.18653/v1/2021.findings-acl.14", 70 | pages = "164--174", 71 | } 72 | ``` 73 | 74 | -------------------------------------------------------------------------------- /datasets/FewFC/cascading_sampled/shared_args_list.json: -------------------------------------------------------------------------------- 1 | ["collateral", "obj-per", "sub-per", "sub-org", "share-per", "title", "way", "money", "obj-org", "number", "amount", "proportion", "target-company", "date", "sub", "share-org", "obj", "institution"] -------------------------------------------------------------------------------- /datasets/FewFC/cascading_sampled/ty_args.json: -------------------------------------------------------------------------------- 1 | {"质押": ["number", "obj-per", "collateral", "date", "sub-per", "obj-org", "proportion", "sub-org", "money"], "股份股权转让": ["number", "target-company", "obj-per", "date", "collateral", "sub-per", "obj-org", "proportion", "sub-org", "money"], "投资": ["money", "date", "sub", "obj"], "减持": ["share-per", "share-org", "obj", "date", "sub", "title"], "起诉": ["obj-per", "date", "sub-per", "obj-org", "sub-org"], "收购": ["number", "date", "sub-per", "obj-org", "proportion", "sub-org", "money", "way"], "判决": ["obj-per", "date", "sub-per", "obj-org", "sub-org", "institution", "money"], "签署合同": ["obj-per", "date", "amount", "sub-per", "obj-org", "sub-org"], "担保": ["date", "amount", "obj-org", "sub-per", "sub-org", "way"], "中标": ["amount", "date", "sub", "obj"]} 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf -*- 2 | from tqdm import tqdm 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from utils.params import parse_args 7 | from models.model import CasEE 8 | from sklearn.metrics import * 9 | import transformers 10 | from transformers import * 11 | from utils.framework import Framework 12 | from utils.data_loader import get_dict, collate_fn_dev, collate_fn_train, collate_fn_test, Data 13 | import torch 14 | import os 15 | from utils.metric import gen_idx_event_dict 16 | from utils.utils_io_data import read_jsonl, write_jsonl 17 | 18 | MODEL_CLASSES = {'bert': (BertConfig, BertModel, BertTokenizer), 'albert-zh': (AlbertConfig, AlbertModel, BertTokenizer), 'auto': (AutoConfig, AutoModel, AutoTokenizer)} 19 | 20 | 21 | def main(): 22 | if not os.path.exists('plm'): 23 | os.makedirs('plm') 24 | if not os.path.exists('models_save'): 25 | os.makedirs('models_save') 26 | if not os.path.exists('logs'): 27 | os.makedirs('logs') 28 | 29 | config = parse_args() 30 | config.type_id, config.id_type, config.args_id, config.id_args, config.ty_args, config.ty_args_id, config.args_s_id, config.args_e_id = get_dict(config.data_path) 31 | 32 | config.args_num = len(config.args_s_id.keys()) 33 | config.type_num = len(config.type_id.keys()) 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | config.device = device 36 | config.model_type = 'bert' 37 | 38 | config_class, model_class, tokenizer_class = MODEL_CLASSES[config.model_type] 39 | config_plm = config_class.from_pretrained(config.model_name_or_path, cache_dir=config.cache_dir if config.cache_dir else None) 40 | config.hidden_size = config_plm.hidden_size 41 | tokenizer = tokenizer_class.from_pretrained(config.model_name_or_path, do_lower_case=config.do_lower_case, cache_dir=config.cache_dir if config.cache_dir else None) 42 | model_weight = model_class.from_pretrained(config.model_name_or_path, from_tf=bool('.ckpt' in config.model_name_or_path), cache_dir=config.cache_dir if config.cache_dir else None) 43 | 44 | model = CasEE(config, model_weight, pos_emb_size=config.rp_size) 45 | framework = Framework(config, model) 46 | 47 | if config.do_train: 48 | train_set = Data(task='train', fn=config.data_path + '/cascading_sampled/train.json', tokenizer=tokenizer, seq_len=config.seq_length, args_s_id=config.args_s_id, args_e_id=config.args_e_id, type_id=config.type_id) 49 | train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn_train) 50 | dev_set = Data(task='eval_with_oracle', fn=config.data_path + '/cascading_sampled/dev.json', tokenizer=tokenizer, seq_len=config.seq_length, args_s_id=config.args_s_id, args_e_id=config.args_e_id, type_id=config.type_id) 51 | dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, collate_fn=collate_fn_dev) 52 | framework.train(train_loader, dev_loader) 53 | 54 | if config.do_eval: 55 | framework.load_model(config.output_model_path) 56 | print("Dev set evaluation with oracle.") 57 | dev_set = Data(task='eval_with_oracle', fn=config.data_path + '/cascading_sampled/dev.json', tokenizer=tokenizer, seq_len=config.seq_length, args_s_id=config.args_s_id, args_e_id=config.args_e_id, type_id=config.type_id) 58 | dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, collate_fn=collate_fn_dev) 59 | c_ps, c_rs, c_fs, t_ps, t_rs, t_fs, a_ps, a_rs, a_fs = framework.evaluate_with_oracle(config, model, dev_loader, config.device, config.ty_args_id, config.id_type) 60 | f1_mean_all = (c_fs + t_fs + a_fs) / 3 61 | print('Evaluate on all types:') 62 | print("Type P: {:.3f}, Type R: {:.3f}, Type F: {:.3f}".format(c_ps, c_rs, c_fs)) 63 | print("Trigger P: {:.3f}, Trigger R: {:.3f}, Trigger F: {:.3f}".format(t_ps, t_rs, t_fs)) 64 | print("Args P: {:.3f}, Args R: {:.3f}, Args F: {:.3f}".format(a_ps, a_rs, a_fs)) 65 | print("F1 Mean All: {:.3f}".format(f1_mean_all)) 66 | 67 | if config.do_test: 68 | if config.batch_size != 1: 69 | print('For simplicity, reset batch_size=1 to extract each sentence') 70 | config.batch_size = 1 71 | framework.load_model(config.output_model_path) 72 | # Evaluation on test set given ground-truth results of former subtasks. 73 | print("Test set evaluation with oracle.") 74 | dev_set = Data(task='eval_with_oracle', fn=config.data_path + '/cascading_sampled/test.json', tokenizer=tokenizer, seq_len=config.seq_length, args_s_id=config.args_s_id, args_e_id=config.args_e_id, type_id=config.type_id) 75 | dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, collate_fn=collate_fn_dev) 76 | c_ps, c_rs, c_fs, t_ps, t_rs, t_fs, a_ps, a_rs, a_fs = framework.evaluate_with_oracle(config, model, dev_loader, config.device, config.ty_args_id, config.id_type) 77 | f1_mean_all = (c_fs + t_fs + a_fs) / 3 78 | print('Evaluate on all types:') 79 | print("Type P: {:.3f}, Type R: {:.3f}, Type F: {:.3f}".format(c_ps, c_rs, c_fs)) 80 | print("Trigger P: {:.3f}, Trigger R: {:.3f}, Trigger F: {:.3f}".format(t_ps, t_rs, t_fs)) 81 | print("Args P: {:.3f}, Args R: {:.3f}, Args F: {:.3f}".format(a_ps, a_rs, a_fs)) 82 | print("F1 Mean All: {:.3f}".format(f1_mean_all)) 83 | 84 | # Evaluation on test set given oracle predictions. 85 | print("Test set evaluation.") 86 | dev_set = Data(task='eval_without_oracle', fn=config.test_path, tokenizer=tokenizer, seq_len=config.seq_length, args_s_id=config.args_s_id, args_e_id=config.args_e_id, type_id=config.type_id) 87 | dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, collate_fn=collate_fn_test) 88 | print("The number of testing instances:", len(dev_set)) 89 | prf_s, pred_records = framework.evaluate_without_oracle(config, model, dev_loader, config.device, config.seq_length, config.id_type, config.id_args, config.ty_args_id) 90 | metric_names = ['TI', 'TC', 'AI', 'AC'] 91 | for i, prf in enumerate(prf_s): 92 | print('{}: P:{:.1f}, R:{:.1f}, F:{:.1f}'.format(metric_names[i], prf[0] * 100, prf[1] * 100, prf[2] * 100)) 93 | 94 | write_jsonl(pred_records, config.output_result_path) 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import math 3 | import torch 4 | 5 | 6 | def gelu(x): 7 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 8 | 9 | 10 | class ConditionalLayerNorm(nn.Module): 11 | def __init__(self, hidden_size, eps=1e-6): 12 | super(ConditionalLayerNorm, self).__init__() 13 | self.eps = eps 14 | self.gamma_dense = nn.Linear(hidden_size, hidden_size, bias=False) 15 | self.beta_dense = nn.Linear(hidden_size, hidden_size, bias=False) 16 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 17 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 18 | 19 | nn.init.zeros_(self.gamma_dense.weight) 20 | nn.init.zeros_(self.beta_dense.weight) 21 | 22 | def forward(self, x, condition): 23 | ''' 24 | 25 | :param x: [b, t, e] 26 | :param condition: [b, e] 27 | :return: 28 | ''' 29 | mean = x.mean(-1, keepdim=True) 30 | std = x.std(-1, keepdim=True) 31 | 32 | condition = condition.unsqueeze(1).expand_as(x) 33 | gamma = self.gamma_dense(condition) + self.gamma 34 | beta = self.beta_dense(condition) + self.beta 35 | x = gamma * (x - mean) / (std + self.eps) + beta 36 | return x 37 | 38 | 39 | class AdaptiveAdditionPredictor(nn.Module): 40 | def __init__(self, hidden_size, dropout_rate=0.0): 41 | super(AdaptiveAdditionPredictor, self).__init__() 42 | self.v = nn.Linear(hidden_size * 4, 1) 43 | self.hidden = nn.Linear(hidden_size * 4, hidden_size * 4) 44 | self.dropout = nn.Dropout(dropout_rate) 45 | 46 | def forward(self, query, context, mask): 47 | ''' 48 | :param query: [c, e] 49 | :param context: [b, t, e] 50 | :param mask: [b, t], 0 if masked 51 | :return: [b, e] 52 | ''' 53 | 54 | context_ = context.unsqueeze(1).expand(context.size(0), query.size(0), context.size(1), context.size(2)) # [b, c, t, e] 55 | query_ = query.unsqueeze(0).unsqueeze(2).expand_as(context_) # [b, c, t, e] 56 | 57 | scores = self.v(torch.tanh(self.hidden(torch.cat([query_, context_, torch.abs(query_ - context_), query_ * context_], dim=-1)))) # [b, c, t, 1] 58 | scores = self.dropout(scores) 59 | mask = (mask < 1).unsqueeze(1).unsqueeze(3).expand_as(scores) # [b, c, t, 1] 60 | scores = scores.masked_fill_(mask, -1e10) 61 | scores = scores.transpose(-1, -2) # [b, c, 1, t] 62 | scores = torch.softmax(scores, dim=-1) # [b, c, 1, t] 63 | g = torch.matmul(scores, context_).squeeze(2) # [b, c, e] 64 | query = query.unsqueeze(0).expand_as(g) # [b, c, e] 65 | 66 | pred = self.v(torch.tanh(self.hidden(torch.cat([query, g, torch.abs(query - g), query * g], dim=-1)))).squeeze(-1) # [b, c] 67 | return pred 68 | 69 | 70 | class MultiHeadedAttention(nn.Module): 71 | """ 72 | Each head is a self-attention operation. 73 | self-attention refers to https://arxiv.org/pdf/1706.03762.pdf 74 | """ 75 | def __init__(self, hidden_size, heads_num, dropout): 76 | super(MultiHeadedAttention, self).__init__() 77 | self.hidden_size = hidden_size 78 | self.heads_num = heads_num 79 | self.per_head_size = hidden_size // heads_num 80 | 81 | self.linear_layers = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(3)]) 82 | 83 | self.dropout = nn.Dropout(dropout) 84 | self.final_linear = nn.Linear(hidden_size, hidden_size) 85 | 86 | def forward(self, key, value, query, mask): 87 | """ 88 | Args: 89 | key: [batch_size x seq_length x hidden_size] 90 | value: [batch_size x seq_length x hidden_size] 91 | query: [batch_size x seq_length x hidden_size] 92 | mask: [batch_size x seq_length] 93 | mask is 0 if it is masked 94 | 95 | Returns: 96 | output: [batch_size x seq_length x hidden_size] 97 | """ 98 | batch_size, seq_length, hidden_size = key.size() 99 | heads_num = self.heads_num 100 | per_head_size = self.per_head_size 101 | 102 | def shape(x): 103 | return x. \ 104 | contiguous(). \ 105 | view(batch_size, seq_length, heads_num, per_head_size). \ 106 | transpose(1, 2) 107 | 108 | def unshape(x): 109 | return x. \ 110 | transpose(1, 2). \ 111 | contiguous(). \ 112 | view(batch_size, seq_length, hidden_size) 113 | 114 | query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))] 115 | 116 | scores = torch.matmul(query, key.transpose(-2, -1)) 117 | scores = scores / math.sqrt(float(per_head_size)) 118 | mask = mask. \ 119 | unsqueeze(1). \ 120 | repeat(1, seq_length, 1). \ 121 | unsqueeze(1) 122 | mask = mask.float() 123 | mask = (1.0 - mask) * -10000.0 124 | scores = scores + mask 125 | probs = nn.Softmax(dim=-1)(scores) 126 | probs = self.dropout(probs) 127 | output = unshape(torch.matmul(probs, value)) 128 | output = self.final_linear(output) 129 | return output 130 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from models.layers import * 2 | 3 | 4 | class TypeCls(nn.Module): 5 | def __init__(self, config): 6 | super(TypeCls, self).__init__() 7 | self.type_emb = nn.Embedding(config.type_num, config.hidden_size) 8 | self.register_buffer('type_indices', torch.arange(0, config.type_num, 1).long()) 9 | self.dropout = nn.Dropout(config.decoder_dropout) 10 | 11 | self.config = config 12 | self.Predictor = AdaptiveAdditionPredictor(config.hidden_size, dropout_rate=config.decoder_dropout) 13 | 14 | def forward(self, text_rep, mask): 15 | type_emb = self.type_emb(self.type_indices) 16 | pred = self.Predictor(type_emb, text_rep, mask) # [b, c] 17 | p_type = torch.sigmoid(pred) 18 | return p_type, type_emb 19 | 20 | 21 | class TriggerRec(nn.Module): 22 | def __init__(self, config, hidden_size): 23 | super(TriggerRec, self).__init__() 24 | self.ConditionIntegrator = ConditionalLayerNorm(hidden_size) 25 | self.SA = MultiHeadedAttention(hidden_size, heads_num=config.decoder_num_head, dropout=config.decoder_dropout) 26 | 27 | self.hidden = nn.Linear(hidden_size, hidden_size) 28 | self.head_cls = nn.Linear(hidden_size, 1, bias=True) 29 | self.tail_cls = nn.Linear(hidden_size, 1, bias=True) 30 | 31 | self.layer_norm = nn.LayerNorm(hidden_size) 32 | self.dropout = nn.Dropout(config.decoder_dropout) 33 | self.config = config 34 | 35 | def forward(self, query_emb, text_emb, mask): 36 | ''' 37 | 38 | :param query_emb: [b, e] 39 | :param text_emb: [b, t, e] 40 | :param mask: 0 if masked 41 | :return: [b, t, 1], [], [] 42 | ''' 43 | 44 | h_cln = self.ConditionIntegrator(text_emb, query_emb) 45 | 46 | h_cln = self.dropout(h_cln) 47 | h_sa = self.SA(h_cln, h_cln, h_cln, mask) 48 | h_sa = self.dropout(h_sa) 49 | inp = self.layer_norm(h_sa + h_cln) 50 | inp = gelu(self.hidden(inp)) 51 | inp = self.dropout(inp) 52 | p_s = torch.sigmoid(self.head_cls(inp)) # [b, t, 1] 53 | p_e = torch.sigmoid(self.tail_cls(inp)) # [b, t, 1] 54 | return p_s, p_e, h_cln 55 | 56 | 57 | class ArgsRec(nn.Module): 58 | def __init__(self, config, hidden_size, num_labels, seq_len, pos_emb_size): 59 | super(ArgsRec, self).__init__() 60 | self.relative_pos_embed = nn.Embedding(seq_len * 2, pos_emb_size) 61 | self.ConditionIntegrator = ConditionalLayerNorm(hidden_size) 62 | self.SA = MultiHeadedAttention(hidden_size, heads_num=config.decoder_num_head, dropout=config.decoder_dropout) 63 | self.hidden = nn.Linear(hidden_size + pos_emb_size, hidden_size) 64 | 65 | self.head_cls = nn.Linear(hidden_size, num_labels, bias=True) 66 | self.tail_cls = nn.Linear(hidden_size, num_labels, bias=True) 67 | 68 | self.gate_hidden = nn.Linear(hidden_size, hidden_size) 69 | self.gate_linear = nn.Linear(hidden_size, num_labels) 70 | 71 | self.seq_len = seq_len 72 | self.dropout = nn.Dropout(config.decoder_dropout) 73 | self.layer_norm = nn.LayerNorm(hidden_size) 74 | self.config = config 75 | 76 | def forward(self, text_emb, relative_pos, trigger_mask, mask, type_emb): 77 | ''' 78 | :param query_emb: [b, 4, e] 79 | :param text_emb: [b, t, e] 80 | :param relative_pos: [b, t, e] 81 | :param trigger_mask: [b, t] 82 | :param mask: 83 | :param type_emb: [b, e] 84 | :return: [b, t, a], [] 85 | ''' 86 | trigger_emb = torch.bmm(trigger_mask.unsqueeze(1).float(), text_emb).squeeze(1) # [b, e] 87 | trigger_emb = trigger_emb / 2 88 | 89 | h_cln = self.ConditionIntegrator(text_emb, trigger_emb) 90 | h_cln = self.dropout(h_cln) 91 | h_sa = self.SA(h_cln, h_cln, h_cln, mask) 92 | h_sa = self.dropout(h_sa) 93 | h_sa = self.layer_norm(h_sa + h_cln) 94 | 95 | rp_emb = self.relative_pos_embed(relative_pos) 96 | rp_emb = self.dropout(rp_emb) 97 | 98 | inp = torch.cat([h_sa, rp_emb], dim=-1) 99 | 100 | inp = gelu(self.hidden(inp)) 101 | inp = self.dropout(inp) 102 | 103 | p_s = torch.sigmoid(self.head_cls(inp)) # [b, t, l] 104 | p_e = torch.sigmoid(self.tail_cls(inp)) 105 | 106 | type_soft_constrain = torch.sigmoid(self.gate_linear(type_emb)) # [b, l] 107 | type_soft_constrain = type_soft_constrain.unsqueeze(1).expand_as(p_s) 108 | p_s = p_s * type_soft_constrain 109 | p_e = p_e * type_soft_constrain 110 | 111 | return p_s, p_e, type_soft_constrain 112 | 113 | 114 | class CasEE(nn.Module): 115 | def __init__(self, config, model_weight, pos_emb_size): 116 | super(CasEE, self).__init__() 117 | self.bert = model_weight 118 | 119 | self.config = config 120 | self.args_num = config.args_num 121 | self.text_seq_len = config.seq_length 122 | 123 | self.type_cls = TypeCls(config) 124 | self.trigger_rec = TriggerRec(config, config.hidden_size) 125 | self.args_rec = ArgsRec(config, config.hidden_size, self.args_num, self.text_seq_len, pos_emb_size) 126 | self.dropout = nn.Dropout(config.decoder_dropout) 127 | 128 | self.loss_0 = nn.BCELoss(reduction='none') 129 | self.loss_1 = nn.BCELoss(reduction='none') 130 | self.loss_2 = nn.BCELoss(reduction='none') 131 | 132 | def forward(self, tokens, segment, mask, type_id, type_vec, trigger_s_vec, trigger_e_vec, relative_pos, trigger_mask, args_s_vec, args_e_vec, args_mask): 133 | ''' 134 | 135 | :param tokens: [b, t] 136 | :param segment: [b, t] 137 | :param mask: [b, t], 0 if masked 138 | :param trigger_s: [b, t] 139 | :param trigger_e: [b, t] 140 | :param relative_pos: 141 | :param trigger_mask: [0000011000000] 142 | :param args_s: [b, l, t] 143 | :param args_e: [b, l, t] 144 | :param args_m: [b, k] 145 | :return: 146 | ''' 147 | 148 | outputs = self.bert( 149 | tokens, 150 | attention_mask=mask, 151 | token_type_ids=segment, 152 | position_ids=None, 153 | head_mask=None, 154 | inputs_embeds=None, 155 | output_attentions=None, 156 | output_hidden_states=None, 157 | ) 158 | 159 | output_emb = outputs[0] 160 | p_type, type_emb = self.type_cls(output_emb, mask) 161 | p_type = p_type.pow(self.config.pow_0) 162 | type_loss = self.loss_0(p_type, type_vec) 163 | type_loss = torch.sum(type_loss) 164 | 165 | type_rep = type_emb[type_id, :] 166 | p_s, p_e, text_rep_type = self.trigger_rec(type_rep, output_emb, mask) 167 | p_s = p_s.pow(self.config.pow_1) 168 | p_e = p_e.pow(self.config.pow_1) 169 | p_s = p_s.squeeze(-1) 170 | p_e = p_e.squeeze(-1) 171 | trigger_loss_s = self.loss_1(p_s, trigger_s_vec) 172 | trigger_loss_e = self.loss_1(p_e, trigger_e_vec) 173 | mask_t = mask.float() # [b, t] 174 | trigger_loss_s = torch.sum(trigger_loss_s.mul(mask_t)) 175 | trigger_loss_e = torch.sum(trigger_loss_e.mul(mask_t)) 176 | 177 | p_s, p_e, type_soft_constrain = self.args_rec(text_rep_type, relative_pos, trigger_mask, mask, type_rep) 178 | p_s = p_s.pow(self.config.pow_2) 179 | p_e = p_e.pow(self.config.pow_2) 180 | args_loss_s = self.loss_2(p_s, args_s_vec.transpose(1, 2)) # [b, t, l] 181 | args_loss_e = self.loss_2(p_e, args_e_vec.transpose(1, 2)) 182 | mask_a = mask.unsqueeze(-1).expand_as(args_loss_s).float() # [b, t, l] 183 | args_loss_s = torch.sum(args_loss_s.mul(mask_a)) 184 | args_loss_e = torch.sum(args_loss_e.mul(mask_a)) 185 | 186 | trigger_loss = trigger_loss_s + trigger_loss_e 187 | args_loss = args_loss_s + args_loss_e 188 | 189 | type_loss = self.config.w1 * type_loss 190 | trigger_loss = self.config.w2 * trigger_loss 191 | args_loss = self.config.w3 * args_loss 192 | loss = type_loss + trigger_loss + args_loss 193 | return loss, type_loss, trigger_loss, args_loss 194 | 195 | def plm(self, tokens, segment, mask): 196 | assert tokens.size(0) == 1 197 | 198 | outputs = self.bert( 199 | tokens, 200 | attention_mask=mask, 201 | token_type_ids=segment, 202 | position_ids=None, 203 | head_mask=None, 204 | inputs_embeds=None, 205 | output_attentions=None, 206 | output_hidden_states=None, 207 | ) 208 | output_emb = outputs[0] 209 | return output_emb 210 | 211 | def predict_type(self, text_emb, mask): 212 | assert text_emb.size(0) == 1 213 | p_type, type_emb = self.type_cls(text_emb, mask) 214 | p_type = p_type.view(self.config.type_num).data.cpu().numpy() 215 | return p_type, type_emb 216 | 217 | def predict_trigger(self, type_rep, text_emb, mask): 218 | assert text_emb.size(0) == 1 219 | p_s, p_e, text_rep_type = self.trigger_rec(type_rep, text_emb, mask) 220 | p_s = p_s.squeeze(-1) # [b, t] 221 | p_e = p_e.squeeze(-1) 222 | mask = mask.float() # [1, t] 223 | p_s = p_s.mul(mask) 224 | p_e = p_e.mul(mask) 225 | p_s = p_s.view(self.text_seq_len).data.cpu().numpy() # [b, t] 226 | p_e = p_e.view(self.text_seq_len).data.cpu().numpy() 227 | return p_s, p_e, text_rep_type 228 | 229 | def predict_args(self, text_rep_type, relative_pos, trigger_mask, mask, type_rep): 230 | assert text_rep_type.size(0) == 1 231 | p_s, p_e, type_soft_constrain = self.args_rec(text_rep_type, relative_pos, trigger_mask, mask, type_rep) 232 | mask = mask.unsqueeze(-1).expand_as(p_s).float() # [b, t, l] 233 | p_s = p_s.mul(mask) 234 | p_e = p_e.mul(mask) 235 | p_s = p_s.view(self.text_seq_len, self.args_num).data.cpu().numpy() 236 | p_e = p_e.view(self.text_seq_len, self.args_num).data.cpu().numpy() 237 | return p_s, p_e, type_soft_constrain 238 | -------------------------------------------------------------------------------- /pre_cascading.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def load_data(file): 5 | with open(file, 'r', encoding='utf-8') as f: 6 | lines = f.readlines() 7 | 8 | records = [] 9 | for line in lines: 10 | record = json.loads(line) 11 | records.append(record) 12 | return records 13 | 14 | 15 | def write(data, fn): 16 | with open(fn, 'w', encoding='utf-8') as f: 17 | for line in data: 18 | line = json.dumps(line, ensure_ascii=False) 19 | f.write(line + '\n') 20 | 21 | 22 | TYPES = ['起诉', '投资', '减持', '股份股权转让', '质押', '收购', '判决', '签署合同', '担保', '中标'] 23 | 24 | 25 | def process(records): 26 | data_len = len(records) 27 | data_new = [] 28 | for i in range(data_len): 29 | record = records[i] 30 | data_id = record['id'] 31 | events = record['events'] 32 | content = record['content'] 33 | 34 | # label all occurrence types 35 | type_occur = [] 36 | for TYPE in TYPES: 37 | for event in events: 38 | event_type = event['type'] 39 | if event_type == TYPE: 40 | type_occur.append(TYPE) 41 | type_occur = list(set(type_occur)) 42 | 43 | # label triggers and arguments 44 | for TYPE in TYPES: 45 | events_typed = [] 46 | for event in events: 47 | event_type = event['type'] 48 | if event_type == TYPE: 49 | events_typed.append(event) 50 | # label triggers 51 | if len(events_typed) != 0: 52 | triggers = [] 53 | trigger_args = {} 54 | for event in events_typed: 55 | trigger = event['trigger']['span'] 56 | if trigger not in triggers: 57 | triggers.append(trigger) 58 | trigger_args[str(trigger)] = trigger_args.get(str(trigger), {}) 59 | for arg_role in event['args']: 60 | trigger_args[str(trigger)][arg_role] = trigger_args[str(trigger)].get(arg_role, []) 61 | args_roled_spans = [item['span'] for item in event['args'][arg_role]] 62 | for args_roled_span in args_roled_spans: 63 | if args_roled_span not in trigger_args[str(trigger)][arg_role]: 64 | trigger_args[str(trigger)][arg_role].append(args_roled_span) 65 | # according to trigger order, write json record 66 | triggers_str = [str(trigger) for trigger in triggers] # with order 67 | for trigger_str in trigger_args: 68 | index = triggers_str.index(trigger_str) 69 | data_dict = {} 70 | data_dict['id'] = data_id 71 | data_dict['content'] = content 72 | data_dict['occur'] = type_occur 73 | data_dict['type'] = TYPE 74 | data_dict['triggers'] = triggers 75 | data_dict['index'] = index 76 | data_dict['args'] = trigger_args[trigger_str] 77 | data_new.append(data_dict) 78 | return data_new 79 | 80 | 81 | # {"id": "9e573f697633ad200c9aa86d70bc2103", "content": "此外,隆鑫通用的公告还透露,除了中信证券,隆鑫控股与四川信托有限公司、中国国际金融有限公司、国民信托有限公司的股票质押合同均已到期,但“经与质权人友好协商,目前暂不会做违约处置”。", "events": [ 82 | # {"type": "质押", "trigger": {"span": [57, 59], "word": "质押"}, "args": {"sub-org": [{"span": [3, 7], "word": "隆鑫通用"}], "obj-org": [{"span": [16, 20], "word": "中信证券"}], "collateral": [{"span": [55, 57], "word": "股票"}]}}, 83 | # {"type": "质押", "trigger": {"span": [57, 59], "word": "质押"}, "args": {"sub-org": [{"span": [3, 7], "word": "隆鑫通用"}], "obj-org": [{"span": [26, 34], "word": "四川信托有限公司"}], "collateral": [{"span": [55, 57], "word": "股票"}]}}, 84 | # {"type": "质押", "trigger": {"span": [57, 59], "word": "质押"}, "args": {"sub-org": [{"span": [3, 7], "word": "隆鑫通用"}], "obj-org": [{"span": [35, 45], "word": "中国国际金融有限公司"}], "collateral": [{"span": [55, 57], "word": "股票"}]}}, 85 | # {"type": "质押", "trigger": {"span": [57, 59], "word": "质押"}, "args": {"sub-org": [{"span": [3, 7], "word": "隆鑫通用"}], "obj-org": [{"span": [46, 54], "word": "国民信托有限公司"}], "collateral": [{"span": [55, 57], "word": "股票"}]}}]} 86 | 87 | 88 | def main(): 89 | train = load_data('./datasets/FewFC/data/train.json') 90 | train = process(train) 91 | write(train, './datasets/FewFC/cascading_sampled/train.json') 92 | 93 | dev = load_data('./datasets/FewFC/data/dev.json') 94 | dev = process(dev) 95 | write(dev, './datasets/FewFC/cascading_sampled/dev.json') 96 | 97 | test = load_data('./datasets/FewFC/data/test.json') 98 | test = process(test) 99 | write(test, './datasets/FewFC/cascading_sampled/test.json') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | 7 | def get_dict(fn): 8 | with open(fn + '/cascading_sampled/ty_args.json', 'r', encoding='utf-8') as f: 9 | ty_args = json.load(f) 10 | if not os.path.exists(fn + '/cascading_sampled/shared_args_list.json'): 11 | args_list = set() 12 | for ty in ty_args: 13 | for arg in ty_args[ty]: 14 | args_list.add(arg) 15 | args_list = list(args_list) 16 | with open(fn + '/cascading_sampled/shared_args_list.json', 'w', encoding='utf-8') as f: 17 | json.dump(args_list, f, ensure_ascii=False) 18 | else: 19 | with open(fn + '/cascading_sampled/shared_args_list.json', 'r', encoding='utf-8') as f: 20 | args_list = json.load(f) 21 | 22 | args_s_id = {} 23 | args_e_id = {} 24 | for i in range(len(args_list)): 25 | s = args_list[i] + '_s' 26 | args_s_id[s] = i 27 | e = args_list[i] + '_e' 28 | args_e_id[e] = i 29 | 30 | id_type = {i: item for i, item in enumerate(ty_args)} 31 | type_id = {item: i for i, item in enumerate(ty_args)} 32 | 33 | id_args = {i: item for i, item in enumerate(args_list)} 34 | args_id = {item: i for i, item in enumerate(args_list)} 35 | ty_args_id = {} 36 | for ty in ty_args: 37 | args = ty_args[ty] 38 | tmp = [args_id[a] for a in args] 39 | ty_args_id[type_id[ty]] = tmp 40 | return type_id, id_type, args_id, id_args, ty_args, ty_args_id, args_s_id, args_e_id 41 | 42 | 43 | def read_labeled_data(fn): 44 | ''' Read Train Data / Dev Data ''' 45 | with open(fn, 'r', encoding='utf-8') as f: 46 | lines = f.readlines() 47 | data_ids = [] 48 | data_content = [] 49 | data_type = [] 50 | data_occur = [] 51 | data_triggers = [] 52 | data_index = [] 53 | data_args = [] 54 | for line in lines: 55 | line_dict = json.loads(line.strip()) 56 | data_ids.append(line_dict.get('id', 0)) 57 | data_occur.append(line_dict['occur']) 58 | data_type.append(line_dict['type']) 59 | data_content.append(line_dict['content']) 60 | data_index.append(line_dict['index']) 61 | data_triggers.append(line_dict['triggers']) 62 | data_args.append(line_dict['args']) 63 | return data_ids, data_occur, data_type, data_content, data_triggers, data_index, data_args 64 | 65 | 66 | def read_unlabeled_data(fn): 67 | ''' Read Test Data''' 68 | with open(fn, 'r', encoding='utf-8') as f: 69 | lines = f.readlines() 70 | data_ids = [] 71 | data_content = [] 72 | for line in lines: 73 | line_dict = json.loads(line.strip()) 74 | data_ids.append(line_dict['id']) 75 | data_content.append(line_dict['content']) 76 | return data_ids, data_content 77 | 78 | 79 | def get_relative_pos(start_idx, end_idx, length): 80 | ''' 81 | return relative position 82 | [start_idx, end_idx] 83 | ''' 84 | pos = list(range(-start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list(range(1, length - end_idx)) 85 | return pos 86 | 87 | 88 | def get_trigger_mask(start_idx, end_idx, length): 89 | ''' 90 | used to generate trigger mask, where the element of start/end postion is 1 91 | [000010100000] 92 | ''' 93 | mask = [0] * length 94 | mask[start_idx] = 1 95 | mask[end_idx] = 1 96 | return mask 97 | 98 | 99 | class Data(Dataset): 100 | def __init__(self, task, fn, tokenizer=None, seq_len=None, args_s_id=None, args_e_id=None, type_id=None): 101 | assert task in ['train', 'eval_with_oracle', 'eval_without_oracle'] 102 | self.task = task 103 | self.tokenizer = tokenizer 104 | self.seq_len = seq_len 105 | self.args_s_id = args_s_id 106 | self.args_e_id = args_e_id 107 | self.args_num = len(args_s_id.keys()) 108 | self.type_id = type_id 109 | self.type_num = len(type_id.keys()) 110 | 111 | if self.task == 'eval_without_oracle': 112 | data_ids, data_content = read_unlabeled_data(fn) 113 | self.data_ids = data_ids 114 | self.data_content = data_content 115 | tokens_ids, segs_ids, masks_ids = self.data_to_id(data_content) 116 | 117 | self.token = tokens_ids 118 | self.seg = segs_ids 119 | self.mask = masks_ids 120 | 121 | else: 122 | data_ids, data_occur, data_type, data_content, data_triggers, data_index, data_args = read_labeled_data(fn) 123 | self.data_ids = data_ids 124 | self.data_occur = data_occur 125 | self.data_triggers = data_triggers 126 | self.data_args = data_args 127 | 128 | self.data_content = data_content 129 | tokens_ids, segs_ids, masks_ids = self.data_to_id(data_content) 130 | self.token = tokens_ids 131 | self.seg = segs_ids 132 | self.mask = masks_ids 133 | 134 | data_type_id_s, type_vec_s = self.type_to_id(data_type, data_occur) 135 | self.data_type_id_s = data_type_id_s 136 | self.type_vec_s = type_vec_s 137 | 138 | self.r_pos, self.t_m = self.get_rp_tm(data_triggers, data_index) 139 | self.t_index = data_index 140 | 141 | if self.task == 'train': 142 | t_s, t_e = self.trigger_seq_id(data_triggers) 143 | self.t_s = t_s 144 | self.t_e = t_e 145 | a_s, a_e, a_m = self.args_seq_id(data_args) 146 | self.a_s = a_s 147 | self.a_e = a_e 148 | self.a_m = a_m 149 | 150 | if self.task == 'eval_with_oracle': 151 | self.data_content = data_content 152 | self.data_args = data_args 153 | self.data_triggers = data_triggers 154 | triggers_truth_s, args_truth_s = self.results_for_eval() 155 | self.triggers_truth = triggers_truth_s 156 | self.args_truth = args_truth_s 157 | 158 | def __len__(self): 159 | return len(self.data_ids) 160 | 161 | def __getitem__(self, index): 162 | if self.task == 'train': 163 | return self.data_ids[index], \ 164 | self.data_type_id_s[index], \ 165 | self.type_vec_s[index], \ 166 | self.token[index], \ 167 | self.seg[index], \ 168 | self.mask[index], \ 169 | self.t_index[index], \ 170 | self.r_pos[index], \ 171 | self.t_m[index], \ 172 | self.t_s[index], \ 173 | self.t_e[index], \ 174 | self.a_s[index], \ 175 | self.a_e[index], \ 176 | self.a_m[index] 177 | elif self.task == 'eval_with_oracle': 178 | return self.data_ids[index], \ 179 | self.data_type_id_s[index], \ 180 | self.type_vec_s[index], \ 181 | self.token[index], \ 182 | self.seg[index], \ 183 | self.mask[index], \ 184 | self.t_index[index], \ 185 | self.r_pos[index], \ 186 | self.t_m[index], \ 187 | self.triggers_truth[index], \ 188 | self.args_truth[index] 189 | elif self.task == 'eval_without_oracle': 190 | return self.data_ids[index], \ 191 | self.data_content[index], \ 192 | self.token[index], \ 193 | self.seg[index], \ 194 | self.mask[index] 195 | else: 196 | raise Exception('task not define !') 197 | 198 | def data_to_id(self, data_contents): 199 | tokens_ids = [] 200 | segs_ids = [] 201 | masks_ids = [] 202 | for i in range(len(self.data_ids)): 203 | data_content = data_contents[i] 204 | # default uncased 205 | data_content = [token.lower() for token in data_content] 206 | data_content = list(data_content) 207 | # Here we add and token for BERT input 208 | # transformers == 4.9.1 209 | inputs = self.tokenizer.encode_plus(data_content, add_special_tokens=True, max_length=self.seq_len, truncation=True, padding='max_length') 210 | tokens, segs, masks = inputs["input_ids"], inputs["token_type_ids"], inputs['attention_mask'] 211 | tokens_ids.append(tokens) 212 | segs_ids.append(segs) 213 | masks_ids.append(masks) 214 | return tokens_ids, segs_ids, masks_ids 215 | 216 | def type_to_id(self, data_type, data_occur): 217 | data_type_id_s, type_vec_s = [], [] 218 | for i in range(len(self.data_ids)): 219 | data_type_id = self.type_id[data_type[i]] 220 | type_vec = np.array([0] * self.type_num) 221 | for occ in data_occur[i]: 222 | idx = self.type_id[occ] 223 | type_vec[idx] = 1 224 | data_type_id_s.append(data_type_id) 225 | type_vec_s.append(type_vec) 226 | return data_type_id_s, type_vec_s 227 | 228 | def trigger_seq_id(self, data_triggers): 229 | ''' 230 | given trigger span, return ground truth trigger matrix, for bce loss 231 | t_s: trigger start sequence, 1 for position 0 232 | t_e: trigger end sequence, 1 for position 0 233 | ''' 234 | trigger_s = [] 235 | trigger_e = [] 236 | for i in range(len(self.data_ids)): 237 | data_trigger = data_triggers[i] 238 | t_s = [0] * self.seq_len 239 | t_e = [0] * self.seq_len 240 | 241 | for t in data_trigger: 242 | # plus 1 for additional token 243 | t_s[t[0] + 1] = 1 244 | t_e[t[1] + 1 - 1] = 1 245 | 246 | trigger_s.append(t_s) 247 | trigger_e.append(t_e) 248 | return trigger_s, trigger_e 249 | 250 | def args_seq_id(self, data_args_list): 251 | ''' 252 | given argument span, return ground truth argument matrix, for bce loss 253 | ''' 254 | args_s_lines = [] 255 | args_e_lines = [] 256 | arg_masks = [] 257 | for i in range(len(self.data_ids)): 258 | args_s = np.zeros(shape=[self.args_num, self.seq_len]) 259 | args_e = np.zeros(shape=[self.args_num, self.seq_len]) 260 | data_args_dict = data_args_list[i] 261 | arg_mask = [0] * self.args_num 262 | for args_name in data_args_dict: 263 | s_r_i = self.args_s_id[args_name + '_s'] 264 | e_r_i = self.args_e_id[args_name + '_e'] 265 | arg_mask[s_r_i] = 1 266 | for span in data_args_dict[args_name]: 267 | # plus 1 for additional token 268 | args_s[s_r_i][span[0] + 1] = 1 269 | args_e[e_r_i][span[1] + 1 - 1] = 1 270 | args_s_lines.append(args_s) 271 | args_e_lines.append(args_e) 272 | arg_masks.append(arg_mask) 273 | return args_s_lines, args_e_lines, arg_masks 274 | 275 | def results_for_eval(self): 276 | ''' 277 | read structured ground truth, for evaluating model performance 278 | ''' 279 | triggers_truth_s = [] 280 | args_truth_s = [] 281 | for i in range(len(self.data_ids)): 282 | triggers = self.data_triggers[i] 283 | args = self.data_args[i] 284 | # plus 1 for additional token 285 | triggers_truth = [(span[0] + 1, span[1] + 1 - 1) for span in triggers] 286 | args_truth = {i: [] for i in range(self.args_num)} 287 | for args_name in args: 288 | s_r_i = self.args_s_id[args_name + '_s'] 289 | for span in args[args_name]: 290 | # plus 1 for additional token 291 | args_truth[s_r_i].append((span[0] + 1, span[1] + 1 - 1)) 292 | triggers_truth_s.append(triggers_truth) 293 | args_truth_s.append(args_truth) 294 | return triggers_truth_s, args_truth_s 295 | 296 | def get_rp_tm(self, triggers, data_index): 297 | ''' 298 | get relative position embedding and trigger mask, according to the trigger span 299 | r_pos: relation position embedding 300 | t_m: trigger mask, used for mean pooling 301 | ''' 302 | r_pos = [] 303 | t_m = [] 304 | for i in range(len(self.data_ids)): 305 | trigger = triggers[i] 306 | index = data_index[i] 307 | span = trigger[index] 308 | # plus 1 for additional token 309 | pos = get_relative_pos(span[0] + 1, span[1] + 1 - 1, self.seq_len) 310 | pos = [p + self.seq_len for p in pos] 311 | # plus 1 for additional token 312 | mask = get_trigger_mask(span[0] + 1, span[1] + 1 - 1, self.seq_len) 313 | r_pos.append(pos) 314 | t_m.append(mask) 315 | return r_pos, t_m 316 | 317 | 318 | def collate_fn_train(data): 319 | ''' 320 | :param data: [(x, y), (x, y), (), (), ()] 321 | :return: 322 | idx: the id of data record 323 | dt: the type of event (str) 324 | t_v: 325 | token: token sequence 326 | seg: segment sequence 327 | mask: mask sequence 328 | t_index: unused; used to indicate the trigger number of argument 329 | r_pos: relative position embedding 330 | t_m: trigger_mask,where 1 for start/end postion 331 | t_s, ground_truth of trigger start 332 | t_e, ground_truth of trigger end 333 | a_s, ground_truth of argument start 334 | a_e, ground_truth of argument end 335 | a_m, unused; used to indicate the correlation between argument role and event type. 336 | ''' 337 | idx, dt, t_v, token, seg, mask, t_index, r_pos, t_m, t_s, t_e, a_s, a_e, a_m = zip(*data) 338 | return idx, dt, t_v, token, seg, mask, t_index, r_pos, t_m, t_s, t_e, a_s, a_e, a_m 339 | 340 | 341 | def collate_fn_dev(data): 342 | ''' 343 | :param data: [(x, y), (x, y), (), (), ()] 344 | :return: 345 | ''' 346 | idx, dt, t_v, token, seg, mask, t_index, r_pos, t_m, t_t, a_t = zip(*data) 347 | return idx, dt, t_v, token, seg, mask, t_index, r_pos, t_m, t_t, a_t 348 | 349 | 350 | def collate_fn_test(data): 351 | ''' 352 | :param data: [(x, y), (x, y), (), (), ()] 353 | :return: 354 | ''' 355 | idx, dc, token, seg, mask = zip(*data) 356 | return idx, dc, token, seg, mask 357 | -------------------------------------------------------------------------------- /utils/framework.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch.nn as nn 4 | from transformers import AdamW, get_linear_schedule_with_warmup 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 7 | from utils.utils_io_model import load_model, save_model 8 | import torch 9 | import numpy as np 10 | from sklearn.metrics import * 11 | from utils.predict_without_oracle import extract_all_items_without_oracle 12 | from utils.predict_with_oracle import predict_one 13 | from tqdm import tqdm 14 | from utils.metric import score, gen_idx_event_dict, cal_scores, cal_scores_ti_tc_ai_ac 15 | from utils.utils_io_data import read_jsonl, write_jsonl 16 | 17 | 18 | class Framework(object): 19 | def __init__(self, config, model): 20 | self.config = config 21 | self.model = model.to(config.device) 22 | 23 | def load_model(self, model_path): 24 | self.model = load_model(self.model, model_path) 25 | 26 | def set_learning_setting(self, config, train_loader, dev_loader, model): 27 | instances_num = len(train_loader.dataset) 28 | train_steps = int(instances_num * config.epochs_num / config.batch_size) + 1 29 | 30 | print("Batch size: ", config.batch_size) 31 | print("The number of training instances:", instances_num) 32 | print("The number of evaluating instances:", len(dev_loader.dataset)) 33 | 34 | bert_params = list(map(id, model.bert.parameters())) 35 | 36 | other_params = filter(lambda p: id(p) not in bert_params, model.parameters()) 37 | optimizer_grouped_parameters = [{'params': model.bert.parameters()}, {'params': other_params, 'lr': config.lr_task}] 38 | 39 | optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr_bert, correct_bias=False) 40 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=train_steps * config.warmup, num_training_steps=train_steps) 41 | 42 | if config.fp16: 43 | try: 44 | from apex import amp 45 | except ImportError: 46 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 47 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16_opt_level) 48 | 49 | if torch.cuda.device_count() > 1: 50 | print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) 51 | self.model = torch.nn.DataParallel(model) 52 | 53 | return scheduler, optimizer 54 | 55 | def train(self, train_loader, dev_loader): 56 | scheduler, optimizer = self.set_learning_setting(self.config, train_loader, dev_loader, self.model) 57 | # going to train 58 | total_loss = 0.0 59 | ed_loss = 0.0 60 | te_loss = 0.0 61 | ae_loss = 0.0 62 | best_f1 = 0.0 63 | best_epoch = 0 64 | for epoch in range(1, self.config.epochs_num + 1): 65 | print('Training...') 66 | self.model.train() 67 | for i, (idx, d_t, t_v, token, seg, mask, t_index, r_pos, t_m, t_s, t_e, a_s, a_e, a_m) in enumerate(train_loader): 68 | self.model.zero_grad() 69 | d_t = torch.LongTensor(d_t).to(self.config.device) 70 | t_v = torch.FloatTensor(t_v).to(self.config.device) 71 | token = torch.LongTensor(token).to(self.config.device) 72 | seg = torch.LongTensor(seg).to(self.config.device) 73 | mask = torch.LongTensor(mask).to(self.config.device) 74 | r_pos = torch.LongTensor(r_pos).to(self.config.device) 75 | t_m = torch.LongTensor(t_m).to(self.config.device) 76 | t_s = torch.FloatTensor(t_s).to(self.config.device) 77 | t_e = torch.FloatTensor(t_e).to(self.config.device) 78 | a_s = torch.FloatTensor(a_s).to(self.config.device) 79 | a_e = torch.FloatTensor(a_e).to(self.config.device) 80 | a_m = torch.LongTensor(a_m).to(self.config.device) 81 | loss, type_loss, trigger_loss, args_loss = self.model(token, seg, mask, d_t, t_v, t_s, t_e, r_pos, t_m, a_s, a_e, a_m) 82 | if torch.cuda.device_count() > 1: 83 | loss = torch.mean(loss) 84 | type_loss = torch.mean(type_loss) 85 | trigger_loss = torch.mean(trigger_loss) 86 | args_loss = torch.mean(args_loss) 87 | 88 | total_loss += loss.item() 89 | ed_loss += type_loss.item() 90 | te_loss += trigger_loss.item() 91 | ae_loss += args_loss.item() 92 | 93 | if (i + 1) % self.config.report_steps == 0: 94 | print("Epoch id: {}, Training steps: {}, ED loss:{:.6f},TE loss:{:.6f}, AE loss:{:.6f}, Avg loss: {:.6f}".format(epoch, i + 1, ed_loss / self.config.report_steps, te_loss / self.config.report_steps, ae_loss / self.config.report_steps, 95 | total_loss / self.config.report_steps)) 96 | total_loss = 0.0 97 | ed_loss = 0.0 98 | te_loss = 0.0 99 | ae_loss = 0.0 100 | if self.config.fp16: 101 | with amp.scale_loss(loss, optimizer) as scaled_loss: 102 | scaled_loss.backward() 103 | else: 104 | loss.backward() 105 | optimizer.step() 106 | scheduler.step() 107 | 108 | print('Evaluating...') 109 | c_ps, c_rs, c_fs, t_ps, t_rs, t_fs, a_ps, a_rs, a_fs = self.evaluate_with_oracle(self.config, self.model, dev_loader, self.config.device, self.config.ty_args_id, self.config.id_type) 110 | f1_mean_all = (c_fs + t_fs + a_fs) / 3 111 | print('Evaluate on all types:') 112 | print("Epoch id: {}, Type P: {:.3f}, Type R: {:.3f}, Type F: {:.3f}".format(epoch, c_ps, c_rs, c_fs)) 113 | print("Epoch id: {}, Trigger P: {:.3f}, Trigger R: {:.3f}, Trigger F: {:.3f}".format(epoch, t_ps, t_rs, t_fs)) 114 | print("Epoch id: {}, Args P: {:.3f}, Args R: {:.3f}, Args F: {:.3f}".format(epoch, a_ps, a_rs, a_fs)) 115 | print("Epoch id: {}, F1 Mean All: {:.3f}".format(epoch, f1_mean_all)) 116 | 117 | if f1_mean_all > best_f1: 118 | best_f1 = f1_mean_all 119 | best_epoch = epoch 120 | save_model(self.model, self.config.output_model_path) 121 | print("The Best F1 Is: {:.3f}, When Epoch Is: {}".format(best_f1, best_epoch)) 122 | 123 | def evaluate_with_oracle(self, config, model, dev_data_loader, device, ty_args_id, id2type): 124 | if hasattr(model, "module"): 125 | model = model.module 126 | model.eval() 127 | # since there exists "an" idx with "several" records, we use dict to combine the results 128 | type_pred_dict = {} 129 | type_truth_dict = {} 130 | trigger_pred_tuples_dict = {} 131 | trigger_truth_tuples_dict = {} 132 | args_pred_tuples_dict = {} 133 | args_truth_tuples_dict = {} 134 | 135 | for i, (idx, typ_oracle, typ_truth, token, seg, mask, t_index, r_p, t_m, tri_truth, args_truth) in tqdm(enumerate(dev_data_loader)): 136 | typ_oracle = torch.LongTensor(typ_oracle).to(device) 137 | typ_truth = torch.FloatTensor(typ_truth).to(device) 138 | token = torch.LongTensor(token).to(device) 139 | seg = torch.LongTensor(seg).to(device) 140 | mask = torch.LongTensor(mask).to(device) 141 | r_p = torch.LongTensor(r_p).to(device) 142 | t_m = torch.LongTensor(t_m).to(device) 143 | 144 | tri_oracle = tri_truth[0][t_index[0]] 145 | type_pred, type_truth, trigger_pred_tuples, trigger_truth_tuples, args_pred_tuples, args_truth_tuples = predict_one(model, config, typ_truth, token, seg, mask, r_p, t_m, tri_truth, args_truth, ty_args_id, typ_oracle, tri_oracle) 146 | 147 | idx = idx[0] 148 | # collect type predictions 149 | if idx not in type_pred_dict: 150 | type_pred_dict[idx] = type_pred 151 | if idx not in type_truth_dict: 152 | type_truth_dict[idx] = type_truth 153 | 154 | # collect trigger predictions 155 | if idx not in trigger_pred_tuples_dict: 156 | trigger_pred_tuples_dict[idx] = [] 157 | trigger_pred_tuples_dict[idx].extend(trigger_pred_tuples) 158 | if idx not in trigger_truth_tuples_dict: 159 | trigger_truth_tuples_dict[idx] = [] 160 | trigger_truth_tuples_dict[idx].extend(trigger_truth_tuples) 161 | 162 | # collect argument predictions 163 | if idx not in args_pred_tuples_dict: 164 | args_pred_tuples_dict[idx] = [] 165 | args_pred_tuples_dict[idx].extend(args_pred_tuples) 166 | if idx not in args_truth_tuples_dict: 167 | args_truth_tuples_dict[idx] = [] 168 | args_truth_tuples_dict[idx].extend(args_truth_tuples) 169 | 170 | # Here we calculate event detection metric (macro). 171 | type_pred_s, type_truth_s = [], [] 172 | for idx in type_truth_dict.keys(): 173 | type_pred_s.append(type_pred_dict[idx]) 174 | type_truth_s.append(type_truth_dict[idx]) 175 | type_pred_s = np.array(type_pred_s) 176 | type_truth_s = np.array(type_truth_s) 177 | c_ps = precision_score(type_truth_s, type_pred_s, average='macro') 178 | c_rs = recall_score(type_truth_s, type_pred_s, average='macro') 179 | c_fs = f1_score(type_truth_s, type_pred_s, average='macro') 180 | 181 | # Here we calculate TC and AC metric with oracle inputs. 182 | t_p, t_r, t_f = score(trigger_pred_tuples_dict, trigger_truth_tuples_dict) 183 | a_p, a_r, a_f = score(args_pred_tuples_dict, args_truth_tuples_dict) 184 | return c_ps, c_rs, c_fs, t_p, t_r, t_f, a_p, a_r, a_f 185 | 186 | def evaluate_without_oracle(self, config, model, data_loader, device, seq_len, id_type, id_args, ty_args_id): 187 | if torch.cuda.device_count() > 1: 188 | model = model.module 189 | model.eval() 190 | results = [] 191 | for i, (idx, content, token, seg, mask) in tqdm(enumerate(data_loader)): 192 | idx = idx[0] 193 | token = torch.LongTensor(token).to(device) 194 | seg = torch.LongTensor(seg).to(device) 195 | mask = torch.LongTensor(mask).to(device) 196 | result = extract_all_items_without_oracle(model, device, idx, content, token, seg, mask, seq_len, config.threshold_0, config.threshold_1, config.threshold_2, config.threshold_3, config.threshold_4, id_type, id_args, ty_args_id) 197 | results.append(result) 198 | pred_records = results 199 | pred_dict = gen_idx_event_dict(pred_records) 200 | gold_records = read_jsonl(self.config.test_path) 201 | gold_dict = gen_idx_event_dict(gold_records) 202 | prf_s = cal_scores_ti_tc_ai_ac(pred_dict, gold_dict) 203 | return prf_s, pred_records 204 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def score(preds_tuple, golds_tuple): 5 | ''' 6 | Modified from https://github.com/xinyadu/eeqa 7 | ''' 8 | gold_mention_n, pred_mention_n, true_positive_n = 0, 0, 0 9 | for sentence_id in golds_tuple: 10 | gold_sentence_mentions = golds_tuple[sentence_id] 11 | pred_sentence_mentions = preds_tuple[sentence_id] 12 | gold_sentence_mentions = set(gold_sentence_mentions) 13 | pred_sentence_mentions = set(pred_sentence_mentions) 14 | for mention in pred_sentence_mentions: 15 | pred_mention_n += 1 16 | for mention in gold_sentence_mentions: 17 | gold_mention_n += 1 18 | for mention in pred_sentence_mentions: 19 | if mention in gold_sentence_mentions: 20 | true_positive_n += 1 21 | prec_c, recall_c, f1_c = 0, 0, 0 22 | if pred_mention_n != 0: 23 | prec_c = true_positive_n / pred_mention_n 24 | else: 25 | prec_c = 0 26 | if gold_mention_n != 0: 27 | recall_c = true_positive_n / gold_mention_n 28 | else: 29 | recall_c = 0 30 | if prec_c or recall_c: 31 | f1_c = 2 * prec_c * recall_c / (prec_c + recall_c) 32 | else: 33 | f1_c = 0 34 | return prec_c, recall_c, f1_c 35 | 36 | 37 | def gen_tuples(record): 38 | if record: 39 | ti, tc, ai, ac = [], [], [], [] 40 | for event in record: 41 | typ, trigger_span = event['type'], event['trigger']['span'] 42 | ti_one = (trigger_span[0], trigger_span[1]) 43 | tc_one = (typ, trigger_span[0], trigger_span[1]) 44 | ti.append(ti_one) 45 | tc.append(tc_one) 46 | for arg_role in event['args']: 47 | for arg_role_one in event['args'][arg_role]: 48 | ai_one = (typ, arg_role_one['span'][0], arg_role_one['span'][1]) 49 | ac_one = (typ, arg_role_one['span'][0], arg_role_one['span'][1], arg_role) 50 | 51 | ai.append(ai_one) 52 | ac.append(ac_one) 53 | return ti, tc, ai, ac 54 | else: 55 | return [], [], [], [] 56 | 57 | 58 | def cal_scores_ti_tc_ai_ac(preds, golds): 59 | ''' 60 | :param preds: {id: [{type:'', 'trigger':{'span':[], 'word':[]}, args:[role1:[], role2:[], ...}, ...]} 61 | :param golds: 62 | :return: 63 | ''' 64 | # assert len(preds) == len(golds) 65 | tuples_pred = [{}, {}, {}, {}] # ti, tc, ai, ac 66 | tuples_gold = [{}, {}, {}, {}] # ti, tc, ai, ac 67 | 68 | for idx in golds: 69 | if idx not in preds: 70 | pred = None 71 | else: 72 | pred = preds[idx] 73 | gold = golds[idx] 74 | 75 | ti, tc, ai, ac = gen_tuples(pred) 76 | tuples_pred[0][idx] = ti 77 | tuples_pred[1][idx] = tc 78 | tuples_pred[2][idx] = ai 79 | tuples_pred[3][idx] = ac 80 | 81 | ti, tc, ai, ac = gen_tuples(gold) 82 | tuples_gold[0][idx] = ti 83 | tuples_gold[1][idx] = tc 84 | tuples_gold[2][idx] = ai 85 | tuples_gold[3][idx] = ac 86 | 87 | prf_s = [] 88 | for i in range(4): 89 | prf = score(tuples_pred[i], tuples_gold[i]) 90 | prf_s.append(prf) 91 | return prf_s 92 | 93 | 94 | def cal_scores(pred_dict, gold_dict, print_tab=False): 95 | prf_s = cal_scores_ti_tc_ai_ac(pred_dict, gold_dict) 96 | metric_names = ['TI', 'TC', 'AI', 'AC'] 97 | for i, prf in enumerate(prf_s): 98 | if not print_tab: 99 | print('{}: P:{:.1f}, R:{:.1f}, F:{:.1f}'.format(metric_names[i], prf[0] * 100, prf[1] * 100, prf[2] * 100)) 100 | else: 101 | print('{}:\tP:\t{:.1f}\tR:\t{:.1f}\tF:\t{:.1f}'.format(metric_names[i], prf[0] * 100, prf[1] * 100, prf[2] * 100)) 102 | return prf_s 103 | 104 | 105 | def gen_idx_event_dict(records): 106 | data_dict = {} 107 | for line in records: 108 | idx = line['id'] 109 | events = line['events'] 110 | data_dict[idx] = events 111 | return data_dict 112 | -------------------------------------------------------------------------------- /utils/params.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import torch 4 | import os 5 | import numpy as np 6 | 7 | 8 | def seed_everything(seed=7): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | os.environ['PYTHONHASHSEED'] = str(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | 17 | 18 | def str2bool(str): 19 | return True if str.lower() == 'true' else False 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | 25 | # Path options. 26 | parser.add_argument("--data_path", type=str, default='datasets/FewFC', help="Path of the dataset.") 27 | parser.add_argument("--test_path", type=str, default='datasets/FewFC/data/test.json', help="Path of the testset.") 28 | 29 | parser.add_argument("--do_train", default=True, type=str2bool) 30 | parser.add_argument("--do_eval", default=True, type=str2bool) 31 | parser.add_argument("--do_test", default=True, type=str2bool) 32 | 33 | parser.add_argument("--output_result_path", type=str, default='models_save/results.json') 34 | parser.add_argument("--output_model_path", default="./models_save/model.bin", type=str, help="Path of the output model.") 35 | 36 | parser.add_argument("--model_name_or_path", default="bert-base-chinese", type=str, help="Path of the output model.") 37 | parser.add_argument("--cache_dir", default="./plm", type=str, help="Where do you want to store the pre-trained models downloaded") 38 | parser.add_argument("--do_lower_case", action="store_true", help="") 39 | parser.add_argument("--seq_length", default=400, type=int, help="Sequence length.") 40 | 41 | # Training options. 42 | parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 43 | parser.add_argument('--fp16_opt_level', type=str, default='O1', help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") 44 | 45 | parser.add_argument("--seed", type=int, default=42, help="Random seed.") 46 | parser.add_argument("--lr_bert", type=float, default=2e-5, help="Learning rate for BERT.") 47 | parser.add_argument("--lr_task", type=float, default=1e-4, help="Learning rate for task layers.") 48 | parser.add_argument("--warmup", type=float, default=0.1, help="Warm up value.") 49 | parser.add_argument("--batch_size", type=int, default=8, help="Batch_size.") 50 | parser.add_argument("--epochs_num", type=int, default=20, help="Number of epochs.") 51 | parser.add_argument("--report_steps", type=int, default=5, help="Specific steps to print prompt.") 52 | 53 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay value") 54 | parser.add_argument("--dropout", type=float, default=0.1, help="Dropout on BERT") 55 | parser.add_argument("--decoder_dropout", type=float, default=0.3, help="Dropout on decoders") 56 | 57 | # Model options. 58 | parser.add_argument("--w1", type=float, default=1.0) 59 | parser.add_argument("--w2", type=float, default=1.0) 60 | parser.add_argument("--w3", type=float, default=1.0) 61 | parser.add_argument("--pow_0", type=int, default=1) 62 | parser.add_argument("--pow_1", type=int, default=1) 63 | parser.add_argument("--pow_2", type=int, default=1) 64 | 65 | parser.add_argument("--rp_size", type=int, default=64) 66 | parser.add_argument("--decoder_num_head", type=int, default=1) 67 | 68 | parser.add_argument("--threshold_0", type=float, default=0.5) 69 | parser.add_argument("--threshold_1", type=float, default=0.5) 70 | parser.add_argument("--threshold_2", type=float, default=0.5) 71 | parser.add_argument("--threshold_3", type=float, default=0.5) 72 | parser.add_argument("--threshold_4", type=float, default=0.5) 73 | 74 | parser.add_argument("--step", type=str, choices=["dev", "test"]) 75 | 76 | args = parser.parse_args() 77 | 78 | seed_everything(args.seed) 79 | return args 80 | -------------------------------------------------------------------------------- /utils/predict_with_oracle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def extract_specific_item_with_oracle(model, d_t, token, seg, mask, rp, tm, args_num, threshold_0, threshold_1, threshold_2, threshold_3, threshold_4, ty_args_id): 6 | assert token.size(0) == 1 7 | data_type = d_t.item() 8 | text_emb = model.plm(token, seg, mask) 9 | 10 | # predict event type 11 | p_type, type_emb = model.predict_type(text_emb, mask) 12 | type_pred = np.array(p_type > threshold_0, dtype=int) 13 | type_rep = type_emb[d_t, :] 14 | 15 | # predict event trigger 16 | p_s, p_e, text_rep_type = model.predict_trigger(type_rep, text_emb, mask) 17 | trigger_s = np.where(p_s > threshold_1)[0] 18 | trigger_e = np.where(p_e > threshold_2)[0] 19 | trigger_spans = [] 20 | for i in trigger_s: 21 | es = trigger_e[trigger_e >= i] 22 | if len(es) > 0: 23 | e = es[0] 24 | trigger_spans.append((i, e)) 25 | 26 | # predict event argument 27 | p_s, p_e, type_soft_constrain = model.predict_args(text_rep_type, rp, tm, mask, type_rep) 28 | p_s = np.transpose(p_s) 29 | p_e = np.transpose(p_e) 30 | args_spans = {i: [] for i in range(args_num)} 31 | for i in ty_args_id[data_type]: 32 | args_s = np.where(p_s[i] > threshold_3)[0] 33 | args_e = np.where(p_e[i] > threshold_4)[0] 34 | for j in args_s: 35 | es = args_e[args_e >= j] 36 | if len(es) > 0: 37 | e = es[0] 38 | args_spans[i].append((j, e)) 39 | return type_pred, trigger_spans, args_spans 40 | 41 | 42 | def predict_one(model, args, typ_truth, token, seg, mask, r_p, t_m, tri_truth, args_truth, ty_args_id, typ_oracle, tri_oracle): 43 | type_pred, trigger_pred, args_pred = extract_specific_item_with_oracle(model, typ_oracle, token, seg, mask, r_p, t_m, args.args_num, args.threshold_0, args.threshold_1, args.threshold_2, args.threshold_3, args.threshold_4, ty_args_id) 44 | type_oracle = typ_oracle.item() 45 | type_truth = typ_truth.view(args.type_num).cpu().numpy().astype(int) 46 | trigger_truth, args_truth = tri_truth[0], args_truth[0] 47 | 48 | # used to save tuples, which is like: 49 | trigger_pred_tuples = [] # (type, tri_sta, tri_end), 3-tuple 50 | trigger_truth_tuples = [] 51 | args_pred_tuples = [] # (type, tri_sta, tri_end, arg_sta, arg_end, arg_role), 6-tuple 52 | args_truth_tuples = [] 53 | 54 | for trigger_pred_one in trigger_pred: 55 | typ = type_oracle 56 | sta = trigger_pred_one[0] 57 | end = trigger_pred_one[1] 58 | trigger_pred_tuples.append((typ, sta, end)) 59 | 60 | for trigger_truth_one in trigger_truth: 61 | typ = type_oracle 62 | sta = trigger_truth_one[0] 63 | end = trigger_truth_one[1] 64 | trigger_truth_tuples.append((typ, sta, end)) 65 | 66 | args_candidates = ty_args_id[type_oracle] # type constrain 67 | for i in args_candidates: 68 | typ = type_oracle 69 | tri_sta = tri_oracle[0] 70 | tri_end = tri_oracle[1] 71 | arg_role = i 72 | for args_pred_one in args_pred[i]: 73 | arg_sta = args_pred_one[0] 74 | arg_end = args_pred_one[1] 75 | args_pred_tuples.append((typ, arg_sta, arg_end, arg_role)) 76 | 77 | for args_truth_one in args_truth[i]: 78 | arg_sta = args_truth_one[0] 79 | arg_end = args_truth_one[1] 80 | args_truth_tuples.append((typ, arg_sta, arg_end, arg_role)) 81 | 82 | return type_pred, type_truth, trigger_pred_tuples, trigger_truth_tuples, args_pred_tuples, args_truth_tuples 83 | -------------------------------------------------------------------------------- /utils/predict_without_oracle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.data_loader import get_relative_pos, get_trigger_mask 4 | 5 | TRI_LEN = 5 6 | ARG_LEN_DICT = { 7 | 'collateral': 14, 8 | 'proportion': 37, 9 | 'obj-org': 34, 10 | 'number': 18, 11 | 'date': 27, 12 | 'sub-org': 35, 13 | 'target-company': 59, 14 | 'sub': 38, 15 | 'obj': 36, 16 | 'share-org': 19, 17 | 'money': 28, 18 | 'title': 8, 19 | 'sub-per': 15, 20 | 'obj-per': 18, 21 | 'share-per': 20, 22 | 'institution': 22, 23 | 'way': 8, 24 | 'amount': 19 25 | } 26 | 27 | 28 | def extract_all_items_without_oracle(model, device, idx, content: str, token, seg, mask, seq_len, threshold_0, threshold_1, threshold_2, threshold_3, threshold_4, id_type: dict, id_args: dict, ty_args_id: dict): 29 | assert token.size(0) == 1 30 | content = content[0] 31 | result = {'id': idx, 'content': content} 32 | text_emb = model.plm(token, seg, mask) 33 | 34 | args_id = {id_args[k]: k for k in id_args} 35 | args_len_dict = {args_id[k]: ARG_LEN_DICT[k] for k in ARG_LEN_DICT} 36 | 37 | p_type, type_emb = model.predict_type(text_emb, mask) 38 | type_pred = np.array(p_type > threshold_0, dtype=bool) 39 | type_pred = [i for i, t in enumerate(type_pred) if t] 40 | events_pred = [] 41 | 42 | for type_pred_one in type_pred: 43 | type_rep = type_emb[type_pred_one, :] 44 | type_rep = type_rep.unsqueeze(0) 45 | p_s, p_e, text_rep_type = model.predict_trigger(type_rep, text_emb, mask) 46 | trigger_s = np.where(p_s > threshold_1)[0] 47 | trigger_e = np.where(p_e > threshold_2)[0] 48 | trigger_spans = [] 49 | 50 | for i in trigger_s: 51 | es = trigger_e[trigger_e >= i] 52 | if len(es) > 0: 53 | e = es[0] 54 | if e - i + 1 <= TRI_LEN: 55 | trigger_spans.append((i, e)) 56 | 57 | for k, span in enumerate(trigger_spans): 58 | rp = get_relative_pos(span[0], span[1], seq_len) 59 | rp = [p + seq_len for p in rp] 60 | tm = get_trigger_mask(span[0], span[1], seq_len) 61 | rp = torch.LongTensor(rp).to(device) 62 | tm = torch.LongTensor(tm).to(device) 63 | rp = rp.unsqueeze(0) 64 | tm = tm.unsqueeze(0) 65 | 66 | p_s, p_e, type_soft_constrain = model.predict_args(text_rep_type, rp, tm, mask, type_rep) 67 | 68 | p_s = np.transpose(p_s) 69 | p_e = np.transpose(p_e) 70 | 71 | type_name = id_type[type_pred_one] 72 | pred_event_one = {'type': type_name} 73 | pred_trigger = {'span': [int(span[0]) - 1, int(span[1]) + 1 - 1], 'word': content[int(span[0]) - 1:int(span[1]) + 1 - 1]} # remove token 74 | pred_event_one['trigger'] = pred_trigger 75 | pred_args = {} 76 | 77 | args_candidates = ty_args_id[type_pred_one] 78 | for i in args_candidates: 79 | pred_args[id_args[i]] = [] 80 | args_s = np.where(p_s[i] > threshold_3)[0] 81 | args_e = np.where(p_e[i] > threshold_4)[0] 82 | for j in args_s: 83 | es = args_e[args_e >= j] 84 | if len(es) > 0: 85 | e = es[0] 86 | if e - j + 1 <= args_len_dict[i]: 87 | pred_arg = {'span': [int(j) - 1, int(e) + 1 - 1], 'word': content[int(j) - 1:int(e) + 1 - 1]} # remove token 88 | pred_args[id_args[i]].append(pred_arg) 89 | 90 | pred_event_one['args'] = pred_args 91 | events_pred.append(pred_event_one) 92 | result['events'] = events_pred 93 | return result 94 | -------------------------------------------------------------------------------- /utils/utils_io_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def read_json(fn): 5 | with open(fn, 'r', encoding='utf-8') as f: 6 | data = json.load(f) 7 | return data 8 | 9 | 10 | def read_jsonl(fn): 11 | with open(fn, 'r', encoding='utf-8') as f: 12 | lines = f.readlines() 13 | data = [] 14 | for line in lines: 15 | data.append(json.loads(line)) 16 | return data 17 | 18 | 19 | def write_json(data, fn): 20 | with open(fn, 'w', encoding='utf-8') as f: 21 | json.dump(data, f, ensure_ascii=False) 22 | 23 | 24 | def write_jsonl(data, fn): 25 | with open(fn, 'w', encoding='utf-8') as f: 26 | for line in data: 27 | line = json.dumps(line, ensure_ascii=False) 28 | f.write(line + '\n') -------------------------------------------------------------------------------- /utils/utils_io_model.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf -*- 2 | import torch 3 | 4 | 5 | def save_model(model, model_path): 6 | if hasattr(model, "module"): 7 | torch.save(model.module.state_dict(), model_path) 8 | else: 9 | torch.save(model.state_dict(), model_path) 10 | 11 | 12 | def load_model(model, model_path, strict=False): 13 | if hasattr(model, "module"): 14 | model.module.load_state_dict(torch.load(model_path, map_location='cpu'), strict=strict) 15 | else: 16 | model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=strict) 17 | return model 18 | --------------------------------------------------------------------------------