├── CRF.py ├── README.md ├── bert_pretrain └── readme.txt ├── const.py ├── models ├── __pycache__ │ ├── bert.cpython-36.pyc │ ├── bert_RCNN.cpython-36.pyc │ └── bert_RNN.cpython-36.pyc └── bert.py ├── nanhai_data ├── data │ ├── nanhai_data.json │ └── readme.txt └── test_result │ └── readme.txt ├── run.py ├── train_eval.py └── utils.py /CRF.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/14 9:51 3 | # @File : CRF.py 4 | # @Last Modify Time : 2018/9/14 9:51 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : CRF.py 9 | FUNCTION : None 10 | REFERENCE : https://github.com/jiesutd/NCRFpp/blob/master/model/crf.py 11 | """ 12 | import torch 13 | from torch.autograd.variable import Variable 14 | import torch.nn as nn 15 | 16 | 17 | def log_sum_exp(vec, m_size): 18 | """ 19 | Args: 20 | vec: size=(batch_size, vanishing_dim, hidden_dim) 21 | m_size: hidden_dim 22 | 23 | Returns: 24 | size=(batch_size, hidden_dim) 25 | """ 26 | _, idx = torch.max(vec, 1) # B * 1 * M 27 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M 28 | return max_score.view(-1, m_size) + torch.log(torch.sum( 29 | torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) 30 | 31 | 32 | class CRF(nn.Module): 33 | """ 34 | CRF 35 | """ 36 | def __init__(self, **kwargs): 37 | """ 38 | kwargs: 39 | target_size: int, target size 40 | device: str, device 41 | """ 42 | super(CRF, self).__init__() 43 | for k in kwargs: 44 | self.__setattr__(k, kwargs[k]) 45 | device = self.device 46 | 47 | # init transitions 48 | self.START_TAG, self.STOP_TAG = -2, -1 49 | init_transitions = torch.zeros(self.target_size + 2, self.target_size + 2, device=device) 50 | init_transitions[:, self.START_TAG] = -10000.0 51 | init_transitions[self.STOP_TAG, :] = -10000.0 52 | self.transitions = nn.Parameter(init_transitions) 53 | 54 | def _forward_alg(self, feats, mask): 55 | """ 56 | Do the forward algorithm to compute the partition function (batched). 57 | 58 | Args: 59 | feats: size=(batch_size, seq_len, self.target_size+2) 60 | mask: size=(batch_size, seq_len) 61 | 62 | Returns: 63 | xxx 64 | """ 65 | batch_size = feats.size(0) 66 | seq_len = feats.size(1) 67 | tag_size = feats.size(2) 68 | mask = mask.transpose(1, 0).contiguous() 69 | ins_num = seq_len * batch_size 70 | """ be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) """ 71 | feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) 72 | """ need to consider start """ 73 | scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 74 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 75 | # build iter 76 | seq_iter = enumerate(scores) 77 | _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size 78 | """ only need start from start_tag """ 79 | partition = inivalues[:, self.START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size 80 | 81 | """ 82 | add start score (from start to all tag, duplicate to batch_size) 83 | partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) 84 | iter over last scores 85 | """ 86 | for idx, cur_values in seq_iter: 87 | """ 88 | previous to_target is current from_target 89 | partition: previous results log(exp(from_target)), #(batch_size * from_target) 90 | cur_values: bat_size * from_target * to_target 91 | """ 92 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 93 | cur_partition = log_sum_exp(cur_values, tag_size) 94 | 95 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) 96 | 97 | """ effective updated partition part, only keep the partition value of mask value = 1 """ 98 | masked_cur_partition = cur_partition.masked_select(mask_idx) 99 | """ let mask_idx broadcastable, to disable warning """ 100 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 101 | 102 | """ replace the partition where the maskvalue=1, other partition value keeps the same """ 103 | partition.masked_scatter_(mask_idx, masked_cur_partition) 104 | """ 105 | until the last state, add transition score for all partition (and do log_sum_exp) 106 | then select the value in STOP_TAG 107 | """ 108 | cur_values = self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 109 | cur_partition = log_sum_exp(cur_values, tag_size) 110 | final_partition = cur_partition[:, self.STOP_TAG] 111 | return final_partition.sum(), scores 112 | 113 | def _viterbi_decode(self, feats, mask): 114 | """ 115 | input: 116 | feats: (batch, seq_len, self.tag_size+2) 117 | mask: (batch, seq_len) 118 | output: 119 | decode_idx: (batch, seq_len) decoded sequence 120 | path_score: (batch, 1) corresponding score for each sequence (to be implementated) 121 | """ 122 | # print(feats.size()) 123 | batch_size = feats.size(0) 124 | seq_len = feats.size(1) 125 | tag_size = feats.size(2) 126 | # assert(tag_size == self.tagset_size+2) 127 | """ calculate sentence length for each sentence """ 128 | length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() 129 | """ mask to (seq_len, batch_size) """ 130 | mask = mask.transpose(1, 0).contiguous() 131 | ins_num = seq_len * batch_size 132 | """ be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) """ 133 | feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 134 | """ need to consider start """ 135 | scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 136 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 137 | 138 | # build iter 139 | seq_iter = enumerate(scores) 140 | # record the position of best score 141 | back_points = list() 142 | partition_history = list() 143 | ## reverse mask (bug for mask = 1- mask, use this as alternative choice) 144 | # mask = 1 + (-1)*mask 145 | mask = (1 - mask.long()).byte() 146 | _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size 147 | """ only need start from start_tag """ 148 | partition = inivalues[:, self.START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size 149 | partition_history.append(partition) 150 | # iter over last scores 151 | for idx, cur_values in seq_iter: 152 | """ 153 | previous to_target is current from_target 154 | partition: previous results log(exp(from_target)), #(batch_size * from_target) 155 | cur_values: batch_size * from_target * to_target 156 | """ 157 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 158 | """ forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG """ 159 | partition, cur_bp = torch.max(cur_values, 1) 160 | partition_history.append(partition) 161 | """ 162 | cur_bp: (batch_size, tag_size) max source score position in current tag 163 | set padded label as 0, which will be filtered in post processing 164 | """ 165 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 166 | back_points.append(cur_bp) 167 | """ add score to final STOP_TAG """ 168 | partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1, 0).contiguous() ## (batch_size, seq_len. tag_size) 169 | """ get the last position for each setences, and select the last partitions using gather() """ 170 | last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 171 | last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) 172 | """ calculate the score from last partition to end state (and then select the STOP_TAG from it) """ 173 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) 174 | _, last_bp = torch.max(last_values, 1) 175 | # self.device = torch.device('cpu') 176 | pad_zero = torch.zeros(batch_size, tag_size, device=self.device, requires_grad=True).long() 177 | back_points.append(pad_zero) 178 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 179 | 180 | """ elect end ids in STOP_TAG """ 181 | pointer = last_bp[:, self.STOP_TAG] 182 | insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) 183 | back_points = back_points.transpose(1,0).contiguous() 184 | """move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values """ 185 | back_points.scatter_(1, last_position, insert_last) 186 | back_points = back_points.transpose(1,0).contiguous() 187 | """ decode from the end, padded position ids are 0, which will be filtered if following evaluation """ 188 | # decode_idx = Variable(torch.LongTensor(seq_len, batch_size)) 189 | decode_idx = torch.empty(seq_len, batch_size, device=self.device, requires_grad=True).long() 190 | decode_idx[-1] = pointer.detach() 191 | for idx in range(len(back_points)-2, -1, -1): 192 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 193 | decode_idx[idx] = pointer.detach().view(batch_size) 194 | path_score = None 195 | decode_idx = decode_idx.transpose(1, 0) 196 | return path_score, decode_idx 197 | 198 | def forward(self, feats, mask): 199 | """ 200 | :param feats: 201 | :param mask: 202 | :return: 203 | """ 204 | path_score, best_path = self._viterbi_decode(feats, mask) 205 | return path_score, best_path 206 | 207 | def _score_sentence(self, scores, mask, tags): 208 | """ 209 | Args: 210 | scores: size=(seq_len, batch_size, tag_size, tag_size) 211 | mask: size=(batch_size, seq_len) 212 | tags: size=(batch_size, seq_len) 213 | 214 | Returns: 215 | score: 216 | """ 217 | # print(scores.size()) 218 | batch_size = scores.size(1) 219 | seq_len = scores.size(0) 220 | tag_size = scores.size(-1) 221 | tags = tags.view(batch_size, seq_len) 222 | """ convert tag value into a new format, recorded label bigram information to index """ 223 | # new_tags = Variable(torch.LongTensor(batch_size, seq_len)) 224 | new_tags = torch.empty(batch_size, seq_len, device=self.device, requires_grad=True).long() 225 | for idx in range(seq_len): 226 | if idx == 0: 227 | new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] 228 | else: 229 | new_tags[:, idx] = tags[:, idx-1] * tag_size + tags[:, idx] 230 | 231 | """ transition for label to STOP_TAG """ 232 | end_transition = self.transitions[:, self.STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) 233 | """ length for batch, last word position = length - 1 """ 234 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() 235 | """ index the label id of last word """ 236 | end_ids = torch.gather(tags, 1, length_mask-1) 237 | 238 | """ index the transition score for end_id to STOP_TAG """ 239 | end_energy = torch.gather(end_transition, 1, end_ids) 240 | 241 | """ convert tag as (seq_len, batch_size, 1) """ 242 | new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) 243 | """ need convert tags id to search from 400 positions of scores """ 244 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) 245 | tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) 246 | 247 | """ 248 | add all score together 249 | gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() 250 | """ 251 | gold_score = tg_energy.sum() + end_energy.sum() 252 | 253 | return gold_score 254 | 255 | def neg_log_likelihood_loss(self, feats, mask, tags): 256 | """ 257 | Args: 258 | feats: size=(batch_size, seq_len, tag_size) 259 | mask: size=(batch_size, seq_len) 260 | tags: size=(batch_size, seq_len) 261 | """ 262 | batch_size = feats.size(0) 263 | forward_score, scores = self._forward_alg(feats, mask) 264 | gold_score = self._score_sentence(scores, mask, tags) 265 | return forward_score - gold_score 266 | 267 | 268 | 269 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chinese-event-extraction-pytorch 2 | 一个简单的用pytorch实现中文事件抽取的代码,写得不好,有待提高,希望对大家有所帮助。有问题欢迎留言。 3 | 4 | # 环境 5 | python 3.6 6 | 7 | torch==1.0.1 8 | 9 | pytorch-pretrained-bert==0.6.2 10 | 11 | # 数据集 12 | 数据集是从新闻网站上爬下来,人工标注的,不方便全部公开(虽然也没多少数据,而且标注质量也有待提高),nanhai_data文件夹下显示有10条json数据,以共参考数据格式。可以按照数据格式,更换自己的数据集。贡献人员:Lei Li; Panpan Jin; Kaiwen Wei; Jianwei Lv; Xiaoyu Li。 13 | 14 | # 效果 15 | bert_RNN: 16 | 识别分类|P|R|F1 17 | --|--|--|-- 18 | 触发词识别|0.689|0.752|0.719 19 | 触发词分类|0.591|0.644|0.616 20 | 论元识别|0.547|0.702|0.615 21 | 论元分类|0.446|0.572|0.510 22 | 23 | # 运行 24 | 运行 run.py 25 | 26 | 通过序列标注(BIO标签)同时识别分类触发词和实体,将识别分类的触发词特征和实体特征拼接,进行角色分类。 27 | 28 | # 预训练语言模型 29 | bert模型放在bert_pretain目录下,三个文件: 30 | 31 | pytorch_model.bin 32 | 33 | bert_config.json 34 | 35 | vocab.txt 36 | 37 | 预训练模型下载地址: 38 | 39 | bert_Chinese: 模型 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz 40 | 41 | 词表 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt 42 | 43 | 来自[这里](https://github.com/huggingface/pytorch-transformers) 44 | 45 | # 参考仓库 46 | 47 | 项目1:https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch 48 | 49 | 项目2:https://github.com/nlpcl-lab/bert-event-extraction 50 | -------------------------------------------------------------------------------- /bert_pretrain/readme.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放下载好的pytorch版的BERT预训练模型 -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | TRIGGERS = ['死亡', 4 | '闯入', 5 | '部署', 6 | '停驻', 7 | '演习', 8 | '配备', 9 | '冲突', 10 | '威胁', 11 | '侦查', 12 | '交涉', 13 | '合作', 14 | '谈判', 15 | '访问', 16 | '制裁', 17 | '购买', 18 | '贸易'] 19 | 20 | 21 | ENTITIES=[ 22 | '人物', 23 | '机构', 24 | '国家', 25 | '时间', 26 | '地点', 27 | '数量', 28 | '装备' 29 | ] 30 | 31 | ARGUMENTS=['Arg-Place', 'Arg-Subject', 'Arg-Time', 'Arg-Object', 'Arg-Number'] 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/__pycache__/bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jianwei-Lv/chinese-event-extraction-pytorch/33ae0c5353e1f8d963d9625d9eb5dcdcb4174639/models/__pycache__/bert.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/bert_RCNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jianwei-Lv/chinese-event-extraction-pytorch/33ae0c5353e1f8d963d9625d9eb5dcdcb4174639/models/__pycache__/bert_RCNN.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/bert_RNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jianwei-Lv/chinese-event-extraction-pytorch/33ae0c5353e1f8d963d9625d9eb5dcdcb4174639/models/__pycache__/bert_RNN.cpython-36.pyc -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import torch.nn as nn 4 | from pytorch_pretrained_bert import BertModel, BertTokenizer 5 | # from pytorch_pretrained import BertModel, BertTokenizer 6 | from utils import all_triggers_entities, trigger_entities2idx, idx2trigger_entities,find_triggers,all_arguments, argument2idx, idx2argument 7 | from CRF import CRF 8 | 9 | 10 | class Config(object): 11 | 12 | """配置参数""" 13 | def __init__(self, dataset): 14 | self.model_name = 'bert' 15 | self.train_path = dataset + '/data/nanhai_data.json' # 训练集 16 | self.device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') # 设备 17 | 18 | self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 19 | self.num_epochs = 50 # epoch数 20 | self.batch_size =32 # mini-batch大小 21 | self.pad_size = 128 # 每句话处理成的长度(短填长切) 22 | self.learning_rate = 5e-5 # 学习率 23 | self.bert_path = './bert_pretrain' 24 | self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) 25 | self.hidden_size = 768 26 | self.rnn_hidden = 768 27 | self.num_layers = 1 28 | self.dropout = 0.1 29 | 30 | 31 | class Model(nn.Module): 32 | 33 | def __init__(self, config): 34 | super(Model, self).__init__() 35 | self.bert = BertModel.from_pretrained(config.bert_path) 36 | for param in self.bert.parameters(): 37 | param.requires_grad = True 38 | self.fc=nn.Sequential(nn.Linear(config.hidden_size, 256), 39 | nn.Dropout(0.5), 40 | nn.Linear(256, len(all_triggers_entities)+2)) 41 | 42 | self.fc_argument = nn.Sequential(nn.Linear(config.hidden_size*2, 256), 43 | nn.Dropout(0.5), 44 | nn.Linear(256, len(all_arguments))) 45 | self.device=config.device 46 | 47 | self.lstm = nn.LSTM(config.hidden_size, config.rnn_hidden//2, config.num_layers, 48 | bidirectional=True, batch_first=True) 49 | kwargs = dict({'target_size': len(all_triggers_entities), 'device': self.device}) 50 | self.tri_CRF1 = CRF(**kwargs) 51 | 52 | def forward(self, x,label,train=True,condidate_entity=None): 53 | context = x[0] # 输入的句子 54 | mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] 55 | arguments_2d=x[-1] 56 | 57 | triggers_y_2d = label 58 | encoder_out, pooled = self.bert(context, attention_mask=torch.LongTensor(mask).to(self.device), output_all_encoded_layers=False) 59 | encoder_out, _ = self.lstm(encoder_out) 60 | out=self.fc(encoder_out) 61 | 62 | trigger_loss = self.tri_CRF1.neg_log_likelihood_loss(feats=out, mask=torch.ByteTensor(mask).to(self.device), tags=triggers_y_2d) 63 | _, trigger_entities_hat_2d = self.tri_CRF1.forward(feats=out, mask=torch.ByteTensor(mask).to(self.device)) 64 | 65 | # trigger_entities_hat_2d = out.argmax(-1) 66 | batch_size = encoder_out.shape[0] 67 | argument_hidden,argument_keys = [],[] 68 | for i in range(batch_size): 69 | 70 | predicted_triggers, predicted_entities = find_triggers([idx2trigger_entities[trigger] for trigger in trigger_entities_hat_2d[i].tolist()]) 71 | 72 | golden_entity_tensors = {} 73 | for j in range(len(predicted_entities)): 74 | e_start, e_end, e_type_str = predicted_entities[j] 75 | golden_entity_tensors[predicted_entities[j]] = encoder_out[i, e_start:e_end, ].mean(dim=0) 76 | 77 | 78 | for predicted_trigger in predicted_triggers: 79 | t_start, t_end, t_type_str = predicted_trigger 80 | 81 | event_tensor = encoder_out[i, t_start:t_end, ].mean(dim=0) 82 | for j in range(len(predicted_entities)): 83 | e_start, e_end, e_type_str = predicted_entities[j] 84 | entity_tensor = golden_entity_tensors[predicted_entities[j]] 85 | 86 | argument_hidden.append(torch.cat([entity_tensor,event_tensor])) 87 | argument_keys.append((i, t_start, t_end, t_type_str, e_start, e_end, e_type_str)) 88 | 89 | if len(argument_keys) > 0: 90 | argument_hidden = torch.stack(argument_hidden) 91 | argument_hidden_logits = self.fc_argument(argument_hidden) 92 | 93 | argument_hidden_hat_1d = argument_hidden_logits.argmax(-1) 94 | 95 | arguments_y_1d = [] 96 | for i, t_start, t_end, t_type_str, e_start, e_end, e_type_str in argument_keys: 97 | a_label = argument2idx['NONE'] 98 | if (t_start, t_end, t_type_str) in arguments_2d[i]['events']: 99 | for (a_start, a_end, a_type_idx) in arguments_2d[i]['events'][(t_start, t_end, t_type_str)]: 100 | if e_start == a_start and e_end == a_end: 101 | a_label = a_type_idx 102 | break 103 | arguments_y_1d.append(a_label) 104 | 105 | arguments_y_1d = torch.LongTensor(arguments_y_1d).to(self.device) 106 | 107 | batch_size = len(arguments_2d) 108 | argument_hat_2d = [{'events': {}} for _ in range(batch_size)] 109 | for (i, st, ed, event_type_str, e_st, e_ed, entity_type), a_label in zip(argument_keys,argument_hidden_hat_1d.cpu().numpy()): 110 | if a_label == argument2idx['NONE']: 111 | continue 112 | if (st, ed, event_type_str) not in argument_hat_2d[i]['events']: 113 | argument_hat_2d[i]['events'][(st, ed, event_type_str)] = [] 114 | argument_hat_2d[i]['events'][(st, ed, event_type_str)].append((e_st, e_ed, a_label)) 115 | 116 | return trigger_loss,trigger_entities_hat_2d,triggers_y_2d,argument_hidden_logits,arguments_y_1d, argument_hidden_hat_1d, argument_hat_2d,argument_keys 117 | 118 | return trigger_loss,trigger_entities_hat_2d,triggers_y_2d,None,None,None,None,argument_keys 119 | -------------------------------------------------------------------------------- /nanhai_data/data/nanhai_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "sentence": "据悉,菲律宾订购了8架W-3直升机。", 4 | "entity": [ 5 | { 6 | "text": "直升机", 7 | "start": 14, 8 | "end": 17, 9 | "role": "Arg-Object" 10 | }, 11 | { 12 | "text": "菲律宾", 13 | "start": 3, 14 | "end": 6, 15 | "role": "Arg-Subject" 16 | }, 17 | { 18 | "text": "8架", 19 | "start": 9, 20 | "end": 11, 21 | "role": "Arg-Number" 22 | } 23 | ], 24 | "trigger": [ 25 | { 26 | "text": "订购", 27 | "start": 6, 28 | "end": 8, 29 | "event_type": "购买" 30 | } 31 | ], 32 | "arguments": [ 33 | { 34 | "text": "直升机", 35 | "start": 14, 36 | "end": 17, 37 | "entity_type": "装备", 38 | "role": "Arg-Object" 39 | }, 40 | { 41 | "text": "菲律宾", 42 | "start": 3, 43 | "end": 6, 44 | "entity_type": "国家", 45 | "role": "Arg-Subject" 46 | }, 47 | { 48 | "text": "8架", 49 | "start": 9, 50 | "end": 11, 51 | "entity_type": "数量", 52 | "role": "Arg-Number" 53 | } 54 | ] 55 | }, 56 | { 57 | "sentence": "环球网国际军情中心2012年2月13日消息:近日波兰一家航空网站公布了菲律宾订购的W-3“索科尔”(Sokol)通用直升机的照片。", 58 | "entity": [ 59 | { 60 | "text": "W-3“索科尔”(Sokol)通用直升机", 61 | "start": 41, 62 | "end": 61, 63 | "role": "Arg-Object" 64 | }, 65 | { 66 | "text": "菲律宾", 67 | "start": 35, 68 | "end": 38, 69 | "role": "Arg-Subject" 70 | } 71 | ], 72 | "trigger": [ 73 | { 74 | "text": "订购", 75 | "start": 38, 76 | "end": 40, 77 | "event_type": "购买" 78 | } 79 | ], 80 | "arguments": [ 81 | { 82 | "text": "W-3“索科尔”(Sokol)通用直升机", 83 | "start": 41, 84 | "end": 61, 85 | "entity_type": "装备", 86 | "role": "Arg-Object" 87 | }, 88 | { 89 | "text": "菲律宾", 90 | "start": 35, 91 | "end": 38, 92 | "entity_type": "国家", 93 | "role": "Arg-Subject" 94 | } 95 | ] 96 | }, 97 | { 98 | "sentence": "新华网马尼拉5月10日专电:据《马尼拉今日标准报》报道,在被菲律宾占领的南沙群岛中的仁爱礁附近海域发现一艘中国海军驱逐舰和两艘民用船只之后,菲律宾海军9日向该海域派出3艘舰只。", 99 | "entity": [ 100 | { 101 | "text": "菲律宾海军", 102 | "start": 70, 103 | "end": 75, 104 | "role": "Arg-Subject" 105 | }, 106 | { 107 | "text": "仁爱礁附近海域", 108 | "start": 42, 109 | "end": 49, 110 | "role": "Arg-Place" 111 | }, 112 | { 113 | "text": "舰只", 114 | "start": 85, 115 | "end": 87, 116 | "role": "Arg-Object" 117 | }, 118 | { 119 | "text": "9日", 120 | "start": 75, 121 | "end": 77, 122 | "role": "Arg-Time" 123 | } 124 | ], 125 | "trigger": [ 126 | { 127 | "text": "派出", 128 | "start": 81, 129 | "end": 83, 130 | "event_type": "部署" 131 | } 132 | ], 133 | "arguments": [ 134 | { 135 | "text": "菲律宾海军", 136 | "start": 70, 137 | "end": 75, 138 | "entity_type": "机构", 139 | "role": "Arg-Subject" 140 | }, 141 | { 142 | "text": "仁爱礁附近海域", 143 | "start": 42, 144 | "end": 49, 145 | "entity_type": "地点", 146 | "role": "Arg-Place" 147 | }, 148 | { 149 | "text": "舰只", 150 | "start": 85, 151 | "end": 87, 152 | "entity_type": "装备", 153 | "role": "Arg-Object" 154 | }, 155 | { 156 | "text": "9日", 157 | "start": 75, 158 | "end": 77, 159 | "entity_type": "时间", 160 | "role": "Arg-Time" 161 | } 162 | ] 163 | }, 164 | { 165 | "sentence": "二战中日本侵占南海诸岛,战后中国政府收回,并将其归入广东省管辖。", 166 | "entity": [ 167 | { 168 | "text": "日本", 169 | "start": 3, 170 | "end": 5, 171 | "role": "Arg-Subject" 172 | }, 173 | { 174 | "text": "南海诸岛", 175 | "start": 7, 176 | "end": 11, 177 | "role": "Arg-Place" 178 | } 179 | ], 180 | "trigger": [ 181 | { 182 | "text": "侵占", 183 | "start": 5, 184 | "end": 7, 185 | "event_type": "停驻" 186 | } 187 | ], 188 | "arguments": [ 189 | { 190 | "text": "日本", 191 | "start": 3, 192 | "end": 5, 193 | "entity_type": "国家", 194 | "role": "Arg-Subject" 195 | }, 196 | { 197 | "text": "南海诸岛", 198 | "start": 7, 199 | "end": 11, 200 | "entity_type": "地点", 201 | "role": "Arg-Place" 202 | } 203 | ] 204 | }, 205 | { 206 | "sentence": "中国台湾网5月10日消息台当局“农委会渔业署副署长”蔡日耀9日证实,当天上午10点左右,台湾渔船“广大兴28号”在屏东县鹅銮鼻东南方约180海里处遭菲律宾军舰射击,一名船员死亡,船只也丧失动力。", 207 | "entity": [ 208 | { 209 | "text": "菲律宾军舰", 210 | "start": 74, 211 | "end": 79, 212 | "role": "Arg-Subject" 213 | }, 214 | { 215 | "text": "台湾渔船“广大兴28号”", 216 | "start": 44, 217 | "end": 56, 218 | "role": "Arg-Object" 219 | }, 220 | { 221 | "text": "屏东县鹅銮鼻东南方约180海里", 222 | "start": 57, 223 | "end": 72, 224 | "role": "Arg-Place" 225 | }, 226 | { 227 | "text": "9日", 228 | "start": 29, 229 | "end": 31, 230 | "role": "Arg-Time" 231 | } 232 | ], 233 | "trigger": [ 234 | { 235 | "text": "射击", 236 | "start": 79, 237 | "end": 81, 238 | "event_type": "冲突" 239 | } 240 | ], 241 | "arguments": [ 242 | { 243 | "text": "菲律宾军舰", 244 | "start": 74, 245 | "end": 79, 246 | "entity_type": "装备", 247 | "role": "Arg-Subject" 248 | }, 249 | { 250 | "text": "台湾渔船“广大兴28号”", 251 | "start": 44, 252 | "end": 56, 253 | "entity_type": "装备", 254 | "role": "Arg-Object" 255 | }, 256 | { 257 | "text": "屏东县鹅銮鼻东南方约180海里", 258 | "start": 57, 259 | "end": 72, 260 | "entity_type": "地点", 261 | "role": "Arg-Place" 262 | }, 263 | { 264 | "text": "9日", 265 | "start": 29, 266 | "end": 31, 267 | "entity_type": "时间", 268 | "role": "Arg-Time" 269 | } 270 | ] 271 | }, 272 | { 273 | "sentence": "中国台湾网5月10日消息台当局“农委会渔业署副署长”蔡日耀9日证实,当天上午10点左右,台湾渔船“广大兴28号”在屏东县鹅銮鼻东南方约180海里处遭菲律宾军舰射击,一名船员死亡,船只也丧失动力。", 274 | "entity": [ 275 | { 276 | "text": "船员", 277 | "start": 84, 278 | "end": 86, 279 | "role": "Arg-Subject" 280 | } 281 | ], 282 | "trigger": [ 283 | { 284 | "text": "死亡", 285 | "start": 86, 286 | "end": 88, 287 | "event_type": "死亡" 288 | } 289 | ], 290 | "arguments": [ 291 | { 292 | "text": "船员", 293 | "start": 84, 294 | "end": 86, 295 | "entity_type": "人物", 296 | "role": "Arg-Subject" 297 | } 298 | ] 299 | }, 300 | { 301 | "sentence": "中新网5月10日电:据台湾《中国时报》报道,台湾渔民遭菲方船只杀害,台“海巡署”初步调查后指出,“广大兴28号”受攻击的位置,是在鹅銮鼻东南方166里处,属于台湾护渔“南界线”之外。", 302 | "entity": [ 303 | { 304 | "text": "台湾渔民", 305 | "start": 22, 306 | "end": 26, 307 | "role": "Arg-Subject" 308 | } 309 | ], 310 | "trigger": [ 311 | { 312 | "text": "杀害", 313 | "start": 31, 314 | "end": 33, 315 | "event_type": "死亡" 316 | } 317 | ], 318 | "arguments": [ 319 | { 320 | "text": "台湾渔民", 321 | "start": 22, 322 | "end": 26, 323 | "entity_type": "人物", 324 | "role": "Arg-Subject" 325 | } 326 | ] 327 | }, 328 | { 329 | "sentence": "环球网综合报道:屏东琉球籍渔船“广大兴28号”9日在台湾海域遭疑似菲律宾海军机枪扫射,船员洪石城背部因伤重不治。", 330 | "entity": [ 331 | { 332 | "text": "菲律宾海军", 333 | "start": 33, 334 | "end": 38, 335 | "role": "Arg-Subject" 336 | }, 337 | { 338 | "text": "屏东琉球籍渔船“广大兴28号", 339 | "start": 8, 340 | "end": 22, 341 | "role": "Arg-Object" 342 | }, 343 | { 344 | "text": "台湾海域", 345 | "start": 26, 346 | "end": 30, 347 | "role": "Arg-Place" 348 | }, 349 | { 350 | "text": "9日", 351 | "start": 23, 352 | "end": 25, 353 | "role": "Arg-Time" 354 | } 355 | ], 356 | "trigger": [ 357 | { 358 | "text": "扫射", 359 | "start": 40, 360 | "end": 42, 361 | "event_type": "冲突" 362 | } 363 | ], 364 | "arguments": [ 365 | { 366 | "text": "菲律宾海军", 367 | "start": 33, 368 | "end": 38, 369 | "entity_type": "机构", 370 | "role": "Arg-Subject" 371 | }, 372 | { 373 | "text": "屏东琉球籍渔船“广大兴28号", 374 | "start": 8, 375 | "end": 22, 376 | "entity_type": "装备", 377 | "role": "Arg-Object" 378 | }, 379 | { 380 | "text": "台湾海域", 381 | "start": 26, 382 | "end": 30, 383 | "entity_type": "地点", 384 | "role": "Arg-Place" 385 | }, 386 | { 387 | "text": "9日", 388 | "start": 23, 389 | "end": 25, 390 | "entity_type": "时间", 391 | "role": "Arg-Time" 392 | } 393 | ] 394 | }, 395 | { 396 | "sentence": "新华网台北电:台湾“农委会”渔业署副署长蔡日耀9日证实,当日上午10时左右,台湾渔船“广大兴28号”在屏东县鹅銮鼻东南方约180海里处遭菲律宾军舰射击,一名船员死亡,船只也丧失动力。", 397 | "entity": [ 398 | { 399 | "text": "菲律宾军舰", 400 | "start": 68, 401 | "end": 73, 402 | "role": "Arg-Subject" 403 | }, 404 | { 405 | "text": "台湾渔船“广大兴28号”", 406 | "start": 38, 407 | "end": 50, 408 | "role": "Arg-Object" 409 | }, 410 | { 411 | "text": "屏东县鹅銮鼻东南方约180海里处", 412 | "start": 51, 413 | "end": 67, 414 | "role": "Arg-Place" 415 | }, 416 | { 417 | "text": "当日上午", 418 | "start": 28, 419 | "end": 32, 420 | "role": "Arg-Time" 421 | } 422 | ], 423 | "trigger": [ 424 | { 425 | "text": "射击", 426 | "start": 73, 427 | "end": 75, 428 | "event_type": "冲突" 429 | } 430 | ], 431 | "arguments": [ 432 | { 433 | "text": "菲律宾军舰", 434 | "start": 68, 435 | "end": 73, 436 | "entity_type": "机构", 437 | "role": "Arg-Subject" 438 | }, 439 | { 440 | "text": "台湾渔船“广大兴28号”", 441 | "start": 38, 442 | "end": 50, 443 | "entity_type": "装备", 444 | "role": "Arg-Object" 445 | }, 446 | { 447 | "text": "屏东县鹅銮鼻东南方约180海里处", 448 | "start": 51, 449 | "end": 67, 450 | "entity_type": "地点", 451 | "role": "Arg-Place" 452 | }, 453 | { 454 | "text": "当日上午", 455 | "start": 28, 456 | "end": 32, 457 | "entity_type": "时间", 458 | "role": "Arg-Time" 459 | } 460 | ] 461 | }, 462 | { 463 | "sentence": "新华网台北电:台湾“农委会”渔业署副署长蔡日耀9日证实,当日上午10时左右,台湾渔船“广大兴28号”在屏东县鹅銮鼻东南方约180海里处遭菲律宾军舰射击,一名船员死亡,船只也丧失动力。", 464 | "entity": [ 465 | { 466 | "text": "船员", 467 | "start": 78, 468 | "end": 80, 469 | "role": "Arg-Subject" 470 | } 471 | ], 472 | "trigger": [ 473 | { 474 | "text": "死亡", 475 | "start": 80, 476 | "end": 82, 477 | "event_type": "死亡" 478 | } 479 | ], 480 | "arguments": [ 481 | { 482 | "text": "船员", 483 | "start": 78, 484 | "end": 80, 485 | "entity_type": "人物", 486 | "role": "Arg-Subject" 487 | } 488 | ] 489 | }, 490 | ] -------------------------------------------------------------------------------- /nanhai_data/data/readme.txt: -------------------------------------------------------------------------------- 1 | 数据是我们从新闻网站上下载的并进行人工标注的,不方便全部给出,所里这里仅给出10条样例数据,以供参考。 -------------------------------------------------------------------------------- /nanhai_data/test_result/readme.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放测试结果 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import time 3 | import torch 4 | import numpy as np 5 | from train_eval import train, init_network 6 | from importlib import import_module 7 | import argparse 8 | from utils import build_dataset, build_iterator, get_time_dif 9 | 10 | parser = argparse.ArgumentParser(description='Chinese Event Extraction') 11 | parser.add_argument('--model', type=str, default='bert', help='choose a model: Bert') 12 | args = parser.parse_args() 13 | 14 | 15 | if __name__ == '__main__': 16 | dataset = 'nanhai_data' # 数据集 17 | 18 | model_name = args.model # bert 19 | x = import_module('models.' + model_name) 20 | 21 | config = x.Config(dataset) 22 | np.random.seed(11) 23 | torch.manual_seed(11) 24 | torch.cuda.manual_seed_all(11) 25 | torch.backends.cudnn.deterministic = True # 保证每次结果一样 26 | 27 | start_time = time.time() 28 | print("Loading data...") 29 | train_data = build_dataset(config) 30 | 31 | test_data = train_data[1000:] 32 | train_data = train_data[0:1000] 33 | train_iter = build_iterator(train_data, config) 34 | 35 | test_iter = build_iterator(test_data, config) 36 | time_dif = get_time_dif(start_time) 37 | print("Time usage:", time_dif) 38 | 39 | # train 40 | model = x.Model(config).to(config.device) 41 | train(config, model, train_iter, test_iter) 42 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import BCEWithLogitsLoss 7 | # from sklearn import metrics 8 | import time,os 9 | from utils import get_time_dif,calc_metric 10 | from pytorch_pretrained_bert.optimization import BertAdam 11 | from utils import all_triggers_entities, trigger_entities2idx, idx2trigger_entities,find_triggers,all_arguments, argument2idx, idx2argument 12 | 13 | 14 | # 权重初始化,默认xavier 15 | def init_network(model, method='xavier', exclude='embedding', seed=123): 16 | for name, w in model.named_parameters(): 17 | if exclude not in name: 18 | if len(w.size()) < 2: 19 | continue 20 | if 'weight' in name: 21 | if method == 'xavier': 22 | nn.init.xavier_normal_(w) 23 | elif method == 'kaiming': 24 | nn.init.kaiming_normal_(w) 25 | else: 26 | nn.init.normal_(w) 27 | elif 'bias' in name: 28 | nn.init.constant_(w, 0) 29 | else: 30 | pass 31 | 32 | 33 | def eval(model, iterator, fname): 34 | model.eval() 35 | 36 | words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] 37 | with torch.no_grad(): 38 | # for i, batch in enumerate(iterator): 39 | for i, (test, labels) in enumerate(iterator): 40 | trigger_logits, trigger_entities_hat_2d, triggers_y_2d, argument_hidden_logits, arguments_y_1d, argument_hidden_hat_1d, argument_hat_2d, argument_keys = model(test, labels) 41 | 42 | 43 | words_all.extend(test[3]) 44 | triggers_all.extend(test[4]) 45 | triggers_hat_all.extend(trigger_entities_hat_2d.cpu().numpy().tolist()) 46 | arguments_2d=test[-1] 47 | arguments_all.extend(arguments_2d) 48 | if len(argument_keys) > 0: 49 | arguments_hat_all.extend(argument_hat_2d) 50 | else: 51 | batch_size = len(arguments_2d) 52 | argument_hat_2d = [{'events': {}} for _ in range(batch_size)] 53 | arguments_hat_all.extend(argument_hat_2d) 54 | 55 | triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] 56 | with open('temp', 'w',encoding='utf-8') as fout: 57 | for i, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate(zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)): 58 | triggers_hat = triggers_hat[:len(words)] 59 | triggers_hat = [idx2trigger_entities[hat] for hat in triggers_hat] 60 | 61 | # [(ith sentence, t_start, t_end, t_type_str)] 62 | triggers_true_,entities_true=find_triggers(triggers[:len(words)]) 63 | triggers_pred_, entities_pred = find_triggers(triggers_hat) 64 | triggers_true.extend([(i, *item) for item in triggers_true_]) 65 | triggers_pred.extend([(i, *item) for item in triggers_pred_]) 66 | 67 | # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)] 68 | for trigger in arguments['events']: 69 | t_start, t_end, t_type_str = trigger 70 | for argument in arguments['events'][trigger]: 71 | a_start, a_end, a_type_idx = argument 72 | arguments_true.append(( t_type_str, a_start, a_end, a_type_idx)) 73 | 74 | for trigger in arguments_hat['events']: 75 | t_start, t_end, t_type_str = trigger 76 | if t_start>=len(words) or t_end>=len(words): 77 | continue 78 | for argument in arguments_hat['events'][trigger]: 79 | a_start, a_end, a_type_idx = argument 80 | if a_start >= len(words) or a_end >= len(words): 81 | continue 82 | arguments_pred.append((t_type_str, a_start, a_end, a_type_idx)) 83 | 84 | for w, t, t_h in zip(words, triggers, triggers_hat): 85 | fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) 86 | fout.write('#arguments#{}\n'.format(arguments['events'])) 87 | fout.write('#arguments_hat#{}\n'.format(arguments_hat['events'])) 88 | fout.write("\n") 89 | 90 | # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred])) 91 | 92 | print('[trigger classification]') 93 | trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) 94 | print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) 95 | 96 | print('[argument classification]') 97 | argument_p, argument_r, argument_f1 = calc_metric(arguments_true, arguments_pred) 98 | print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r, argument_f1)) 99 | print('[trigger identification]') 100 | triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] 101 | triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] 102 | trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) 103 | print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) 104 | 105 | print('[argument identification]') 106 | arguments_true = [(item[0], item[1], item[2]) for item in arguments_true] 107 | arguments_pred = [(item[0], item[1], item[2]) for item in arguments_pred] 108 | argument_p_, argument_r_, argument_f1_ = calc_metric(arguments_true, arguments_pred) 109 | print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_)) 110 | 111 | metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(trigger_p, trigger_r, trigger_f1) 112 | metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p, argument_r, argument_f1) 113 | metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(trigger_p_, trigger_r_, trigger_f1_) 114 | metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p_, argument_r_, argument_f1_) 115 | final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1) 116 | with open(final, 'w',encoding='utf-8') as fout: 117 | result = open("temp", "r",encoding='utf-8').read() 118 | fout.write("{}\n".format(result)) 119 | fout.write(metric) 120 | os.remove("temp") 121 | return metric,trigger_f1,argument_f1 122 | 123 | 124 | def train(config, model, train_iter, test_iter): 125 | criterion = nn.CrossEntropyLoss(ignore_index=0) 126 | model.train() 127 | # param_optimizer = list(model.named_parameters()) 128 | # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 129 | # optimizer_grouped_parameters = [ 130 | # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 131 | # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 132 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 133 | # optimizer = BertAdam(optimizer_grouped_parameters, 134 | # lr=config.learning_rate, 135 | # warmup=0.05, 136 | # t_total=len(train_iter) * config.num_epochs) 137 | trigger_F1=0 138 | argument_F1=0 139 | 140 | for epoch in range(config.num_epochs): 141 | model.train() 142 | print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) 143 | for i, (trains, labels) in enumerate(train_iter): 144 | model.zero_grad() 145 | trigger_loss, trigger_entities_hat_2d, triggers_y_2d,argument_hidden_logits, arguments_y_1d, argument_hidden_hat_1d, argument_hat_2d,argument_keys= model(trains,labels) 146 | 147 | 148 | # trigger_logits = trigger_logits.view(-1, trigger_logits.shape[-1]) 149 | # trigger_loss = criterion(trigger_logits, triggers_y_2d.view(-1)) 150 | if len(argument_keys)>0 : 151 | argument_loss = criterion(argument_hidden_logits, arguments_y_1d) 152 | 153 | loss = trigger_loss + argument_loss 154 | else: 155 | loss=trigger_loss 156 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 157 | 158 | loss.backward() 159 | 160 | optimizer.step() 161 | # if i % 100 == 0: # monitoring 162 | print("step: {}, loss: {}".format(i, loss.item())) 163 | 164 | 165 | print(f"=========eval test at epoch={epoch}=========") 166 | metric_test, trigger_f1, argument_f1 = eval(model, test_iter, 'nanhai_data/test_result/'+str(epoch) + '_test') 167 | 168 | 169 | if trigger_F1 < trigger_f1: 170 | trigger_F1 = trigger_f1 171 | torch.save(model, "latest_model_2.pt") 172 | if argument_F1 < argument_f1: 173 | argument_F1 = argument_f1 174 | torch.save(model, "argument_latest_model_2.pt") 175 | print('best trigger F1:') 176 | print(trigger_F1) 177 | print('best argument F1:') 178 | print(argument_F1) 179 | 180 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from datetime import timedelta 6 | import json 7 | from const import TRIGGERS,ENTITIES,ARGUMENTS 8 | PAD, CLS ,NONE= '[PAD]', '[CLS]' ,'NONE' # padding符号, bert中综合信息符号 9 | 10 | def build_vocab(labels_trigger,labels_entities, BIO_tagging=True): 11 | all_labels = [NONE] 12 | for label in labels_trigger: 13 | if BIO_tagging: 14 | all_labels.append('B-T-{}'.format(label)) 15 | all_labels.append('I-T-{}'.format(label)) 16 | else: 17 | all_labels.append(label) 18 | for label in labels_entities: 19 | if BIO_tagging: 20 | all_labels.append('B-E-{}'.format(label)) 21 | all_labels.append('I-E-{}'.format(label)) 22 | else: 23 | all_labels.append(label) 24 | label2idx = {tag: idx for idx, tag in enumerate(all_labels)} 25 | idx2label = {idx: tag for idx, tag in enumerate(all_labels)} 26 | 27 | return all_labels, label2idx, idx2label 28 | 29 | all_triggers_entities, trigger_entities2idx, idx2trigger_entities = build_vocab(TRIGGERS,ENTITIES) 30 | # all_entities, entity2idx, idx2entity = build_vocab(ENTITIES) 31 | # all_postags, postag2idx, idx2postag = build_vocab(POSTAGS, BIO_tagging=False) 32 | all_arguments, argument2idx, idx2argument = build_vocab(ARGUMENTS,[], BIO_tagging=False) 33 | 34 | def build_dataset(config): 35 | 36 | def load_dataset(path, pad_size=128): 37 | cut_off=pad_size 38 | contents = [] 39 | 40 | with open(path, 'r', encoding='UTF-8') as f: 41 | data = json.load(f) 42 | for item in data: 43 | 44 | words=[ item['sentence'][i] for i in range(len(item['sentence']))] 45 | token=[] 46 | for w in words: 47 | t = config.tokenizer.tokenize(w) 48 | token.extend(t) 49 | 50 | 51 | token = [CLS] + token 52 | seq_len = len(token) 53 | mask = [] 54 | token_ids = config.tokenizer.convert_tokens_to_ids(token) 55 | if pad_size: 56 | if len(token) < pad_size: 57 | mask = [1] * len(token_ids) + [0] * (pad_size - len(token)) 58 | token_ids += ([0] * (pad_size - len(token))) 59 | else: 60 | mask = [1] * pad_size 61 | token_ids = token_ids[:pad_size] 62 | seq_len = pad_size 63 | triggers_entities=[NONE for _ in range(len(token))][:cut_off] 64 | arguments = { 65 | 'candidates': [ 66 | # ex. (5, 6, "entity_type_str"), ... 67 | ], 68 | 'events': { 69 | # ex. (1, 3, "trigger_type_str"): [(5, 6, "argument_role_idx"), ...] 70 | }, 71 | } 72 | try: 73 | for entity_mention in item['arguments']: 74 | start = entity_mention['start'] 75 | if start >= cut_off: 76 | continue 77 | end = min(entity_mention["end"], cut_off) 78 | arguments['candidates'].append((start+1, end+1, entity_mention['entity_type'])) 79 | 80 | for i in range(start, end): 81 | entity_type = entity_mention['entity_type'] 82 | if i == start: 83 | entity_type = 'B-E-{}'.format(entity_type) 84 | else: 85 | entity_type = 'I-E-{}'.format(entity_type) 86 | 87 | triggers_entities[i+1] = entity_type 88 | 89 | 90 | for event_mention in item['trigger']: 91 | if event_mention['start'] >= cut_off: 92 | continue 93 | for i in range(event_mention['start'],min(event_mention['end'], cut_off)): 94 | trigger_type = event_mention['event_type'] 95 | if i == event_mention['start']: 96 | 97 | triggers_entities[i+1]= 'B-T-{}'.format(trigger_type) 98 | else: 99 | 100 | triggers_entities[i+1] = 'I-T-{}'.format(trigger_type) 101 | 102 | event_key = (event_mention['start']+1, min(event_mention['end'], cut_off)+1,event_mention['event_type']) 103 | arguments['events'][event_key] = [] 104 | for argument in item['arguments']: 105 | if argument['start'] >= cut_off: 106 | continue 107 | role = argument['role'] 108 | 109 | arguments['events'][event_key].append( 110 | (argument['start']+1, min(argument['end'], cut_off)+1, argument2idx[role])) 111 | 112 | triggers_entities_ids=[trigger_entities2idx[i] for i in triggers_entities] 113 | if pad_size: 114 | if len(triggers_entities_ids) < pad_size: 115 | 116 | triggers_entities_ids += ([0] * (pad_size - len(triggers_entities_ids))) 117 | else: 118 | 119 | triggers_entities_ids = triggers_entities_ids[:pad_size] 120 | 121 | 122 | contents.append((token_ids,triggers_entities_ids,seq_len,mask,token,triggers_entities,arguments)) 123 | except: 124 | 125 | continue 126 | 127 | return contents 128 | train = load_dataset(config.train_path, config.pad_size) 129 | 130 | return train 131 | 132 | 133 | class DatasetIterater(object): 134 | def __init__(self, batches, batch_size, device): 135 | self.batch_size = batch_size 136 | self.batches = batches 137 | self.n_batches = len(batches) // batch_size 138 | self.residue = False # 记录batch数量是否为整数 139 | if len(batches) % self.n_batches != 0: 140 | self.residue = True 141 | self.index = 0 142 | self.device = device 143 | 144 | def _to_tensor(self, datas): 145 | x = torch.LongTensor([_[0] for _ in datas]).to(self.device) 146 | y = torch.LongTensor([_[1] for _ in datas]).to(self.device) 147 | 148 | # pad前的长度(超过pad_size的设为pad_size) 149 | seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device) 150 | # mask = torch.LongTensor([_[3] for _ in datas]).to(self.device) 151 | mask = [_[3] for _ in datas] 152 | words=[_[4] for _ in datas] 153 | trigger_entities = [_[5] for _ in datas] 154 | arguments=[_[-1] for _ in datas] 155 | 156 | 157 | return (x, seq_len, mask,words,trigger_entities,arguments), y 158 | 159 | def __next__(self): 160 | if self.residue and self.index == self.n_batches: 161 | batches = self.batches[self.index * self.batch_size: len(self.batches)] 162 | self.index += 1 163 | batches = self._to_tensor(batches) 164 | return batches 165 | 166 | elif self.index >= self.n_batches: 167 | self.index = 0 168 | raise StopIteration 169 | else: 170 | batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] 171 | self.index += 1 172 | batches = self._to_tensor(batches) 173 | return batches 174 | 175 | def __iter__(self): 176 | return self 177 | 178 | def __len__(self): 179 | if self.residue: 180 | return self.n_batches + 1 181 | else: 182 | return self.n_batches 183 | 184 | 185 | def build_iterator(dataset, config): 186 | iter = DatasetIterater(dataset, config.batch_size, config.device) 187 | return iter 188 | 189 | 190 | def get_time_dif(start_time): 191 | """获取已使用时间""" 192 | end_time = time.time() 193 | time_dif = end_time - start_time 194 | return timedelta(seconds=int(round(time_dif))) 195 | 196 | 197 | def find_triggers(labels): 198 | """ 199 | :param labels: ['B-Conflict:Attack', 'I-Conflict:Attack', 'O', 'B-Life:Marry'] 200 | :return: [(0, 2, 'Conflict:Attack'), (3, 4, 'Life:Marry')] 201 | """ 202 | result_trigger = [] 203 | result_entities=[] 204 | labels = [label.split('-') for label in labels] 205 | 206 | for i in range(len(labels)): 207 | if labels[i][0] == 'B': 208 | if labels[i][1]=='T': 209 | result_trigger.append([i, i + 1, labels[i][2]]) 210 | elif labels[i][1]=='E': 211 | result_entities.append([i, i + 1, labels[i][2]]) 212 | 213 | for item in result_trigger: 214 | j = item[1] 215 | while j < len(labels): 216 | if labels[j][0] == 'I' and labels[j][1]=='T': 217 | j = j + 1 218 | item[1] = j 219 | else: 220 | break 221 | for item in result_entities: 222 | j = item[1] 223 | while j < len(labels): 224 | if labels[j][0] == 'I' and labels[j][1]=='E': 225 | j = j + 1 226 | item[1] = j 227 | else: 228 | break 229 | 230 | return [tuple(item) for item in result_trigger],[tuple(item) for item in result_entities] 231 | 232 | 233 | def calc_metric(y_true, y_pred): 234 | """ 235 | :param y_true: [(tuple), ...] 236 | :param y_pred: [(tuple), ...] 237 | :return: 238 | """ 239 | num_proposed = len(y_pred) 240 | num_gold = len(y_true) 241 | 242 | y_true_set = set(y_true) 243 | num_correct = 0 244 | for item in y_pred: 245 | if item in y_true_set: 246 | num_correct += 1 247 | 248 | print('proposed: {}\tcorrect: {}\tgold: {}'.format(num_proposed, num_correct, num_gold)) 249 | 250 | if num_proposed != 0: 251 | precision = num_correct / num_proposed 252 | else: 253 | precision = 1.0 254 | 255 | if num_gold != 0: 256 | recall = num_correct / num_gold 257 | else: 258 | recall = 1.0 259 | 260 | if precision + recall != 0: 261 | f1 = 2 * precision * recall / (precision + recall) 262 | else: 263 | f1 = 0 264 | 265 | return precision, recall, f1 266 | --------------------------------------------------------------------------------