├── model_save └── model.pt ├── ccl-cfn ├── result │ └── resultdir └── datadir ├── run.sh ├── 汉语框架语义解析评测报名表.docx ├── requirements.txt ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── gcn_model.py ├── L2.py ├── layers.py ├── convert_tf_checkpoint_to_pytorch.py ├── optimization.py ├── file_utils.py └── tokenization.py ├── evaluate_F1.py ├── evaluate_F1_rc.py ├── model.py ├── predict.py ├── run_fi.py ├── arguments.py ├── dataset.py ├── README.md ├── run_rc.py ├── run_ai.py └── until.py /model_save/model.pt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccl-cfn/result/resultdir: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccl-cfn/datadir: -------------------------------------------------------------------------------- 1 | 报名成功后在阿里天池下载数据集 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python run_fi.py 3 | python run_ai.py 4 | python run_rc.py 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /汉语框架语义解析评测报名表.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SXUNLP/Chinese-Frame-Semantic-Parsing/HEAD/汉语框架语义解析评测报名表.docx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0.dev0 2 | allennlp==2.10.0 3 | apex==0.9.10dev 4 | boto3==1.24.89 5 | botocore==1.27.89 6 | evaluate==0.4.0 7 | numpy==1.23.3 8 | requests==2.28.1 9 | tensorflow==2.12.0 10 | torch==1.11.0 11 | tqdm==4.64.1 12 | transformers==4.20.1 13 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 4 | BertForMaskedLM, BertForNextSentencePrediction, 5 | BertForSequenceClassification, BertForMultipleChoice, 6 | BertForTokenClassification, BertForQuestionAnswering) 7 | from .optimization import BertAdam 8 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE 9 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | try: 5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 6 | except ModuleNotFoundError: 7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 8 | "In that case, it requires TensorFlow to be installed. Please see " 9 | "https://www.tensorflow.org/install/ for installation instructions.") 10 | raise 11 | 12 | if len(sys.argv) != 5: 13 | # pylint: disable=line-too-long 14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 15 | else: 16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 17 | TF_CONFIG = sys.argv.pop() 18 | TF_CHECKPOINT = sys.argv.pop() 19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /evaluate_F1.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | dev_file = './ccl-cfn/cfn-dev.json' 4 | predictions_file = './ccl-cfn/result/task2_dev.json' 5 | 6 | with open(dev_file, 'r') as f: 7 | dev_data = json.load(f) 8 | 9 | with open(predictions_file, 'r') as f: 10 | predictions_data = json.load(f) 11 | 12 | # assert len(dev_data) == len(predictions_data) 13 | 14 | j = 0 15 | total_TP = 0 16 | total_FP = 0 17 | total_FN = 0 18 | preds = {} 19 | labels = {} 20 | for i, span in enumerate(predictions_data): 21 | preds.setdefault(span[0], set()) 22 | preds[span[0]].add((span[1], span[2])) 23 | for i, data in enumerate(dev_data): 24 | labels.setdefault(data['task_id'], set()) 25 | for span in data['cfn_spans']: 26 | labels[data['task_id']].add((span[0], span[1])) 27 | 28 | for taskid in labels: 29 | TP = 0 30 | FP = 0 31 | FN = 0 32 | if taskid not in preds: 33 | FN += len(labels[taskid]) 34 | total_FN += FN 35 | continue 36 | 37 | for pred in preds[taskid]: 38 | if pred in labels[taskid]: 39 | TP += 1 40 | else: 41 | FP += 1 42 | for label in labels[taskid]: 43 | if label not in preds[taskid]: 44 | FN += 1 45 | total_TP += TP 46 | total_FP += FP 47 | total_FN += FN 48 | 49 | 50 | 51 | print(total_TP, total_FP, total_FN) 52 | precision = total_TP / (total_TP + total_FP) 53 | recall = total_TP / (total_TP + total_FN) 54 | F1 = 2 * precision * recall / (precision + recall) 55 | print(f'precision: {precision}, recall: {recall}, F1: {F1}') -------------------------------------------------------------------------------- /pytorch_pretrained_bert/gcn_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | from torch.nn.modules.module import Module 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from .layers import BiGCN_layer 10 | 11 | 12 | 13 | class BiGCN(nn.Module): 14 | def __init__(self,num_feature,num_hidden,num_class,p,lambda_1,lambda_2,dropout,bias=True,beta=True,A2='cos_A2',n_iter=2,Type='mean'): 15 | super(BiGCN,self).__init__() 16 | 17 | self.p = p 18 | self.lambda_1 = lambda_1 19 | self.lambda_2 = lambda_2 20 | self.Type = Type 21 | self.num_feature = num_feature 22 | self.num_hidden = num_hidden 23 | self.gc1 = BiGCN_layer(num_feature,num_hidden,p,lambda_1,lambda_2,bias,beta,A2,n_iter,Type) 24 | self.gc2 = BiGCN_layer(num_hidden,num_class,p,lambda_1,lambda_2,bias,beta,A2,n_iter,Type) 25 | self.dropout = dropout 26 | self.reg_params = list(self.gc1.parameters()) 27 | self.non_reg_params = list(self.gc2.parameters()) 28 | 29 | def forward(self, x, L): 30 | A = [] 31 | Z1 = torch.zeros(x.shape[0], self.num_feature).cuda() 32 | X = F.relu(self.gc1(x, x, Z1, L)[0]) 33 | A.append(self.gc1(x, x, Z1, L)[1]) 34 | X = F.dropout(X, self.dropout, training=self.training) 35 | Z2 = torch.zeros(x.shape[0], self.num_hidden).cuda() 36 | X, _ = self.gc2(X, X, Z2, L) 37 | return X 38 | 39 | -------------------------------------------------------------------------------- /evaluate_F1_rc.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | dev_file = './ccl-cfn/cfn-dev.json' 4 | predictions_file = './ccl-cfn/result/task3_dev.json' 5 | 6 | with open(dev_file, 'r') as f: 7 | dev_data = json.load(f) 8 | 9 | with open(predictions_file, 'r') as f: 10 | predictions_data = json.load(f) 11 | 12 | # assert len(dev_data) == len(predictions_data) 13 | 14 | j = 0 15 | total_TP = 0 16 | total_FP = 0 17 | total_FN = 0 18 | preds = {} 19 | labels = {} 20 | for i, span in enumerate(predictions_data): 21 | preds.setdefault(span[0], set()) 22 | preds[span[0]].add((span[1], span[2], span[3])) 23 | for i, data in enumerate(dev_data): 24 | labels.setdefault(data['task_id'], set()) 25 | for span in data['cfn_spans']: 26 | labels[data['task_id']].add((span[0], span[1], span[2])) 27 | 28 | for taskid in labels: 29 | TP = 0 30 | FP = 0 31 | FN = 0 32 | if taskid not in preds: 33 | FN += len(labels[taskid]) 34 | total_FN += FN 35 | continue 36 | 37 | for pred in preds[taskid]: 38 | if pred in labels[taskid]: 39 | TP += 1 40 | else: 41 | FP += 1 42 | for label in labels[taskid]: 43 | if label not in preds[taskid]: 44 | FN += 1 45 | total_TP += TP 46 | total_FP += FP 47 | total_FN += FN 48 | # break 49 | 50 | 51 | 52 | print(total_TP, total_FP, total_FN) 53 | precision = total_TP / (total_TP + total_FP) 54 | recall = total_TP / (total_TP + total_FN) 55 | F1 = 2 * precision * recall / (precision + recall) 56 | print(f'precision: {precision}, recall: {recall}, F1: {F1}') -------------------------------------------------------------------------------- /pytorch_pretrained_bert/L2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | 9 | 10 | def cos_A2(input,beta): 11 | feature_g = input.t() 12 | norm2 = torch.norm(feature_g,p=2,dim=1).view(-1, 1) 13 | cos = beta*torch.div(torch.mm(feature_g, feature_g.t()), torch.mm(norm2, norm2.t()) + 1e-7) 14 | I = torch.eye(cos.shape[0]).cuda() 15 | cos = F.softmax(cos,dim=1) 16 | cos = cos - (torch.triu(cos)-torch.triu(cos,diagonal=1)) 17 | mean = torch.mean(cos) 18 | cos[cos=mean]=1 20 | cos = cos + I 21 | rowsum = torch.sum(cos,dim=1)**(-0.5) 22 | D2 = torch.diag(rowsum) 23 | A2 = torch.mm(D2,cos) 24 | A2 = torch.mm(A2,D2) 25 | return A2 26 | 27 | 28 | def cosM_A2(input,beta,A): 29 | feature_g = input.t() 30 | norm2 = torch.norm(feature_g,p=2,dim=1).view(-1, 1) 31 | cos = beta*torch.div(torch.mm(feature_g, feature_g.t()), torch.mm(norm2, norm2.t()) + 1e-7) 32 | I = torch.eye(cos.shape[0]).cuda() 33 | e = beta*torch.ones(cos.shape).cuda() 34 | A = torch.sigmoid(A) 35 | A = torch.triu(A)+torch.triu(A,diagonal=1).t() 36 | cos = torch.mm(0.5*cos+0.5*e,A) 37 | 38 | cos = cos - (torch.triu(cos)-torch.triu(cos,diagonal=1)) 39 | mean = torch.mean(cos) 40 | cos[cos=mean]=1 42 | cos = cos + I 43 | rowsum = torch.sum(cos,dim=1)**(-0.5) 44 | D2 = torch.diag(rowsum) 45 | A2 = torch.mm(D2,cos) 46 | A2 = torch.mm(A2,D2) 47 | return A2 48 | 49 | def learn_A2 (A): 50 | A = torch.sigmoid(A) 51 | A = torch.triu(A,diagonal=1) 52 | I = torch.eye(A.shape[0]).cuda() 53 | A2 = A + A.t() +I 54 | 55 | rowsum = torch.sum(A2,dim=1)**(-0.5) 56 | D2 = torch.diag(rowsum) 57 | A2 = torch.mm(D2,A2) 58 | A2 = torch.mm(A2,D2) 59 | return A2 60 | 61 | 62 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.modules.module import Module 5 | from pytorch_pretrained_bert.L2 import learn_A2, cos_A2, cosM_A2 6 | 7 | 8 | class ADMM_Y(Module): 9 | def __init__(self,L,A2,p,lambda_1,lambda_2,n_iter=2,Type='mean'): 10 | super(ADMM_Y,self).__init__() 11 | self.p = p 12 | self.L = L 13 | self.A2 = A2 14 | self.lambda_1 = lambda_1 15 | self.lambda_2 = lambda_2 16 | self.n_iter = n_iter 17 | self.Type = Type 18 | 19 | self.I_2 = torch.eye(A2.shape[0]).cuda() 20 | self.L_2 = self.I_2-self.lambda_2*self.A2 21 | def ADMM_Y1(self,F,Y2,Z): 22 | T = F + self.p*Y2 + Z 23 | Y1 = 1./(1+self.p)*torch.mm(self.L,T) 24 | return Y1 25 | 26 | def ADMM_Y2(self,F,Y1,Z): 27 | T = F + self.p*Y1 - Z 28 | Y2 = 1./(1+self.p)*torch.mm(T,self.L_2) 29 | return Y2 30 | 31 | def ADMM_Z(self,Y1,Y2,Z): 32 | Z = Z + self.p*(Y2-Y1) 33 | return Z 34 | 35 | def forward(self,F,Y2,Z): 36 | for i in range(self.n_iter): 37 | Y1 = self.ADMM_Y1(F,Y2,Z) 38 | Y2 = self.ADMM_Y2(F,Y1,Z) 39 | Z = self.ADMM_Z(Y1,Y2,Z) 40 | if self.Type == 'y2': 41 | Y = Y2 42 | elif self.Type == 'y1': 43 | Y = Y1 44 | elif self.Type == 'mean': 45 | Y = 1/2*(Y1+Y2) 46 | return Y 47 | #''' 48 | class BiGCN_layer(Module): 49 | def __init__(self,ind,outd,p,lambda_1,lambda_2,bias=True,beta=True,A2='cos_A2',n_iter=2,Type='mean'): 50 | super(BiGCN_layer,self).__init__() 51 | self.ind = ind #input dimension 52 | self.outd = outd #output dimension 53 | self.p = p 54 | self.A2 = A2 55 | self.n_iter = n_iter 56 | self.lambda_1 = lambda_1 57 | self.lambda_2 = lambda_2 58 | self.Type = Type 59 | 60 | if beta: 61 | self.beta = Parameter(torch.Tensor(1).uniform_(0, 1)) 62 | else: 63 | self.beta = 1 64 | self.weight1 = Parameter(torch.FloatTensor(ind,outd)) 65 | self.A = Parameter(torch.FloatTensor(ind,ind)) 66 | if bias: 67 | self.bias1 = Parameter(torch.FloatTensor(outd)) 68 | else: 69 | self.register_parameter('bias1',None) 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | stdv = 1. / math.sqrt(self.weight1.size(1)) 74 | self.weight1.data.uniform_(-stdv,stdv) 75 | self.A.data = torch.empty(self.A.shape).random_(2) 76 | if self.bias1 is not None: 77 | self.bias1.data.uniform_(-stdv,stdv) 78 | 79 | 80 | def forward(self,Y1,Y2,Z,L): 81 | if self.A2 == 'learn_A2': 82 | A2 = learn_A2(self.A) 83 | elif self.A2 == 'cos_A2': 84 | A2 = cos_A2(Y2,self.beta) 85 | elif self.A2 == 'cosM_A2': 86 | A2 = cosM_A2(Y2,self.beta,self.A) 87 | else: 88 | raise Exception("No such A2:",self.A2) 89 | admm = ADMM_Y(L,A2,self.p,self.lambda_1,self.lambda_2,self.n_iter,self.Type) 90 | Y = admm(Y1,Y2,Z) 91 | Y = torch.mm(Y,self.weight1) 92 | if self.bias1 is not None: 93 | return Y + self.bias1,A2 94 | else: 95 | return Y,A2 96 | 97 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from .modeling import BertConfig, BertForPreTraining 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | config_path = os.path.abspath(bert_config_file) 32 | tf_path = os.path.abspath(tf_checkpoint_path) 33 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) 34 | # Load weights from TF model 35 | init_vars = tf.train.list_variables(tf_path) 36 | names = [] 37 | arrays = [] 38 | for name, shape in init_vars: 39 | print("Loading TF weight {} with shape {}".format(name, shape)) 40 | array = tf.train.load_variable(tf_path, name) 41 | names.append(name) 42 | arrays.append(array) 43 | 44 | # Initialise PyTorch model 45 | config = BertConfig.from_json_file(bert_config_file) 46 | print("Building PyTorch model from configuration: {}".format(str(config))) 47 | model = BertForPreTraining(config) 48 | 49 | for name, array in zip(names, arrays): 50 | name = name.split('/') 51 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 52 | # which are not required for using pretrained model 53 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 54 | print("Skipping {}".format("/".join(name))) 55 | continue 56 | pointer = model 57 | for m_name in name: 58 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 59 | l = re.split(r'_(\d+)', m_name) 60 | else: 61 | l = [m_name] 62 | if l[0] == 'kernel' or l[0] == 'gamma': 63 | pointer = getattr(pointer, 'weight') 64 | elif l[0] == 'output_bias' or l[0] == 'beta': 65 | pointer = getattr(pointer, 'bias') 66 | elif l[0] == 'output_weights': 67 | pointer = getattr(pointer, 'weight') 68 | else: 69 | pointer = getattr(pointer, l[0]) 70 | if len(l) >= 2: 71 | num = int(l[1]) 72 | pointer = pointer[num] 73 | if m_name[-11:] == '_embeddings': 74 | pointer = getattr(pointer, 'weight') 75 | elif m_name == 'kernel': 76 | array = np.transpose(array) 77 | try: 78 | assert pointer.shape == array.shape 79 | except AssertionError as e: 80 | e.args += (pointer.shape, array.shape) 81 | raise 82 | print("Initialize PyTorch weight {}".format(name)) 83 | pointer.data = torch.from_numpy(array) 84 | 85 | # Save pytorch-model 86 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 87 | torch.save(model.state_dict(), pytorch_dump_path) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | ## Required parameters 93 | parser.add_argument("--tf_checkpoint_path", 94 | default = None, 95 | type = str, 96 | required = True, 97 | help = "Path the TensorFlow checkpoint path.") 98 | parser.add_argument("--bert_config_file", 99 | default = None, 100 | type = str, 101 | required = True, 102 | help = "The config json file corresponding to the pre-trained BERT model. \n" 103 | "This specifies the model architecture.") 104 | parser.add_argument("--pytorch_dump_path", 105 | default = None, 106 | type = str, 107 | required = True, 108 | help = "Path to the output PyTorch model.") 109 | args = parser.parse_args() 110 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 111 | args.bert_config_file, 112 | args.pytorch_dump_path) 113 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import torch.nn as nn 4 | from transformers import BertPreTrainedModel, BertModel, BertConfig 5 | from typing import Optional, Tuple 6 | from transformers.modeling_outputs import ModelOutput 7 | import allennlp.modules.span_extractors.max_pooling_span_extractor as max_pooling_span_extractor 8 | from allennlp.nn.util import get_mask_from_sequence_lengths, masked_log_softmax 9 | 10 | @dataclass 11 | class FrameSRLModelOutput(ModelOutput): 12 | """ 13 | 14 | Args: 15 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 16 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. 17 | start_logits (`torch.FloatTensor` of shape `(batch_size, FE_num, sequence_length)`): 18 | Span-start scores (before SoftMax). 19 | end_logits (`torch.FloatTensor` of shape `(batch_size, FE_num, sequence_length)`): 20 | Span-end scores (before SoftMax). 21 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 22 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 23 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 24 | 25 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 26 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 27 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 28 | sequence_length)`. 29 | 30 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 31 | heads. 32 | """ 33 | 34 | loss: Optional[torch.FloatTensor] = None 35 | logits: torch.FloatTensor = None 36 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 37 | attentions: Optional[Tuple[torch.FloatTensor]] = None 38 | 39 | class BertForFrameSRL(BertPreTrainedModel): 40 | def __init__(self, config: BertConfig): 41 | super().__init__(config) 42 | self.config = config 43 | self.bert = BertModel(config, add_pooling_layer=False) 44 | # self.bert = BertModel(config) 45 | # self.start_pointer = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 46 | # self.end_pointer = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 47 | self.ffn = nn.Linear(config.hidden_size, config.hidden_size) 48 | self.activation = nn.ReLU() 49 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 50 | self.mlp = nn.Sequential(self.ffn, self.activation, self.classifier) 51 | self.span_extractor = max_pooling_span_extractor.MaxPoolingSpanExtractor(config.hidden_size) 52 | self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) 53 | # self.loss_fct_nll = nn.NLLLoss(ignore_index=-1) 54 | self.post_init() 55 | 56 | def forward( 57 | self, 58 | input_ids: Optional[torch.Tensor] = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | token_type_ids: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.Tensor] = None, 62 | head_mask: Optional[torch.Tensor] = None, 63 | inputs_embeds: Optional[torch.Tensor] = None, 64 | span_token_idx: Optional[torch.Tensor] = None, 65 | labels: Optional[torch.Tensor] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | return_dict: Optional[bool] = None, 69 | ): 70 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 71 | 72 | outputs = self.bert( 73 | input_ids, 74 | attention_mask=attention_mask, 75 | token_type_ids=token_type_ids, 76 | position_ids=position_ids, 77 | head_mask=head_mask, 78 | inputs_embeds=inputs_embeds, 79 | output_attentions=output_attentions, 80 | output_hidden_states=output_hidden_states, 81 | return_dict=return_dict, 82 | ) 83 | 84 | sequence_output = outputs[0] 85 | 86 | loss = None 87 | 88 | 89 | # span_token_idx (B, span_num, 2) -> span_rep (B, span_num, H) allennlp maxpoolingspanextractor 90 | # logits (B, num_labels, span_num) 91 | 92 | span_rep = self.span_extractor(sequence_output, span_token_idx) 93 | logits = self.mlp(span_rep).permute(0, 2, 1) 94 | 95 | if labels is not None: 96 | loss = self.loss_fct(logits, labels) 97 | 98 | if not return_dict: 99 | output = (logits,) + outputs[2:] 100 | return ((loss,) + output) if loss is not None else output 101 | 102 | return FrameSRLModelOutput( 103 | loss=loss, 104 | logits=logits, 105 | hidden_states=outputs.hidden_states, 106 | attentions=outputs.attentions, 107 | ) 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from allennlp.nn.util import get_mask_from_sequence_lengths 2 | import numpy as np 3 | import torch 4 | from typing import List, Str, Dict 5 | import json 6 | 7 | def post_process_function_greedy(start_logits: torch.Tensor, 8 | end_logits: torch.Tensor, 9 | context_length: torch.Tensor,): 10 | max_len = int(start_logits.shape[-2]) # (B, L, F) 11 | fe_num = int(start_logits.shape[-1]) 12 | 13 | # predict start positions 14 | context_length_mask = get_mask_from_sequence_lengths(context_length.squeeze(), max_len) 15 | start_logits_masked = start_logits.data.masked_fill_(~context_length_mask.unsqueeze(-1), -float('inf')) 16 | start_pred = torch.argmax(start_logits_masked, dim=-2) 17 | start_mask = get_mask_from_sequence_lengths(start_pred.flatten(), max_len).reshape(-1, fe_num, max_len).permute(0, 2, 1) 18 | end_mask = start_mask ^ (context_length_mask.repeat([1, fe_num]).reshape(-1, fe_num, max_len).permute(0, 2, 1)) # end >= start 19 | 20 | # predict end positions 21 | end_logits_masked = end_logits.data.masked_fill_(~end_mask, -float('inf')) 22 | end_pred_ = torch.argmax(end_logits_masked, dim=-2) 23 | neg_mask = (start_pred == 0) # if start == 0, set end to 0 24 | end_pred = neg_mask * torch.zeros_like(end_pred_) + (~neg_mask) * end_pred_ 25 | 26 | return start_pred.cpu().numpy().tolist(), end_pred.cpu().numpy().tolist() 27 | 28 | def post_process_function_with_max_len(start_logits: torch.Tensor, 29 | end_logits: torch.Tensor, 30 | context_length: torch.Tensor, 31 | max_len: int): 32 | # naive 33 | start_pred = [] 34 | end_pred = [] 35 | bsz = int(start_logits.shape[0]) 36 | fe_num = int(start_logits.shape[-1]) 37 | for i in range(bsz): 38 | cl = int(context_length[i]) 39 | start_pred_tensor = torch.LongTensor([0]).to(start_logits.device) 40 | end_pred_tensor = torch.LongTensor([0]).to(end_logits.device) 41 | best_score = start_logits[i][0] + end_logits[i][0] 42 | for start in range(1, cl): 43 | for end in range(start, min(cl, start+max_len)): 44 | score = start_logits[i][start] + end_logits[i][end] 45 | mask = score > best_score 46 | start_pred_tensor = mask * start + (~mask) * start_pred_tensor 47 | end_pred_tensor = mask * end + (~mask) * end_pred_tensor 48 | best_score = mask * score + (~mask) * best_score 49 | 50 | start_pred.append(start_pred_tensor.cpu().numpy().tolist()) 51 | end_pred.append(end_pred_tensor.cpu().numpy().tolist()) 52 | 53 | return start_pred, end_pred 54 | 55 | def save_predictions(start_pred: List, 56 | end_pred: List, 57 | FE_num: torch.Tensor, 58 | word_ids: torch.Tensor, 59 | task_id: torch.Tensor): 60 | predictions = [] 61 | start_pred_lst = start_pred 62 | end_pred_lst = end_pred 63 | bsz = FE_num.shape[0] 64 | 65 | for i in range(bsz): 66 | fe_num = int(FE_num[i][0]) 67 | start_pred_word_lst = [int(word_ids[i][int(tok)]) for tok in start_pred_lst[i][:fe_num]] 68 | end_pred_word_lst = [int(word_ids[i][int(tok)]) for tok in end_pred_lst[i][:fe_num]] 69 | tid = int(task_id[i][0]) 70 | predictions.append({"task_id": tid, "cfn_spans": list(zip(start_pred_word_lst, end_pred_word_lst))}) 71 | 72 | return predictions 73 | 74 | 75 | def calculate_F1_metric(predictions: List, 76 | eval_data_path: Str, 77 | frame_data: Dict): 78 | predictions_with_fename = [] 79 | with open(eval_data_path, 'r') as f: 80 | eval_data = json.load(f) 81 | assert len(predictions) == len(eval_data) 82 | for p, e in zip(predictions, eval_data): 83 | assert p["task_id"] == e["task_id"] 84 | for span in p["cfn_spans"]: 85 | assert span[0] <= span[1] 86 | start_pred_lst = start_pred 87 | end_pred_lst = end_pred 88 | 89 | 90 | bsz = gt_FE_word_idx.shape[0] 91 | fesz = gt_FE_word_idx.shape[-1] 92 | TP = 0 93 | FP = 0 94 | FN = 0 95 | 96 | # print(end_pred_lst) 97 | for i in range(bsz): 98 | fe_num = int(FE_num[i][0]) 99 | start_pred_word_lst = [int(word_ids[i][int(tok)]) for tok in start_pred_lst[i][:fe_num]] 100 | end_pred_word_lst = [int(word_ids[i][int(tok)]) for tok in end_pred_lst[i][:fe_num]] 101 | 102 | tp = 0 103 | fn = 0 104 | fp = 0 105 | 106 | for j in range(fesz): 107 | if int(gt_FE_word_idx[i][j]) == -1: 108 | break 109 | fe_idx = int(gt_FE_word_idx[i][j]) 110 | fe_st = int(gt_start_positions[i][j]) 111 | fe_ed = int(gt_end_positions[i][j]) 112 | 113 | if start_pred_word_lst[fe_idx] == fe_st and end_pred_word_lst[fe_idx] == fe_ed: 114 | tp += float(FE_core_pts[i][fe_idx]) 115 | else: 116 | fn += float(FE_core_pts[i][fe_idx]) 117 | 118 | tp_fp = 0 119 | 120 | for j, x in enumerate(start_pred_word_lst): 121 | if x != -1: 122 | tp_fp += float(FE_core_pts[i][j]) 123 | 124 | fp = tp_fp - tp 125 | 126 | TP += tp 127 | FN += fn 128 | FP += fp 129 | 130 | return TP, FP, FN 131 | 132 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /run_fi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import random 5 | from torch import nn 6 | from tqdm import tqdm 7 | from torch.utils.data import TensorDataset, DataLoader 8 | import torch.nn.functional as F 9 | from until import Processor, convert_examples_to_features, FocalLoss 10 | from pytorch_pretrained_bert.modeling import BertForTokenClassification2 11 | from pytorch_pretrained_bert import BertTokenizer 12 | from pytorch_pretrained_bert import BertAdam 13 | import numpy as np 14 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 15 | train_batch_size = 32 16 | dev_batch_size = 64 17 | test_batch_size = 64 18 | learning_rate = 0.0001 19 | 20 | epochs = 45 21 | random.seed(2021) 22 | np.random.seed(2021) 23 | torch.manual_seed(2021) 24 | # device = 'cpu' 25 | device = 'cuda:0' 26 | do_train = True 27 | do_dev = True 28 | processor = Processor() 29 | train_data_dir = './ccl-cfn/cfn-train.json' 30 | dev_data_dir = './ccl-cfn/cfn-dev.json' 31 | test_data_dir = './ccl-cfn/cfn-test.json' 32 | save_dir = 'model_save/' 33 | CONFIG_NAME = "config.json" 34 | WEIGHTS_NAME = "pytorch_model_new_test.bin" 35 | frame_data = json.load(open("./ccl-cfn/frame_info.json", encoding="utf8")) 36 | idx2f = [ x['frame_name'] for x in frame_data ] 37 | f2idx = { x['frame_name']:i for i, x in enumerate(frame_data) } 38 | num_labels = len(f2idx) 39 | model_name = "./bert_wwm" 40 | model = BertForTokenClassification2.from_pretrained(model_name, cache_dir=None, num_labels=num_labels) 41 | model.to(device) 42 | param_optimizer = list(model.named_parameters()) 43 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 44 | optimizer_grouped_parameters = [ 45 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 46 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 47 | ] 48 | 49 | 50 | tokenizer = BertTokenizer.from_pretrained(model_name) 51 | focal_loss = FocalLoss() 52 | criteon = nn.CrossEntropyLoss().to(device) 53 | # train_input_ids = torch.LongTensor() 54 | train_data_processor = processor.get_train_examples(train_data_dir, f2idx) 55 | train_feature = convert_examples_to_features(train_data_processor, 256, tokenizer) 56 | train_task_id = torch.tensor([f.task_id for f in train_feature], dtype=torch.long) 57 | train_input_ids = torch.tensor([f.input_ids for f in train_feature], dtype=torch.long) 58 | train_input_mask = torch.tensor([f.input_mask for f in train_feature], dtype=torch.long) 59 | train_segment_ids = torch.tensor([f.segment_ids for f in train_feature], dtype=torch.long) 60 | # train_mask_array = torch.tensor([f.mask_array for f in train_feature], dtype=torch.int) 61 | train_mask_array = torch.tensor([f.mask_array for f in train_feature], dtype=torch.bool) 62 | 63 | train_label = torch.tensor([f.label_id for f in train_feature], dtype=torch.long) 64 | train_data = TensorDataset(train_input_ids, train_input_mask, train_segment_ids, train_mask_array, train_label) 65 | train_loader = DataLoader(train_data, shuffle=False, batch_size=train_batch_size) 66 | 67 | 68 | num_train_optimization_steps = int( 69 | len(train_data_processor) / train_batch_size / 1) * epochs 70 | optimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate, warmup=0.1, t_total=num_train_optimization_steps) 71 | scheduler_args={f'gamma': .75**(1 / 5090)} 72 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,**scheduler_args) 73 | dev_data_processor = processor.get_dev_examples(dev_data_dir, f2idx) 74 | dev_feature = convert_examples_to_features(dev_data_processor, 256, tokenizer) 75 | dev_task_id = torch.tensor([f.task_id for f in dev_feature], dtype=torch.long) 76 | dev_input_ids = torch.tensor([f.input_ids for f in dev_feature], dtype=torch.long) 77 | dev_input_mask = torch.tensor([f.input_mask for f in dev_feature], dtype=torch.long) 78 | dev_segment_ids = torch.tensor([f.segment_ids for f in dev_feature], dtype=torch.long) 79 | # dev_mask_array = torch.tensor([f.mask_array for f in dev_feature], dtype=torch.int) 80 | dev_mask_array = torch.tensor([f.mask_array for f in dev_feature], dtype=torch.bool) 81 | 82 | dev_label = torch.tensor([f.label_id for f in dev_feature], dtype=torch.long) 83 | dev_data = TensorDataset(dev_input_ids, dev_input_mask, dev_segment_ids, dev_mask_array, dev_label) 84 | dev_loader = DataLoader(dev_data, shuffle=False, batch_size=dev_batch_size) 85 | 86 | 87 | def do_test(): 88 | test_data_processor = processor.get_test_examples(test_data_dir, f2idx) 89 | test_feature = convert_examples_to_features(test_data_processor, 256, tokenizer) 90 | test_task_id = torch.tensor([f.task_id for f in test_feature], dtype=torch.long) 91 | test_input_ids = torch.tensor([f.input_ids for f in test_feature], dtype=torch.long) 92 | test_input_mask = torch.tensor([f.input_mask for f in test_feature], dtype=torch.long) 93 | test_segment_ids = torch.tensor([f.segment_ids for f in test_feature], dtype=torch.long) 94 | # dev_mask_array = torch.tensor([f.mask_array for f in dev_feature], dtype=torch.int) 95 | test_mask_array = torch.tensor([f.mask_array for f in test_feature], dtype=torch.bool) 96 | 97 | test_label = torch.tensor([f.label_id for f in test_feature], dtype=torch.long) 98 | test_data = TensorDataset(test_task_id, test_input_ids, test_input_mask, test_segment_ids, test_mask_array, test_label) 99 | test_loader = DataLoader(test_data, shuffle=False, batch_size=test_batch_size) 100 | 101 | res_json = [] 102 | model.eval() 103 | for step, batch in enumerate(tqdm(test_loader, desc="Iteration")): 104 | batch = tuple(t.to(device) for t in batch) 105 | test_task_id, input_ids, input_mask, segment_ids, mask_array, label = batch 106 | with torch.no_grad(): 107 | logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None, 108 | mask_array=mask_array) 109 | 110 | logits_ = F.softmax(logits, dim=-1) 111 | pred = torch.cat([test_task_id, logits_.argmax(-1).unsqueeze(-1)], -1) 112 | res_json.extend(pred.tolist()) 113 | 114 | json.dump([ [x[0], idx2f[x[1]]] for x in res_json], open("./ccl-cfn/result/task1_test.json", "w",encoding="utf8"), ensure_ascii=False) 115 | 116 | 117 | best_acc = 0.0 118 | dev_loss = [] 119 | for epoch in range(epochs): 120 | print("epoch:", epoch) 121 | if do_train: 122 | model.train() 123 | for step, batch in enumerate(tqdm(train_loader, desc="Iteration")): 124 | batch = tuple(t.to(device) for t in batch) 125 | input_ids, input_mask, segment_ids, mask_array, label = batch 126 | logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None, 127 | mask_array=mask_array) 128 | 129 | loss = criteon(logits.view(-1, num_labels), label.view(-1)) 130 | # loss = focal_loss(logits, label) 131 | logits_ = F.softmax(logits, dim=-1) 132 | logits_ = logits_.detach().cpu().numpy() 133 | outputs = np.argmax(logits_, axis=1) 134 | batch_acc = np.sum(outputs == label.detach().cpu().numpy()) 135 | loss.backward() 136 | optimizer.step() 137 | scheduler.step() 138 | optimizer.zero_grad() 139 | if step % 100 == 0: 140 | print('Train Epoch : {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc:{}'.format( 141 | epoch, step * len(batch), len(train_loader), 142 | 100. * step / len(train_loader), loss.item(), batch_acc/train_batch_size)) 143 | 144 | if do_dev: 145 | model.eval() 146 | eval_loss, eval_accuracy = 0, 0 147 | for step, batch in enumerate(tqdm(dev_loader, desc="Iteration")): 148 | batch = tuple(t.to(device) for t in batch) 149 | input_ids, input_mask, segment_ids, mask_array, label = batch 150 | with torch.no_grad(): 151 | logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None, 152 | mask_array=mask_array) 153 | 154 | loss = criteon(logits, label) 155 | logits_ = F.softmax(logits, dim=-1) 156 | outputs = np.argmax(logits_.detach().cpu().numpy(), axis=1) 157 | batch_acc = np.sum(outputs == label.detach().cpu().numpy()) 158 | eval_loss += loss 159 | eval_accuracy += batch_acc 160 | eval_acc = eval_accuracy / len(dev_loader) 161 | eval_loss = eval_loss / len(dev_loader) 162 | print("eval_acc:{}".format(eval_acc/dev_batch_size)) 163 | print("eval_loss:{}".format(eval_loss)) 164 | if eval_acc > best_acc: 165 | best_acc = eval_acc 166 | print("best_acc:{}".format(best_acc/dev_batch_size)) 167 | torch.save(model, save_dir+WEIGHTS_NAME+str(epoch)) 168 | do_test() 169 | dev_loss.append(eval_loss) 170 | print("best_acc:{}".format(best_acc/dev_batch_size)) 171 | print(dev_loss) 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import SchedulerType 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description="Finetune a transformers model on Frame Semantic Role Labeling") 6 | parser.add_argument( 7 | "--train_file", type=str, default='./ccl-cfn/cfn-train.json', help="A csv or a json file containing the training data." 8 | ) 9 | parser.add_argument( 10 | "--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data." 11 | ) 12 | parser.add_argument("--do_predict", default=True,action="store_true", help="To do prediction on the question answering model") 13 | parser.add_argument( 14 | "--validation_file", type=str, default='./ccl-cfn/cfn-dev.json', help="A csv or a json file containing the validation data." 15 | ) 16 | parser.add_argument( 17 | "--test_file", type=str, default='./ccl-cfn/cfn-test.json', help="A csv or a json file containing the Prediction data." 18 | ) 19 | parser.add_argument( 20 | "--frame_data", type=str, default='./ccl-cfn/frame_info.json', help="A csv or a json file containing the frame data." 21 | ) 22 | parser.add_argument( 23 | "--task1_res", type=str, default='./ccl-cfn/result/task1_test.json', help="A csv or a json file containing the result of task1." 24 | ) 25 | parser.add_argument( 26 | "--task2_res", type=str, default='./ccl-cfn/result/task2_test.json', help="A csv or a json file containing the result of task2." 27 | ) 28 | parser.add_argument( 29 | "--max_seq_length", 30 | type=int, 31 | default=512, 32 | help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 33 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed.", 34 | ) 35 | parser.add_argument( 36 | "--pad_to_max_length", 37 | action="store_true", 38 | help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.", 39 | ) 40 | parser.add_argument( 41 | "--model_name_or_path", 42 | type=str, 43 | default="./bert-base-chinese", 44 | help="Path to pretrained model or model identifier from huggingface.co/models.", 45 | # required=True, 46 | ) 47 | parser.add_argument( 48 | "--config_name", 49 | type=str, 50 | default=None, 51 | help="Pretrained config name or path if not the same as model_name", 52 | ) 53 | parser.add_argument( 54 | "--tokenizer_name", 55 | type=str, 56 | default=None, 57 | help="Pretrained tokenizer name or path if not the same as model_name", 58 | ) 59 | parser.add_argument( 60 | "--use_slow_tokenizer", 61 | action="store_true", 62 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 63 | ) 64 | parser.add_argument( 65 | "--per_device_train_batch_size", 66 | type=int, 67 | default=16, 68 | help="Batch size (per device) for the training dataloader.", 69 | ) 70 | parser.add_argument( 71 | "--per_device_eval_batch_size", 72 | type=int, 73 | default=16, 74 | help="Batch size (per device) for the evaluation dataloader.", 75 | ) 76 | parser.add_argument( 77 | "--learning_rate", 78 | type=float, 79 | default=5e-5, 80 | help="Initial learning rate (after the potential warmup period) to use.", 81 | ) 82 | parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") 83 | parser.add_argument("--num_train_epochs", type=int, default=20, help="Total number of training epochs to perform.") 84 | parser.add_argument( 85 | "--max_train_steps", 86 | type=int, 87 | default=None, 88 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 89 | ) 90 | parser.add_argument( 91 | "--gradient_accumulation_steps", 92 | type=int, 93 | default=1, 94 | help="Number of updates steps to accumulate before performing a backward/update pass.", 95 | ) 96 | parser.add_argument( 97 | "--lr_scheduler_type", 98 | type=SchedulerType, 99 | default="linear", 100 | help="The scheduler type to use.", 101 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 102 | ) 103 | parser.add_argument( 104 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 105 | ) 106 | parser.add_argument("--output_dir", type=str, default='../ckpt/', help="Where to store the final model.") 107 | parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") 108 | parser.add_argument( 109 | "--max_argument_length", 110 | type=int, 111 | default=30, 112 | help="The maximum length of an answer that can be generated. This is needed because the start " 113 | "and end predictions are not conditioned on one another.", 114 | ) 115 | parser.add_argument( 116 | "--max_train_samples", 117 | type=int, 118 | default=None, 119 | help="For debugging purposes or quicker training, truncate the number of training examples to this " 120 | "value if set.", 121 | ) 122 | parser.add_argument( 123 | "--max_eval_samples", 124 | type=int, 125 | default=None, 126 | help="For debugging purposes or quicker training, truncate the number of evaluation examples to this " 127 | "value if set.", 128 | ) 129 | parser.add_argument( 130 | "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" 131 | ) 132 | parser.add_argument( 133 | "--max_predict_samples", 134 | type=int, 135 | default=None, 136 | help="For debugging purposes or quicker training, truncate the number of prediction examples to this", 137 | ) 138 | parser.add_argument( 139 | "--checkpointing_steps", 140 | type=str, 141 | default='epoch', 142 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 143 | ) 144 | parser.add_argument( 145 | "--resume_from_checkpoint", 146 | type=str, 147 | default=None, 148 | help="If the training should continue from a checkpoint folder.", 149 | ) 150 | parser.add_argument( 151 | "--with_tracking", 152 | action="store_true", 153 | help="Whether to load in all available experiment trackers from the environment and use them for logging.", 154 | ) 155 | parser.add_argument( 156 | "--FE_pooling", 157 | type=str, 158 | default='max', 159 | help="max or avg, how we do pooling over tokens of an FE.", 160 | ) 161 | parser.add_argument( 162 | "--log_every_step", 163 | type=int, 164 | default=None, 165 | help="How many steps do we log loss." 166 | ) 167 | parser.add_argument( 168 | "--post_process", 169 | type=str, 170 | default='greedy' 171 | ) 172 | parser.add_argument("--save_best", action="store_true", help="Whether to save model with best performance on dev dataset.") 173 | parser.add_argument("--loss_on_context", action="store_true", help="Whether to compute loss only on context.") 174 | parser.add_argument('--target', action="store_true", help="Whether to use target as label.") 175 | parser.add_argument( 176 | "--train_file2", 177 | type=str, 178 | default='../data/train_instance_dic_prompt.npy', 179 | ) 180 | parser.add_argument( 181 | "--num_train_epochs1", 182 | type=int, 183 | default=-1 184 | ) 185 | parser.add_argument( 186 | '--model_type', 187 | type=str, 188 | default='srl' 189 | ) 190 | args = parser.parse_args() 191 | 192 | # Sanity checks 193 | if ( 194 | args.train_file is None 195 | and args.validation_file is None 196 | and args.test_file is None 197 | ): 198 | raise ValueError("Need either a dataset name or a training/validation/test file.") 199 | # else: 200 | # if args.train_file is not None: 201 | # extension = args.train_file.split(".")[-1] 202 | # assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 203 | # if args.validation_file is not None: 204 | # extension = args.validation_file.split(".")[-1] 205 | # assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 206 | # if args.test_file is not None: 207 | # extension = args.test_file.split(".")[-1] 208 | # assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." 209 | 210 | 211 | return args -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader, Subset 4 | import numpy as np 5 | from transformers import BertTokenizerFast, PreTrainedTokenizerBase, DataCollatorWithPadding 6 | from tqdm.auto import tqdm 7 | from typing import Optional, Union 8 | import json 9 | 10 | class FrameAIDataset(Dataset): 11 | def __init__(self, data_file, tokenizer): 12 | super(FrameAIDataset, self).__init__() 13 | # print('load data...') 14 | data_instance_dic = {} 15 | with open(data_file, 'r') as f: 16 | data_instance_dic = json.load(f) 17 | # data_instance_dic = np.load(data_file, allow_pickle=True).item() 18 | self.data = [] 19 | # self.model_name_or_path = model_name_or_path 20 | # print('load tokenizer...') 21 | self.tokenizer = tokenizer 22 | # self.tokenizer.add_tokens(['', '', '', '', '', '']) 23 | # bar = tqdm(range(len(data_instance_dic))) 24 | cnt = 0 25 | for item in data_instance_dic: 26 | # for k, v in item.items(): 27 | # print(k, v) 28 | self.tokenize_instance(item) 29 | # exit(0) 30 | 31 | 32 | def align_labels_with_tokens(self, labels, word_ids): 33 | new_labels = [] 34 | current_word = -1 35 | for word_id in word_ids: 36 | if word_id != current_word: 37 | # Start of a new word! 38 | current_word = word_id 39 | label = -100 if word_id == -1 else labels[word_id] 40 | new_labels.append(label) 41 | elif word_id == -1: 42 | new_labels.append(-100) 43 | else: 44 | # Same word as previous token 45 | label = labels[word_id] 46 | # If the label is B-XXX we change it to I-XXX 47 | if label % 2 == 1: 48 | label += 1 49 | new_labels.append(label) 50 | 51 | return new_labels 52 | 53 | def tokenize_instance(self, dic): 54 | # for k, v in dic.items(): 55 | # print(k, v) 56 | data_dic = {} 57 | task_id = dic['sentence_id'] 58 | data_dic['task_id'] = [task_id] 59 | context = list(dic['text']) 60 | context[dic['target']['start']] = ' ' + context[dic['target']['start']] 61 | context[dic['target']["end"]] = context[dic['target']["end"]] + ' ' 62 | labels = [0]*len(context) 63 | for span in dic['cfn_spans']: 64 | labels[span['start']] = 1 65 | for i in range(span['start'] + 1, span['end']+1): 66 | labels[i] = 2 67 | # print(query) 68 | encodings = self.tokenizer(context,is_split_into_words=True, return_length=True) 69 | data_dic['input_ids'] = encodings['input_ids'] 70 | data_dic['attention_mask'] = encodings['attention_mask'] 71 | data_dic['token_type_ids'] = encodings['token_type_ids'] 72 | data_dic['length'] = encodings['length'] 73 | # data_dic['context_length'] = [data_dic['length'][0] - sum(data_dic['token_type_ids']) - 1] 74 | data_dic['word_ids'] = [x if x is not None else -1 for x in encodings.word_ids()] 75 | data_dic['labels'] = self.align_labels_with_tokens(labels, data_dic['word_ids']) 76 | while len(data_dic['word_ids']) < 512: 77 | data_dic['word_ids'].append(-1) 78 | # assert len(data_dic['labels']) == len(data_dic['input_ids']) 79 | # assert len(data_dic['word_ids']) == len(data_dic['input_ids']) 80 | # FE_token_idx_start = [encodings.word_to_tokens(x, sequence_index=1).start for x in dic['frame_data']['FE_word_idx']] 81 | # FE_token_idx_end = [encodings.word_to_tokens(x, sequence_index=1).end - 1 for x in dic['frame_data']['FE_word_idx']] 82 | # FE_token_idx_start = [encodings.word_to_tokens(x, sequence_index=1).start - 1 for x in dic['FE_word_idx']] 83 | # FE_token_idx_end = [encodings.word_to_tokens(x, sequence_index=1).end for x in dic['FE_word_idx']] 84 | # FE_token_idx = [[s, t] for s, t in zip(FE_token_idx_start, FE_token_idx_end)] 85 | # data_dic['FE_num'] = [len(FE_token_idx)] 86 | # data_dic['FE_token_idx'] = FE_token_idx 87 | # data_dic['start_positions'] = [encodings.word_to_tokens(x-1, sequence_index=0).start if x > 0 else 0 for x in start_positions] 88 | # data_dic['end_positions'] = [encodings.word_to_tokens(x-1, sequence_index=0).end - 1 if x > 0 else 0 for x in end_positions] 89 | # data_dic['gt_FE_word_idx'] = dic['gt_FE_word_idx'] 90 | # data_dic['gt_start_positions'] = dic['gt_start_positions'] 91 | # data_dic['gt_end_positions'] = dic['gt_end_positions'] 92 | # data_dic['FE_core_pts'] = dic['FE_core_pts'] 93 | self.data.append(data_dic) 94 | # for k, v in data_dic.items(): 95 | # print(k, v) 96 | 97 | def __len__(self): 98 | return len(self.data) 99 | 100 | def __getitem__(self, index): 101 | return self.data[index] 102 | 103 | def subset(self, indices): 104 | return Subset(self, indices=indices) 105 | 106 | class FrameRCDataset(Dataset): 107 | def __init__(self, data_file, tokenizer, fe2id, task1_res=None, task2_res=None): 108 | super(FrameRCDataset, self).__init__() 109 | # print('load data...') 110 | data_instance_dic = {} 111 | with open(data_file, 'r') as f: 112 | data_instance_dic = json.load(f) 113 | # data_instance_dic = np.load(data_file, allow_pickle=True).item() 114 | self.data = [] 115 | if task1_res is not None: 116 | tid2frame = {} 117 | with open(task1_res, 'r') as f: 118 | task1_data = json.load(f) 119 | for item in task1_data: 120 | tid2frame[item[0]] = item[1] 121 | for item in data_instance_dic: 122 | item['pred_frame'] = tid2frame[item['sentence_id']] 123 | if task2_res is not None: 124 | tid2spans = {} 125 | tid2spansets = {} 126 | with open(task2_res, 'r') as f: 127 | task2_data = json.load(f) 128 | for item in task2_data: 129 | tid2spans.setdefault(item[0], []) 130 | tid2spansets.setdefault(item[0], set()) 131 | if (item[1], item[2]) not in tid2spansets[item[0]]: 132 | tid2spans[item[0]].append({ "start":item[1],"end":item[2] }) 133 | tid2spansets[item[0]].add((item[1], item[2])) 134 | for item in data_instance_dic: 135 | if item['sentence_id'] in tid2spans: 136 | item['pred_spans'] = tid2spans[item['sentence_id']] 137 | else: 138 | item['pred_spans'] = [] 139 | 140 | for i, item in enumerate(data_instance_dic): 141 | data_instance_dic[i]['labels'] = [] 142 | for span in item['cfn_spans']: 143 | data_instance_dic[i]['labels'].append(fe2id[span['fe_name']]) 144 | # self.model_name_or_path = model_name_or_path 145 | # print('load tokenizer...') 146 | self.tokenizer = tokenizer 147 | # self.tokenizer.add_tokens(['', '', '', '', '', '']) 148 | # bar = tqdm(range(len(data_instance_dic))) 149 | cnt = 0 150 | for item in data_instance_dic: 151 | # for k, v in item.items(): 152 | # print(k, v) 153 | self.tokenize_instance(item) 154 | # exit(0) 155 | 156 | 157 | def tokenize_instance(self, dic): 158 | # for k, v in dic.items(): 159 | # print(k, v) 160 | data_dic = {} 161 | task_id = dic['sentence_id'] 162 | data_dic['task_id'] = [task_id] 163 | context = list(dic['text']) 164 | context[dic['target']['start']] = ' ' + context[dic['target']['start']] 165 | context[dic['target']['end']] = context[dic['target']['end']] + ' ' 166 | frame_key = 'pred_frame' if 'pred_frame' in dic else 'frame' 167 | context[dic['target']['start']] = context[dic['target']['start']] + ' ' + dic[frame_key] + ' ' 168 | span_key = 'pred_spans' if 'pred_spans' in dic else 'cfn_spans' 169 | start_positions = [] 170 | end_positions = [] 171 | for span in dic[span_key]: 172 | context[span['start']] = ' ' + context[span['start']] 173 | start_positions.append(span['start']) 174 | context[span['end']] = context[span['end']] + ' ' 175 | end_positions.append(span['end']) 176 | # print(query) 177 | encodings = self.tokenizer(context,is_split_into_words=True, return_length=True) 178 | data_dic['input_ids'] = encodings['input_ids'] 179 | data_dic['attention_mask'] = encodings['attention_mask'] 180 | data_dic['token_type_ids'] = encodings['token_type_ids'] 181 | data_dic['length'] = encodings['length'] 182 | # data_dic['context_length'] = [data_dic['length'][0] - sum(data_dic['token_type_ids']) - 1] 183 | word_ids = [x if x is not None else -1 for x in encodings.word_ids()] 184 | data_dic['labels'] = dic['labels'] 185 | while len(data_dic['labels']) < 16: 186 | data_dic['labels'].append(-100) 187 | # assert len(data_dic['labels']) == len(data_dic['input_ids']) 188 | # assert len(data_dic['word_ids']) == len(data_dic['input_ids']) 189 | # FE_token_idx_start = [encodings.word_to_tokens(x, sequence_index=1).start for x in dic['frame_data']['FE_word_idx']] 190 | # FE_token_idx_end = [encodings.word_to_tokens(x, sequence_index=1).end - 1 for x in dic['frame_data']['FE_word_idx']] 191 | # FE_token_idx_start = [encodings.word_to_tokens(x, sequence_index=1).start - 1 for x in dic['FE_word_idx']] 192 | # FE_token_idx_end = [encodings.word_to_tokens(x, sequence_index=1).end for x in dic['FE_word_idx']] 193 | # FE_token_idx = [[s, t] for s, t in zip(FE_token_idx_start, FE_token_idx_end)] 194 | # data_dic['FE_num'] = [len(start_positions)] 195 | # data_dic['FE_token_idx'] = FE_token_idx 196 | token_start_positions = [encodings.word_to_tokens(x, sequence_index=0).start for x in start_positions] 197 | token_end_positions = [encodings.word_to_tokens(x, sequence_index=0).end - 1 for x in end_positions] 198 | while len(token_start_positions) < 16: 199 | token_start_positions.append(0) 200 | while len(token_end_positions) < 16: 201 | token_end_positions.append(0) 202 | span_token_idx = [[s, t] for s, t in zip(token_start_positions, token_end_positions)] 203 | data_dic['span_token_idx'] = span_token_idx 204 | 205 | # data_dic['gt_FE_word_idx'] = dic['gt_FE_word_idx'] 206 | # data_dic['gt_start_positions'] = dic['gt_start_positions'] 207 | # data_dic['gt_end_positions'] = dic['gt_end_positions'] 208 | # data_dic['FE_core_pts'] = dic['FE_core_pts'] 209 | self.data.append(data_dic) 210 | # for k, v in data_dic.items(): 211 | # print(k, v) 212 | 213 | def __len__(self): 214 | return len(self.data) 215 | 216 | def __getitem__(self, index): 217 | return self.data[index] 218 | 219 | def subset(self, indices): 220 | return Subset(self, indices=indices) 221 | 222 | 223 | 224 | if __name__ == '__main__': 225 | # print(frame_data[1]) 226 | # exit(0) 227 | tokenizer = BertTokenizerFast.from_pretrained('hfl/chinese-bert-wwm-ext') 228 | tokenizer.add_tokens(['', '', '', '', '', '']) 229 | # tokenizer.add_tokens(['', '']) 230 | with open('./ccl-cfn/frame_data.json', 'r') as f: 231 | data = json.load(f) 232 | fe2id = {} 233 | cnt = 0 234 | for frame in data: 235 | for fe in frame['fes']: 236 | name = fe['fe_abbr'] 237 | if name not in fe2id: 238 | fe2id[name] = cnt 239 | cnt += 1 240 | d = FrameRCDataset('./ccl-cfn/cfn-dev.json', tokenizer, fe2id, './ccl-cfn/result/task1_dev.json', './ccl-cfn/result/task2_dev.json') 241 | dd = Subset(d, range(8)) 242 | dc = DataCollatorWithPadding(tokenizer) 243 | dl = DataLoader(dd, 4, shuffle=False, collate_fn=dc) 244 | # print('hi') 245 | # print(len(dl)) 246 | for b in dl: 247 | for k, v in b.items(): 248 | print(k, v) 249 | # break 250 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 更新日志 2 | 3 | * 评测主页迁移,当前页面为数据集页面,请参赛选手到算法比赛页面参与评测,赛制规则和赛程安排以算法比赛页面为准,已在数据集页面提交过评测结果的参赛选手需要在算法比赛页面中重新提交,算法比赛主页链接: 4 | [CCL2023-Eval 汉语框架语义解析评测_算法大赛](https://tianchi.aliyun.com/competition/entrance/532083/introduction) 5 | 6 | * 报名方式更新:分为以下两个步骤,更多报名细节见算法比赛首页: [CCL2023-Eval 汉语框架语义解析评测](https://tianchi.aliyun.com/competition/entrance/532083/introduction)。 7 | 8 | 1. 4月1日阿里天池平台([https://tianchi.aliyun.com/](https://tianchi.aliyun.com/))将开放本次比赛的报名组队、登录比赛官网([CCL2023-Eval 汉语框架语义解析评测_算法大赛](https://tianchi.aliyun.com/competition/entrance/532083/introduction)),完成个人信息注册,即可报名参赛;选手可以单人参赛,也可以组队参赛。组队参赛的每个团队2-3人,每位选手只能加入一支队伍;选手需确保报名信息准确有效,组委会有权取消不符合条件队伍的参赛资格及奖励;选手报名、组队变更等操作截止时间为5月27日23:59:59;各队伍(包括队长及全体队伍成员)需要在5月27日23:59:59前完成实名认证(认证入口:天池官网-右上角个人中心-认证-支付宝实名认证),未完成认证的参赛团队将无法进行后续的比赛; 9 | 10 | 2. 向赛题举办方发送电子邮件进行报名,以获取数据解压密码。邮件标题为:“CCL2023-汉语框架语义解析评测-参赛单位”,例如:“CCL2023-汉语框架语义解析评测-复旦大学”;附件内容为队伍的参赛报名表,报名表[点此下载](https://github.com/SXUNLP/Chinese-Frame-Semantic-Parsing/blob/main/%E6%B1%89%E8%AF%AD%E6%A1%86%E6%9E%B6%E8%AF%AD%E4%B9%89%E8%A7%A3%E6%9E%90%E8%AF%84%E6%B5%8B%E6%8A%A5%E5%90%8D%E8%A1%A8.docx),同时报名表应更名为“参赛队名+参赛队长信息+参赛单位名称”。请参加评测的队伍发送邮件至 [202122407024@email.sxu.edu.cn](mailto:202122407024@email.sxu.edu.cn),报名成功后赛题数据解压密码会通过邮件发送给参赛选手,选手在天池平台下载数据即可 11 | 12 | **注意:报名截止前未发送报名邮件者不参与后续的评选。** 13 | * 报名时间由 4月1日-4月31日 更新为 4月1日-5月27日。 14 | 15 | 16 | # 总体概述 17 | 18 | * CFN 1.0数据集是由山西大学以汉语真实语料为依据构建的框架语义资源,数据由框架知识及标注例句组成,包含了近700个语义框架及20000条标注例句。CFN 1.0数据集遵循[CC BY-NC 4.0协议](https://creativecommons.org/licenses/by-nc/4.0/)。 19 | * 框架语义解析(Frame Semantic Parsing,FSP)是自然语言处理领域中的一项重要任务,其目标是从句子中提取框架语义结构[1],实现对句子中涉及到的事件或情境的深层理解。FSP在阅读理解[2-3]、文本摘要[4-5]、关系抽取[6]等下游任务有着重要意义。 20 | 21 | 22 | # 任务介绍 23 | 24 | 汉语框架语义解析(Chinese FSP,CFSP)是基于汉语框架网(Chinese FrameNet, CFN)的语义解析任务,本次标注数据格式如下: 25 | 26 | 1. 标注数据的字段信息如下: 27 | + sentence_id:例句id 28 | + cfn_spans:框架元素标注信息 29 | + frame:例句所激活的框架名称 30 | + target:目标词的相关信息 31 | + start:目标词在句中的起始位置 32 | + end:目标词在句中的结束位置 33 | + pos:目标词的词性 34 | + text:标注例句 35 | + word:例句的分词结果及其词性信息 36 | 37 | 数据样例: 38 | ```json 39 | [{ 40 | "sentence_id": 2611, 41 | "cfn_spans": [ 42 | { "start": 0, "end": 2, "fe_abbr": "ent_1", "fe_name": "实体1" }, 43 | { "start": 4, "end": 17, "fe_abbr": "ent_2", "fe_name": "实体2" } 44 | ], 45 | "frame": "等同", 46 | "target": { "start": 3, "end": 3, "pos": "v" }, 47 | "text": "餐饮业是天津市在海外投资的重点之一。", 48 | "word": [ 49 | { "start": 0, "end": 2, "pos": "n" }, 50 | { "start": 3, "end": 3, "pos": "v" }, 51 | { "start": 4, "end": 6, "pos": "nz" }, 52 | { "start": 7, "end": 7, "pos": "p" }, 53 | { "start": 8, "end": 9, "pos": "n" }, 54 | { "start": 10, "end": 11, "pos": "v" }, 55 | { "start": 12, "end": 12, "pos": "u" }, 56 | { "start": 13, "end": 14, "pos": "n" }, 57 | { "start": 15, "end": 16, "pos": "n" }, 58 | { "start": 17, "end": 17, "pos": "wp" } 59 | ] 60 | }] 61 | ``` 62 | 63 | 2. 框架信息在`frame_info.json`中,框架数据的字段信息如下: 64 | + frame_name:框架名称 65 | + frame_ename:框架英文名称 66 | + frame_def:框架定义 67 | + fes:框架元素信息 68 | + fe_name:框架元素名称 69 | + fe_abbr:框架元素缩写 70 | + fe_ename:框架元素英文名称 71 | + fe_def:框架元素定义 72 | 73 | 数据样例: 74 | ```json 75 | [{ 76 | "frame_name": "等同", 77 | "frame_ename": "Equating", 78 | "frame_def": "表示两个实体具有相等、相同、同等看待等的关系。", 79 | "fes": [ 80 | { "fe_name": "实体集", "fe_abbr": "ents", "fe_ename": "Entities", "fe_def": "具有同等关系的两个或多个实体" }, 81 | { "fe_name": "实体1", "fe_abbr": "ent_1", "fe_ename": "Entity_1", "fe_def": "与实体2具有等同关系的实体" }, 82 | { "fe_name": "实体2", "fe_abbr": "ent_2", "fe_ename": "Entity_2", "fe_def": "与实体1具有等同关系的实体" }, 83 | { "fe_name": "施动者", "fe_abbr": "agt", "fe_ename": "Agent", "fe_def": "判断实体集具有同等关系的人。" }, 84 | { "fe_name": "方式", "fe_abbr": "manr", "fe_ename": "Manner", "fe_def": "修饰用来概括无法归入其他更具体的框架元素的任何语义成分,包括认知的修饰(如很可能,大概,神秘地),辅助描述(安静地,大声地),和与事件相比较的一般描述(同样的方式)。" }, 85 | { "fe_name": "时间", "fe_abbr": "time", "fe_ename": "Time", "fe_def": "实体之间具有等同关系的时间" } 86 | ] 87 | }] 88 | ``` 89 | 90 | 本次评测共分为以下三个子任务: 91 | 92 | * 子任务1: 框架识别(Frame Identification),识别句子中给定目标词激活的框架。 93 | * 子任务2: 论元范围识别(Argument Identification),识别句子中给定目标词所支配论元的边界范围。 94 | * 子任务3: 论元角色识别(Role Identification),预测子任务2所识别论元的语义角色标签。 95 | 96 | 97 | ## 子任务1: 框架识别(Frame Identification) 98 | ### 1. 任务描述 99 | 100 | 框架识别任务是框架语义学研究中的核心任务,其要求根据给定句子中目标词的上下文语境,为其寻找一个可以激活的框架。框架识别任务是自然语言处理中非常重要的任务之一,它可以帮助计算机更好地理解人类语言,并进一步实现语言处理的自动化和智能化。具体来说,框架识别任务可以帮助计算机识别出句子中的关键信息和语义框架,从而更好地理解句子的含义。这对于自然语言处理中的许多任务都是至关重要的。 101 | 102 | ### 2. 任务说明 103 | 104 | 该任务给定一个包含目标词的句子,需要根据目标词语境识别出激活的框架,并给出识别出的框架名称。 105 | 106 | 1. 输入:句子相关信息(id和文本内容)及目标词。 107 | 2. 输出:句子id及目标词所激活框架的识别结果,数据为json格式,所有例句的识别结果需要放在同一list中,样例如下: 108 | 109 | ```json 110 | [ 111 | [2611, "事件发生场所停业"], 112 | [2612, "等同"], 113 | ... 114 | ] 115 | ``` 116 | 117 | ### 3. 评测指标 118 | 119 |   框架识别采用正确率作为评价指标: 120 | 121 | $$task1\_acc = 正确识别的个数 / 总数$$ 122 | 123 | 124 | ## 子任务2: 论元范围识别(Argument Identification) 125 | 126 | ### 1. 任务描述 127 | 128 | 给定一句汉语句子及目标词,在目标词已知的条件下,从句子中自动识别出目标词所搭配的语义角色的边界。该任务的主要目的是确定句子中目标词所涉及的每个论元在句子中的位置。论元范围识别任务对于框架语义解析任务来说非常重要,因为正确识别谓词和论元的范围可以帮助系统更准确地识别论元的语义角色,并进一步分析句子的语义结构。 129 | 130 | ### 2. 任务说明 131 | 132 | 论元范围识别任务是指,在给定包含目标词的句子中,识别出目标词所支配的语义角色的边界。 133 | 134 | 1. 输入:句子相关信息(id和文本内容)及目标词。 135 | 2. 输出:句子id,及所识别出所有论元角色的范围,每组结果包含例句id:`task_id`, `span`起始位置, `span`结束位置,每句包含的论元数量不定,识别出多个论元需要添加多个元组,所有例句识别出的结果共同放存在一个list中,样例如下: 136 | ```json 137 | [ 138 | [ 2611, 0, 2 ], 139 | [ 2611, 4, 17], 140 | ... 141 | [ 2612, 5, 7], 142 | ... 143 | ] 144 | ``` 145 | 146 | ### 3. 评测指标 147 | 148 | 论元范围识别采用P、R、F1作为评价指标: 149 | 150 | $${\rm{precision}} = \frac{{{\rm{InterSec(gold,pred)}}}}{{{\rm{Len(pred)}}}}$$ 151 | 152 | $${\rm{recall}} = \frac{{{\rm{InterSec(gold,pred)}}}}{{{\rm{Len(gold)}}}}$$ 153 | 154 | $${\rm{task2\\_f1}} = \frac{{{\rm{2\*precision\*recall}}}}{{{\rm{precision}} + {\rm{recall}}}}$$ 155 | 156 | 其中:gold 和 pred 分别表示真实结果与预测结果,InterSec(\*)表示计算二者共有的token数量, Len(\*)表示计算token数量。 157 | 158 | ## 子任务3: 论元角色识别(Role Identification) 159 | 160 | ### 1. 任务描述 161 | 162 | 框架语义解析任务中,论元角色识别任务是非常重要的一部分。该任务旨在确定句子中每个论元对应的框架元素,即每个论元在所属框架中的语义角色。例如,在“我昨天买了一本书”这个句子中,“我”是“商业购买”框架中的“买方”框架元素,“一本书”是“商品”框架元素。论元角色识别任务对于许多自然语言处理任务都是至关重要的,例如信息提取、关系抽取和机器翻译等。它可以帮助计算机更好地理解句子的含义,从而更准确地提取句子中的信息,进而帮助人们更好地理解文本。 163 | 164 | ### 2. 任务说明 165 | 166 | 论元角色识别任务是指,在给定包含目标词的句子中,识别出目标词所支配语义角色的角色名称,该任务需要论元角色的边界信息以及目标词所激活框架的信息(即子任务1和子任务2的结果)。 167 | 框架及其框架元素的所属关系在`frame_info.json`文件中。 168 | 169 | 170 | 1. 输入:句子相关信息(id和文本内容)、目标词、框架信息以及论元角色范围。 171 | 2. 输出:句子id,及论元角色识别的结果,示例中“实体集”和“施动者”是“等同”框架中的框架元素。注意所有例句识别出的结果应共同放存在一个list中,样例如下: 172 | 173 | ```json 174 | [ 175 | [ 2611, 0, 2, "实体集" ], 176 | [ 2611, 4, 17, "施动者" ], 177 | ... 178 | [ 2612, 5, 7, "时间" ], 179 | ... 180 | ] 181 | ``` 182 | 183 | ### 3. 评测指标 184 | 论元角色识别采用P、R、F1作为评价指标: 185 | $${\rm{precision}} = \frac{{{\rm{Count(gold \cap pred)}}}} {{{\rm{Count(pred)}}}}$$ 186 | 187 | $${\rm{recall}} = \frac{{{\rm{Count(gold \cap pred)}}}} {{{\rm{Count(gold)}}}}$$ 188 | 189 | $${\rm{task3\\_f1}} = \frac{{{\rm{2\*precision\*recall}}}}{{{\rm{precision}} + {\rm{recall}}}}$$ 190 | 191 | 其中,gold 和 pred 分别表示真实结果与预测结果, Count(\*) 表示计算集合元素的数量。 192 | 193 | 194 | 195 | # 结果提交 196 | 本次评测结果在阿里天池平台上进行提交和排名。参赛队伍需要在阿里天池平台的“提交结果”界面提交预测结果,提交的压缩包命名为submit.zip,其中包含三个子任务的预测文件。 197 | 198 | + submit.zip 199 | + A_task1_test.json 200 | + A_task2_test.json 201 | + A_task3_test.json 202 | 203 | 1. 三个任务的提交结果需严格命名为A_task1_test.json、A_task2_test.json和A_task3_test.json。 2. 请严格使用`zip submit.zip A_task1_test.json A_task2_test.json A_task3_test.json` 进行压缩,即要求解压后的文件不能存在中间目录。 选⼿可以只提交部分任务的结果,如只提交“框架识别”任务:`zip submit.zip A_task1_test.json`,未预测任务的分数默认为0。 204 | 205 | # 系统排名 206 | 207 | 1. 所有评测任务均采用百分制分数显示,小数点后保留2位。 208 | 2. 系统排名取各项任务得分的加权和(三个子任务权重依次为 0.3,0.3,0.4),即: 209 | ${\rm{task\_score=0.3*task1\_acc+0.3*task2\_f1+0.4*task3\_f1}} $ 210 | 3. 如果某项任务未提交,默认分数为0,仍参与到系统最终得分的计算。 211 | 212 | # Baseline 213 | Baseline下载地址:[Github](https://github.com/SXUNLP/Chinese-Frame-Semantic-Parsing) 214 | Baseline表现: 215 | |task1_acc|task2_f1|task3_f1|task_score| 216 | |---------|--------|--------|----------| 217 | |65.1|87.55|54.07|67.42| 218 | 219 | 220 | 221 | # 评测数据 222 | 223 | 数据由json格式给出,数据集包含以下内容: 224 | 225 | + CFN-train.json: 训练集标注数据,10000条。 226 | + CFN-dev.json: 验证集标注数据,2000条。 227 | + CFN-test-A.json: A榜测试集,4000条。 228 | + CFN-test-B.json: B榜测试集,4000条。B榜开赛前开放下载。 229 | + frame_info.json: 框架信息。 230 | + result.zip:提交示例。 231 | + A_task1_test.json:task1子任务提交示例。 232 | + A_task2_test.json:task2子任务提交示例。 233 | + A_task3_test.json:task3子任务提交示例。 234 | + README.md: 说明文件。 235 | 236 | 237 | 238 | # 数据集信息 239 | 240 | * 数据集提供方:山西大学智能计算与中文信息处理教育部重点实验室,山西太原 030000。 241 | * 负责人:谭红叶 tanhongye@sxu.edu.cn。 242 | * 联系人:闫智超 202022408073@email.sxu.edu.cn、李俊材 202122407024@email.sxu.edu.cn。 243 | 244 | 245 | # 赛程安排 246 | 247 | 本次大赛分为报名组队、A榜、B榜三个阶段,具体安排和要求如下: 248 | 1. 报名时间:4月1日-5月27日 249 | 2. 训练、验证数据及baseline发布:4月10日 250 | 3. 测试A榜数据发布:4月11日 251 | 4. 测试A榜评测截止:5月29日 17:59:59 252 | 5. 测试B榜数据发布:5月31日 253 | 6. 测试B榜最终测试结果:6月2日 17:59:59 254 | 7. 公布测试结果:6月10日前 255 | 8. 提交中文或英文技术报告:6月20日 256 | 9. 中文或英文技术报告反馈:6月28日 257 | 10. 正式提交中英文评测论文:7月3日 258 | 11. 公布获奖名单:7月7日 259 | 12. 评测报告及颁奖:8月3-5日 260 | 261 | 262 | **注意:报名组队与实名认证(2023年4月1日—5月27日)** 263 | 264 | # 赛事规则 265 | 266 | 1. 由于版权保护问题,CFN数据集只免费提供给用户用于非盈利性科学研究使用,参赛人员不得将数据用于任何商业用途。如果用于商业产品,请联系柴清华老师,联系邮箱 [charles@sxu.edu.cn](mailto:charles@sxu.edu.cn)。 267 | 2. 每名参赛选手只能参加一支队伍,一旦发现某选手以注册多个账号的方式参加多支队伍,将取消相关队伍的参赛资格。 268 | 3. 数据集的具体内容、范围、规模及格式以最终发布的真实数据集为准。验证集不可用于模型训练,针对测试集,参赛人员不允许执行任何人工标注。 269 | 4. 参赛队伍可在参赛期间随时上传测试集的预测结果,阿里天池平台A榜阶段每天可提交3次、B榜阶段每天可提交5次,系统会实时更新当前最新榜单排名情况,严禁参赛团队注册其它账号多次提交。 270 | 5. 允许使用公开的代码、工具、外部数据(从其他渠道获得的标注数据)等,但需要保证参赛结果可以复现。 271 | 6. 参赛队伍可以自行设计和调整模型,但需注意模型参数量最多不超过1.5倍BERT-Large(510M)。 272 | 7. 算法与系统的知识产权归参赛队伍所有。要求最终结果排名前10的队伍提供算法代码与系统报告(包括方法说明、数据处理、参考文献和使用的开源工具、外部数据等信息)。提交完毕将采用随机交叉检查的方法对各个队伍提交的模型进行检验,如果在排行榜上的结果无法复现,将取消获奖资格。 273 | 8. 参赛团队需保证提交作品的合规性,若出现下列或其他重大违规的情况,将取消参赛团队的参赛资格和成绩,获奖团队名单依次递补。重大违规情况如下: 274 | a. 使用小号、串通、剽窃他人代码等涉嫌违规、作弊行为; 275 | b. 团队提交的材料内容不完整,或提交任何虚假信息; 276 | c. 参赛团队无法就作品疑义进行足够信服的解释说明; 277 | 9. 获奖队伍必须注册会议并在线下参加(如遇特殊情况,可申请线上参加)。 278 | 10. 评测单位:山西大学、北京大学、南京大学。 279 | 11. 评测负责人:谭红叶 tanhongye@sxu.edu.cn;联系人:闫智超 202022408073@email.sxu.edu.cn、李俊材 202122407024@email.sxu.edu.cn。 280 | 281 | 282 | # 报名方式 283 |   本次评测采用电子邮件进行报名,邮件标题为:“CCL2023-汉语框架语义解析评测-参赛单位”,例如:“CCL2023-汉语框架语义解析评测-山西大学”;附件内容为队伍的参赛报名表,报名表[点此下载](https://github.com/SXUNLP/Chinese-Frame-Semantic-Parsing/blob/main/%E6%B1%89%E8%AF%AD%E6%A1%86%E6%9E%B6%E8%AF%AD%E4%B9%89%E8%A7%A3%E6%9E%90%E8%AF%84%E6%B5%8B%E6%8A%A5%E5%90%8D%E8%A1%A8.docx),同时报名表应更名为“参赛队名+参赛队长信息+参赛 单位名称”。请参加评测的队伍发送邮件至202122407024@email.sxu.edu.cn,并同时在阿里天池平台完成报名,完成报名后需加入评测交流群:22240029459 284 | 285 | * 报名截止前未发送报名邮件者不参与后续的评选。 286 | * 大赛技术交流群: 请加钉钉群 22240029459 。 287 | 288 | # 评测网址 289 | 评测首页:[[CCL2023-Eval 汉语框架语义解析评测_算法大赛](https://tianchi.aliyun.com/competition/entrance/532083/introduction)](https://tianchi.aliyun.com/competition/entrance/532083/introduction) 290 | 数据集网址:[https://tianchi.aliyun.com/dataset/149079]( https://tianchi.aliyun.com/dataset/149079) 291 | GitHub:[https://github.com/SXUNLP/Chinese-Frame-Semantic-Parsing](https://github.com/SXUNLP/Chinese-Frame-Semantic-Parsing) 292 | 293 | # 奖项信息 294 | 本次评测将评选出如下奖项。 295 | 由中国中文信息学会计算语言学专委会(CIPS-CL)为获奖队伍提供荣誉证书。 296 | 297 | |奖项|一等奖|二等奖|三等奖| 298 | |----|----|----|----| 299 | |数量|1名|待定| 待定 | 300 | |奖励|荣誉证书|荣誉证书|荣誉证书| 301 | 302 | 303 | # 数据集协议 304 | 305 | 该数据集遵循协议: [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/?spm=5176.12282016.0.0.7a0a1517bGbbHL)。 306 | 307 | 由于版权保护问题,CFN数据集只免费提供给用户用于非盈利性科学研究使用,参赛人员不得将数据用于任何商业用途。如果用于商业产品,请联系柴清华老师,联系邮箱 charles@sxu.edu.cn。 308 | 309 | 310 | # FAQ 311 | 312 | * Q:比赛是否有技术交流群? 313 | * A:请加钉钉群 22240029459 。 314 | * Q:数据集解压密码是什么? 315 | * A:请阅读“如何报名”,发送邮件报名成功后接收解压邮件。 316 | * Q:验证集可否用于模型训练? 317 | * A:不可以。 318 | 319 | 320 | # 参考文献 321 | [1] Daniel Gildea and Daniel Jurafsky. 2002. Automatic labeling of semantic roles. Computational linguistics,28(3):245–288. 322 | [2] Shaoru Guo, Ru Li*, Hongye Tan, Xiaoli Li, Yong Guan. A Frame-based Sentence Representation for Machine Reading Comprehension[C]. Proceedings of the 58th Annual Meeting of the Association for Computational Linguistic (ACL), 2020: 891-896. 323 | [3] Shaoru Guo, Yong Guan, Ru Li*, Xiaoli Li, Hongye Tan. Incorporating Syntax and Frame Semantics in Neural Network for Machine Reading Comprehension[C]. Proceedings of the 28th International Conference on Computational Linguistics (COLING), 2020: 2635-2641. 324 | [4] Yong Guan, Shaoru Guo, Ru Li*, Xiaoli Li, and Hu Zhang. Integrating Semantic Scenario and Word Relations for Abstractive Sentence Summarization[C]. Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP) 2021: 2522-2529. 325 | [5] Yong Guan, Shaoru Guo, Ru Li*, Xiaoli Li, and Hongye Tan, 2021. Frame Semantic-Enhanced Sentence Modeling for Sentence-level Extractive Text Summarization[C]. Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP) 2021: 404-4052. 326 | [6] Hongyan Zhao, Ru Li*, Xiaoli Li, Hongye Tan. CFSRE: Context-aware based on frame-semantics for distantly supervised relation extraction[J]. Knowledge-Based Systems, 2020, 210: 106480. 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert_wwm-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert_wwm-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert_wwm-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert_wwm-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert_wwm-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert_wwm-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert_wwm-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert_wwm-base-uncased': 512, 41 | 'bert_wwm-large-uncased': 512, 42 | 'bert_wwm-base-cased': 512, 43 | 'bert_wwm-large-cased': 512, 44 | 'bert_wwm-base-multilingual-uncased': 512, 45 | 'bert_wwm-base-multilingual-cased': 512, 46 | 'bert_wwm-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | vocab = collections.OrderedDict() 54 | index = 0 55 | with open(vocab_file, "r", encoding="utf-8") as reader: 56 | while True: 57 | token = reader.readline() 58 | if not token: 59 | break 60 | token = token.strip() 61 | vocab[token] = index 62 | index += 1 63 | return vocab 64 | 65 | 66 | def whitespace_tokenize(text): 67 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 68 | text = text.strip() 69 | if not text: 70 | return [] 71 | tokens = text.split() 72 | return tokens 73 | 74 | 75 | class BertTokenizer(object): 76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 77 | 78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 79 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 80 | if not os.path.isfile(vocab_file): 81 | raise ValueError( 82 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 83 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 84 | self.vocab = load_vocab(vocab_file) 85 | self.ids_to_tokens = collections.OrderedDict( 86 | [(ids, tok) for tok, ids in self.vocab.items()]) 87 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 88 | never_split=never_split) 89 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 90 | self.max_len = max_len if max_len is not None else int(1e12) 91 | 92 | def tokenize(self, text): 93 | split_tokens = [] 94 | for token in self.basic_tokenizer.tokenize(text): 95 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 96 | split_tokens.append(sub_token) 97 | return split_tokens 98 | 99 | def convert_tokens_to_ids(self, tokens): 100 | """Converts a sequence of tokens into ids using the vocab.""" 101 | ids = [] 102 | for token in tokens: 103 | ids.append(self.vocab[token]) 104 | if len(ids) > self.max_len: 105 | raise ValueError( 106 | "Token indices sequence length is longer than the specified maximum " 107 | " sequence length for this BERT model ({} > {}). Running this" 108 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 109 | ) 110 | return ids 111 | 112 | def convert_ids_to_tokens(self, ids): 113 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 114 | tokens = [] 115 | for i in ids: 116 | tokens.append(self.ids_to_tokens[i]) 117 | return tokens 118 | 119 | @classmethod 120 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 121 | """ 122 | Instantiate a PreTrainedBertModel from a pre-trained model file. 123 | Download and cache the pre-trained model file if needed. 124 | """ 125 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 126 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 127 | else: 128 | vocab_file = pretrained_model_name 129 | if os.path.isdir(vocab_file): 130 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 131 | # redirect to the cache, if necessary 132 | try: 133 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 134 | except FileNotFoundError: 135 | logger.error( 136 | "Model name '{}' was not found in model name list ({}). " 137 | "We assumed '{}' was a path or url but couldn't find any file " 138 | "associated to this path or url.".format( 139 | pretrained_model_name, 140 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 141 | vocab_file)) 142 | return None 143 | if resolved_vocab_file == vocab_file: 144 | logger.info("loading vocabulary file {}".format(vocab_file)) 145 | else: 146 | logger.info("loading vocabulary file {} from cache at {}".format( 147 | vocab_file, resolved_vocab_file)) 148 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 149 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 150 | # than the number of positional embeddings 151 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 152 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 153 | # Instantiate tokenizer. 154 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 155 | return tokenizer 156 | 157 | 158 | class BasicTokenizer(object): 159 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 160 | 161 | def __init__(self, 162 | do_lower_case=True, 163 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 164 | """Constructs a BasicTokenizer. 165 | 166 | Args: 167 | do_lower_case: Whether to lower case the input. 168 | """ 169 | self.do_lower_case = do_lower_case 170 | self.never_split = never_split 171 | 172 | def tokenize(self, text): 173 | """Tokenizes a piece of text.""" 174 | text = self._clean_text(text) 175 | # This was added on November 1st, 2018 for the multilingual and Chinese 176 | # models. This is also applied to the English models now, but it doesn't 177 | # matter since the English models were not trained on any Chinese data 178 | # and generally don't have any Chinese data in them (there are Chinese 179 | # characters in the vocabulary because Wikipedia does have some Chinese 180 | # words in the English Wikipedia.). 181 | text = self._tokenize_chinese_chars(text) 182 | orig_tokens = whitespace_tokenize(text) 183 | split_tokens = [] 184 | for token in orig_tokens: 185 | if self.do_lower_case and token not in self.never_split: 186 | token = token.lower() 187 | token = self._run_strip_accents(token) 188 | split_tokens.extend(self._run_split_on_punc(token)) 189 | 190 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 191 | return output_tokens 192 | 193 | def _run_strip_accents(self, text): 194 | """Strips accents from a piece of text.""" 195 | text = unicodedata.normalize("NFD", text) 196 | output = [] 197 | for char in text: 198 | cat = unicodedata.category(char) 199 | if cat == "Mn": 200 | continue 201 | output.append(char) 202 | return "".join(output) 203 | 204 | def _run_split_on_punc(self, text): 205 | """Splits punctuation on a piece of text.""" 206 | if text in self.never_split: 207 | return [text] 208 | chars = list(text) 209 | i = 0 210 | start_new_word = True 211 | output = [] 212 | while i < len(chars): 213 | char = chars[i] 214 | if _is_punctuation(char): 215 | output.append([char]) 216 | start_new_word = True 217 | else: 218 | if start_new_word: 219 | output.append([]) 220 | start_new_word = False 221 | output[-1].append(char) 222 | i += 1 223 | 224 | return ["".join(x) for x in output] 225 | 226 | def _tokenize_chinese_chars(self, text): 227 | """Adds whitespace around any CJK character.""" 228 | output = [] 229 | for char in text: 230 | cp = ord(char) 231 | if self._is_chinese_char(cp): 232 | output.append(" ") 233 | output.append(char) 234 | output.append(" ") 235 | else: 236 | output.append(char) 237 | return "".join(output) 238 | 239 | def _is_chinese_char(self, cp): 240 | """Checks whether CP is the codepoint of a CJK character.""" 241 | # This defines a "chinese character" as anything in the CJK Unicode block: 242 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 243 | # 244 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 245 | # despite its name. The modern Korean Hangul alphabet is a different block, 246 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 247 | # space-separated words, so they are not treated specially and handled 248 | # like the all of the other languages. 249 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 250 | (cp >= 0x3400 and cp <= 0x4DBF) or # 251 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 252 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 253 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 254 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 255 | (cp >= 0xF900 and cp <= 0xFAFF) or # 256 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 257 | return True 258 | 259 | return False 260 | 261 | def _clean_text(self, text): 262 | """Performs invalid character removal and whitespace cleanup on text.""" 263 | output = [] 264 | for char in text: 265 | cp = ord(char) 266 | if cp == 0 or cp == 0xfffd or _is_control(char): 267 | continue 268 | if _is_whitespace(char): 269 | output.append(" ") 270 | else: 271 | output.append(char) 272 | return "".join(output) 273 | 274 | 275 | class WordpieceTokenizer(object): 276 | """Runs WordPiece tokenization.""" 277 | 278 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 279 | self.vocab = vocab 280 | self.unk_token = unk_token 281 | self.max_input_chars_per_word = max_input_chars_per_word 282 | 283 | def tokenize(self, text): 284 | """Tokenizes a piece of text into its word pieces. 285 | 286 | This uses a greedy longest-match-first algorithm to perform tokenization 287 | using the given vocabulary. 288 | 289 | For example: 290 | input = "unaffable" 291 | output = ["un", "##aff", "##able"] 292 | 293 | Args: 294 | text: A single token or whitespace separated tokens. This should have 295 | already been passed through `BasicTokenizer`. 296 | 297 | Returns: 298 | A list of wordpiece tokens. 299 | """ 300 | 301 | output_tokens = [] 302 | for token in whitespace_tokenize(text): 303 | chars = list(token) 304 | if len(chars) > self.max_input_chars_per_word: 305 | output_tokens.append(self.unk_token) 306 | continue 307 | 308 | is_bad = False 309 | start = 0 310 | sub_tokens = [] 311 | while start < len(chars): 312 | end = len(chars) 313 | cur_substr = None 314 | while start < end: 315 | substr = "".join(chars[start:end]) 316 | if start > 0: 317 | substr = "##" + substr 318 | if substr in self.vocab: 319 | cur_substr = substr 320 | break 321 | end -= 1 322 | if cur_substr is None: 323 | is_bad = True 324 | break 325 | sub_tokens.append(cur_substr) 326 | start = end 327 | 328 | if is_bad: 329 | output_tokens.append(self.unk_token) 330 | else: 331 | output_tokens.extend(sub_tokens) 332 | return output_tokens 333 | 334 | 335 | def _is_whitespace(char): 336 | """Checks whether `chars` is a whitespace character.""" 337 | # \t, \n, and \r are technically contorl characters but we treat them 338 | # as whitespace since they are generally considered as such. 339 | if char == " " or char == "\t" or char == "\n" or char == "\r": 340 | return True 341 | cat = unicodedata.category(char) 342 | if cat == "Zs": 343 | return True 344 | return False 345 | 346 | 347 | def _is_control(char): 348 | """Checks whether `chars` is a control character.""" 349 | # These are technically control characters but we count them as whitespace 350 | # characters. 351 | if char == "\t" or char == "\n" or char == "\r": 352 | return False 353 | cat = unicodedata.category(char) 354 | if cat.startswith("C"): 355 | return True 356 | return False 357 | 358 | 359 | def _is_punctuation(char): 360 | """Checks whether `chars` is a punctuation character.""" 361 | cp = ord(char) 362 | # We treat all non-letter/number ASCII as punctuation. 363 | # Characters such as "^", "$", and "`" are not in the Unicode 364 | # Punctuation class but we treat them as punctuation anyways, for 365 | # consistency. 366 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 367 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 368 | return True 369 | cat = unicodedata.category(char) 370 | if cat.startswith("P"): 371 | return True 372 | return False 373 | -------------------------------------------------------------------------------- /run_rc.py: -------------------------------------------------------------------------------- 1 | from accelerate import Accelerator 2 | from accelerate.utils import set_seed 3 | from arguments import parse_args 4 | from dataset import FrameRCDataset 5 | import logging 6 | import math 7 | from model import BertForFrameSRL 8 | from transformers import DataCollatorWithPadding 9 | import os 10 | # from predict import post_process_function_greedy, calculate_F1_metric, post_process_function_with_max_len, save_predictions 11 | import torch 12 | from torch.optim import AdamW 13 | from torch.utils.data import DataLoader 14 | # from tqdm import tqdm 15 | import transformers 16 | from transformers import BertConfig, BertTokenizerFast, get_scheduler 17 | import json 18 | import evaluate 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | def Predict(args, accelerator, model, eval_dataset, eval_dataloader, fe2id): 24 | id2fe = {v: k for k, v in fe2id.items()} 25 | # Evaluation 26 | logger.info("***** Running Prediction *****") 27 | logger.info(f" Num examples = {len(eval_dataset)}") 28 | logger.info(f" Batch size = {args.per_device_eval_batch_size}") 29 | 30 | model.eval() 31 | 32 | all_predictions = [] 33 | all_labels = [] 34 | all_spans = [] 35 | for step, batch in enumerate(eval_dataloader): 36 | with torch.no_grad(): 37 | length = batch.pop('length') 38 | # word_ids = batch.pop('word_ids') 39 | task_id = batch.pop('task_id') 40 | labels = batch.pop('labels') 41 | outputs = model(**batch) 42 | logits = outputs.logits 43 | task_id = task_id.cpu().numpy().tolist() 44 | # word_ids = word_ids.cpu().numpy().tolist() 45 | predictions = torch.argmax(logits, dim=-2).cpu().numpy().tolist() # (B, 16) 46 | span_token_idx = batch['span_token_idx'].cpu().numpy().tolist() # (B, 16, 2) 47 | # print(task_id) 48 | # print(span_token_idx) 49 | # print(predictions) 50 | # labels = batch['labels'].cpu().numpy().tolist() 51 | # true_labels = [[label_names[l] for l in label if l != -100] for label in labels] 52 | true_predictions = [ 53 | [id2fe[p] for (p, span) in zip(prediction, span_token) if span[0] != 0] 54 | for prediction, span_token in zip(predictions, span_token_idx) 55 | ] 56 | # all_spans = [] 57 | for tid, pred in zip(task_id, true_predictions): 58 | tid = tid[0] 59 | spans = [] 60 | for p in pred: 61 | spans.append([tid, p]) 62 | all_spans += spans 63 | 64 | # precision = .0 65 | # recall = .0 66 | # F1 = (2 * precision * recall) / (precision + recall + 1e-12) 67 | # all_metrics = metric.compute(predictions=all_predictions, references=all_labels) 68 | # precision = all_metrics['overall_precision'] 69 | # recall = all_metrics['overall_recall'] 70 | # F1 = all_metrics['overall_f1'] 71 | 72 | with open('./ccl-cfn/result/task2_test.json', 'r') as f: 73 | all_spans_no_label = json.load(f) 74 | for s, s_ in zip(all_spans, all_spans_no_label): 75 | try: 76 | assert s[0] == s_[0] 77 | except: 78 | print(s, s_) 79 | exit(0) 80 | s_.append(s[1]) 81 | with open('./ccl-cfn/result/task3_test.json', 'w') as f: 82 | json.dump(all_spans_no_label, f, ensure_ascii=False) 83 | 84 | return 85 | 86 | def Evaluate(args, accelerator, model, eval_dataset, eval_dataloader, metric): 87 | # label_names = ['O', 'B', 'I'] 88 | # Evaluation 89 | logger.info("***** Running Evaluation *****") 90 | logger.info(f" Num examples = {len(eval_dataset)}") 91 | logger.info(f" Batch size = {args.per_device_eval_batch_size}") 92 | 93 | model.eval() 94 | 95 | all_predictions = [] 96 | all_labels = [] 97 | for step, batch in enumerate(eval_dataloader): 98 | with torch.no_grad(): 99 | length = batch.pop('length') 100 | # word_ids = batch.pop('word_ids') 101 | task_id = batch.pop('task_id') 102 | outputs = model(**batch) 103 | logits = outputs.logits 104 | predictions = torch.argmax(logits, dim=-2).cpu().numpy().tolist() 105 | labels = batch['labels'].cpu().numpy().tolist() 106 | true_labels = [[l for l in label if l != -100] for label in labels] 107 | true_predictions = [ 108 | [p for (p, l) in zip(prediction, label) if l != -100] 109 | for prediction, label in zip(predictions, labels) 110 | ] 111 | for tp in true_predictions: 112 | all_predictions += tp 113 | # all_predictions += true_predictions 114 | for tl in true_labels: 115 | all_labels += tl 116 | # all_labels += true_labels 117 | 118 | # precision = .0 119 | # recall = .0 120 | # F1 = (2 * precision * recall) / (precision + recall + 1e-12) 121 | all_metrics = metric.compute(predictions=all_predictions, references=all_labels) 122 | acc = all_metrics['accuracy'] 123 | # recall = all_metrics['overall_recall'] 124 | # F1 = all_metrics['overall_f1'] 125 | 126 | return acc 127 | 128 | 129 | 130 | 131 | def train(args, accelerator, model, train_dataset, train_dataloader, optimizer, lr_scheduler, eval_dataset, eval_dataloader, tokenizer): 132 | # Figure out how many steps we should save the Accelerator states 133 | if hasattr(args.checkpointing_steps, "isdigit"): 134 | checkpointing_steps = args.checkpointing_steps 135 | if args.checkpointing_steps.isdigit(): 136 | checkpointing_steps = int(args.checkpointing_steps) 137 | else: 138 | checkpointing_steps = None 139 | 140 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 141 | logger.info("***** Running training *****") 142 | logger.info(f" Num examples = {len(train_dataset)}") 143 | logger.info(f" Num Epochs = {args.num_train_epochs}") 144 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 145 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 146 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 147 | logger.info(f" Total optimization steps = {args.max_train_steps}") 148 | 149 | # Only show the progress bar once on each machine. 150 | # progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 151 | completed_steps = 0 152 | starting_epoch = 0 153 | 154 | # Potentially load in the weights and states from a previous save 155 | if args.resume_from_checkpoint: 156 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 157 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 158 | accelerator.load_state(args.resume_from_checkpoint) 159 | path = os.path.basename(args.resume_from_checkpoint) 160 | else: 161 | # Get the most recent checkpoint 162 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 163 | dirs.sort(key=os.path.getctime) 164 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 165 | # Extract `epoch_{i}` or `step_{i}` 166 | training_difference = os.path.splitext(path)[0] 167 | 168 | if "epoch" in training_difference: 169 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 170 | resume_step = None 171 | else: 172 | resume_step = int(training_difference.replace("step_", "")) 173 | starting_epoch = resume_step // len(train_dataloader) 174 | resume_step -= starting_epoch * len(train_dataloader) 175 | 176 | if args.save_best: 177 | best_acc = -1 178 | 179 | metric = evaluate.load('accuracy') 180 | for epoch in range(starting_epoch, args.num_train_epochs): 181 | model.train() 182 | total_loss = 0 183 | for step, batch in enumerate(train_dataloader): 184 | # We need to skip steps until we reach the resumed step 185 | if args.resume_from_checkpoint and epoch == starting_epoch: 186 | if resume_step is not None and step < resume_step: 187 | completed_steps += 1 188 | continue 189 | length = batch.pop('length') 190 | # if not args.loss_on_context: 191 | # context_length = batch.pop('context_length') 192 | # word_ids = batch.pop('word_ids') 193 | # FE_num = batch.pop('FE_num') 194 | task_id = batch.pop('task_id') 195 | # gt_FE_word_idx = batch.pop('gt_FE_word_idx') 196 | # gt_start_positions = batch.pop('gt_start_positions') 197 | # gt_end_positions = batch.pop('gt_end_positions') 198 | # FE_core_pts = batch.pop('FE_core_pts') 199 | try: 200 | outputs = model(**batch) 201 | except: 202 | for k, v in batch.items(): 203 | print(k, v.shape, v) 204 | loss = outputs.loss 205 | # We keep track of the loss at each epoch 206 | total_loss += loss.detach().float() 207 | if args.log_every_step is not None and step % args.log_every_step == 0: 208 | logger.info(f" | batch loss: {loss.detach().float():.6f} step = {step}") 209 | # if args.with_tracking: 210 | # total_loss += loss.detach().float() 211 | # if args.log_every_step is not None and step % args.log_every_step == 0: 212 | # logger.info(f" | batch loss: {loss.detach().float():.6f} step = {step}") 213 | loss = loss / args.gradient_accumulation_steps 214 | accelerator.backward(loss) 215 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 216 | optimizer.step() 217 | lr_scheduler.step() 218 | optimizer.zero_grad() 219 | # progress_bar.update(1) 220 | completed_steps += 1 221 | 222 | 223 | if isinstance(checkpointing_steps, int): 224 | if completed_steps % checkpointing_steps == 0: 225 | output_dir = f"step_{completed_steps }" 226 | if args.output_dir is not None: 227 | output_dir = os.path.join(args.output_dir, output_dir) 228 | accelerator.save_state(output_dir) 229 | 230 | if completed_steps >= args.max_train_steps: 231 | break 232 | 233 | # if args.with_tracking: 234 | # logger.info(f" Epoch Loss {total_loss:.6f}") 235 | logger.info(f" Epoch Loss {total_loss:.6f}") 236 | 237 | acc = Evaluate(args, accelerator, model, eval_dataset, eval_dataloader, metric) 238 | logger.info(f" Accuracy {acc:.6f}") 239 | # logger.info(f" TP: {total_TP} FP: {total_FP} FN: {total_FN}") 240 | 241 | 242 | if args.with_tracking: 243 | log = { 244 | "train_loss": total_loss, 245 | "step": completed_steps, 246 | "acc": acc, 247 | } 248 | accelerator.log(log) 249 | 250 | if args.checkpointing_steps == "epoch": 251 | output_dir = f"epoch_{epoch}" 252 | if args.output_dir is not None: 253 | output_dir = os.path.join(args.output_dir, output_dir) 254 | accelerator.save_state(output_dir) 255 | 256 | if args.output_dir is not None: 257 | # print(f'best {best_F1} current {F1}') 258 | if args.save_best: 259 | if best_acc <= acc: 260 | best_Facc = acc 261 | unwrapped_model = accelerator.unwrap_model(model) 262 | unwrapped_model.save_pretrained( 263 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 264 | ) 265 | if accelerator.is_main_process: 266 | tokenizer.save_pretrained(args.output_dir) 267 | else: 268 | unwrapped_model = accelerator.unwrap_model(model) 269 | unwrapped_model.save_pretrained( 270 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 271 | ) 272 | if accelerator.is_main_process: 273 | tokenizer.save_pretrained(args.output_dir) 274 | 275 | 276 | 277 | def main(): 278 | args = parse_args() 279 | 280 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 281 | # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment 282 | accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator() 283 | # Make one log on every process with the configuration for debugging. 284 | logging.basicConfig( 285 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 286 | datefmt="%m/%d/%Y %H:%M:%S", 287 | level=logging.INFO, 288 | ) 289 | logger.info(accelerator.state) 290 | 291 | # Setup logging, we only want one process per machine to log things on the screen. 292 | # accelerator.is_local_main_process is only True for one process per machine. 293 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) 294 | if accelerator.is_local_main_process: 295 | # datasets.utils.logging.set_verbosity_warning() 296 | transformers.utils.logging.set_verbosity_info() 297 | else: 298 | # datasets.utils.logging.set_verbosity_error() 299 | transformers.utils.logging.set_verbosity_error() 300 | 301 | # If passed along, set the training seed now. 302 | if args.seed is not None: 303 | set_seed(args.seed) 304 | 305 | if accelerator.is_main_process and args.output_dir is not None and not os.path.exists(args.output_dir): 306 | os.mkdir(args.output_dir) 307 | accelerator.wait_for_everyone() 308 | 309 | data_files = {} 310 | if args.train_file is not None: 311 | data_files["train"] = args.train_file 312 | if args.validation_file is not None: 313 | data_files["validation"] = args.validation_file 314 | if args.test_file is not None: 315 | data_files["test"] = args.test_file 316 | 317 | if args.config_name: 318 | config = BertConfig.from_pretrained(args.config_name) 319 | elif args.model_name_or_path: 320 | config = BertConfig.from_pretrained(args.model_name_or_path) 321 | config.update({'FE_pooling':args.FE_pooling}) 322 | 323 | if args.tokenizer_name: 324 | tokenizer = BertTokenizerFast.from_pretrained(args.tokenizer_name, use_fast=True) 325 | elif args.model_name_or_path: 326 | tokenizer = BertTokenizerFast.from_pretrained(args.model_name_or_path, use_fast=True) 327 | tokenizer.add_tokens(['', '', '', '', '', '']) 328 | with open(args.frame_data, 'r') as f: 329 | data = json.load(f) 330 | fe2id = {} 331 | cnt = 0 332 | for frame in data: 333 | for fe in frame['fes']: 334 | name = fe['fe_name'] 335 | if name not in fe2id: 336 | fe2id[name] = cnt 337 | cnt += 1 338 | config.num_labels = len(fe2id) 339 | 340 | if args.model_name_or_path: 341 | model = BertForFrameSRL.from_pretrained( 342 | args.model_name_or_path, 343 | from_tf=bool(".ckpt" in args.model_name_or_path), 344 | config=config, 345 | ) 346 | else: 347 | logger.info("Training new model from scratch") 348 | model = BertForFrameSRL.from_config(config) 349 | model.resize_token_embeddings(len(tokenizer)) 350 | 351 | # frame_data = {} 352 | # with open('./ccl-cfn/frame_data_def.json', 'r') as f: 353 | # frame_lines = json.load(f) 354 | # for line in frame_lines: 355 | # frame_data[line["frame_name"]] = line 356 | 357 | if "train" not in data_files: 358 | raise ValueError("--do_train requires a train dataset") 359 | with accelerator.main_process_first(): 360 | train_dataset = FrameRCDataset(data_files['train'], tokenizer, fe2id) 361 | if args.max_train_samples is not None: 362 | train_dataset = train_dataset.subset(range(args.max_train_samples)) 363 | 364 | if "validation" not in data_files: 365 | raise ValueError("--do_train requires a train dataset") 366 | with accelerator.main_process_first(): 367 | eval_dataset = FrameRCDataset(data_files['validation'], tokenizer, fe2id) 368 | if args.max_eval_samples is not None: 369 | eval_dataset = eval_dataset.subset(range(args.max_eval_samples)) 370 | 371 | if args.do_predict: 372 | test_dataset = FrameRCDataset(data_files['test'], tokenizer, fe2id, args.task1_res, args.task2_res) 373 | if args.max_predict_samples is not None: 374 | test_dataset = test_dataset.subset(range(args.max_predict_samples)) 375 | 376 | # data_collator = DataCollatorForFrameAI(tokenizer=tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 377 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 378 | 379 | train_dataloader = DataLoader( 380 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 381 | ) 382 | 383 | eval_dataloader = DataLoader( 384 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 385 | ) 386 | 387 | test_dataloader = DataLoader( 388 | test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 389 | ) 390 | # Optimizer 391 | # Split weights in two groups, one with weight decay and the other not. 392 | no_decay = ["bias", "LayerNorm.weight"] 393 | optimizer_grouped_parameters = [ 394 | { 395 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 396 | "weight_decay": args.weight_decay, 397 | }, 398 | { 399 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 400 | "weight_decay": 0.0, 401 | }, 402 | ] 403 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 404 | 405 | # Scheduler and math around the number of training steps. 406 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 407 | if args.max_train_steps is None: 408 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 409 | else: 410 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 411 | 412 | lr_scheduler = get_scheduler( 413 | name=args.lr_scheduler_type, 414 | optimizer=optimizer, 415 | num_warmup_steps=args.num_warmup_steps, 416 | num_training_steps=args.max_train_steps, 417 | ) 418 | 419 | # Prepare everything with our `accelerator`. 420 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 421 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 422 | ) 423 | 424 | 425 | # We need to initialize the trackers we use, and also store our configuration 426 | if args.with_tracking: 427 | experiment_config = vars(args) 428 | # TensorBoard cannot log Enums, need the raw value 429 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 430 | accelerator.init_trackers("FrameSRL", experiment_config) 431 | # evaluate(args, accelerator, model, eval_dataset, eval_dataloader) 432 | train(args, accelerator, model, train_dataset, train_dataloader, optimizer, lr_scheduler, eval_dataset, eval_dataloader, tokenizer) 433 | 434 | if args.do_predict: 435 | model = BertForFrameSRL.from_pretrained( 436 | args.output_dir, 437 | from_tf=bool(".ckpt" in args.model_name_or_path), 438 | config=config, 439 | ) 440 | model, test_dataloader = accelerator.prepare(model, test_dataloader) 441 | Predict(args, accelerator, model, test_dataset, test_dataloader, fe2id) 442 | 443 | if __name__ == '__main__': 444 | main() -------------------------------------------------------------------------------- /run_ai.py: -------------------------------------------------------------------------------- 1 | from accelerate import Accelerator 2 | from accelerate.utils import set_seed 3 | from arguments import parse_args 4 | from dataset import FrameAIDataset #, DataCollatorForFrameAI 5 | import logging 6 | import math 7 | # from model import BertForFrameSRL 8 | from transformers import BertForTokenClassification, DataCollatorForTokenClassification 9 | import os 10 | # from predict import post_process_function_greedy, calculate_F1_metric, post_process_function_with_max_len, save_predictions 11 | import torch 12 | from torch.optim import AdamW 13 | from torch.utils.data import DataLoader 14 | # from tqdm import tqdm 15 | import transformers 16 | from transformers import BertConfig, BertTokenizerFast, get_scheduler 17 | import json 18 | import evaluate 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | def Predict(args, accelerator, model, eval_dataset, eval_dataloader): 24 | label_names = ['O', 'B', 'I'] 25 | # Evaluation 26 | logger.info("***** Running Prediction *****") 27 | logger.info(f" Num examples = {len(eval_dataset)}") 28 | logger.info(f" Batch size = {args.per_device_eval_batch_size}") 29 | 30 | model.eval() 31 | 32 | all_predictions = [] 33 | all_labels = [] 34 | all_spans = [] 35 | for step, batch in enumerate(eval_dataloader): 36 | with torch.no_grad(): 37 | length = batch.pop('length') 38 | word_ids = batch.pop('word_ids') 39 | task_id = batch.pop('task_id') 40 | outputs = model(**batch) 41 | logits = outputs.logits 42 | task_id = task_id.cpu().numpy().tolist() 43 | word_ids = word_ids.cpu().numpy().tolist() 44 | predictions = torch.argmax(logits, dim=-1).cpu().numpy().tolist() 45 | labels = batch['labels'].cpu().numpy().tolist() 46 | # true_labels = [[label_names[l] for l in label if l != -100] for label in labels] 47 | true_predictions = [ 48 | [label_names[p] for (p, l) in zip(prediction, label) if l != -100] 49 | for prediction, label in zip(predictions, labels) 50 | ] 51 | # all_spans = [] 52 | for tid, word_id, pred in zip(task_id, word_ids, true_predictions): 53 | tid = tid[0] 54 | word_id = word_id[1:len(pred)+1] 55 | tags = ['O'] * (word_id[-1] + 1) 56 | for i, p in enumerate(pred): 57 | if p == 'I': 58 | tags[word_id[i]] = 'I' 59 | if p == 'B': 60 | tags[word_id[i]] = 'B' 61 | # print(tags) 62 | spans = [] 63 | span_start = -1 64 | for i, tag in enumerate(tags): 65 | if tag == 'B' and span_start == -1: 66 | span_start = i 67 | elif tag == 'B' and span_start != -1: 68 | spans.append((tid, span_start, i-1)) 69 | span_start = i 70 | elif tag == 'O' and span_start != -1: 71 | spans.append((tid, span_start, i-1)) 72 | span_start = -1 73 | elif i == len(tags) - 1 and span_start != -1: 74 | spans.append((tid, span_start, i)) 75 | span_start = -1 76 | # print(spans) 77 | # exit(0) 78 | all_spans.extend(spans) 79 | 80 | # precision = .0 81 | # recall = .0 82 | # F1 = (2 * precision * recall) / (precision + recall + 1e-12) 83 | # all_metrics = metric.compute(predictions=all_predictions, references=all_labels) 84 | # precision = all_metrics['overall_precision'] 85 | # recall = all_metrics['overall_recall'] 86 | # F1 = all_metrics['overall_f1'] 87 | 88 | with open('./ccl-cfn/result/task2_test.json', 'w') as f: 89 | json.dump(all_spans, f) 90 | 91 | return 92 | 93 | def Evaluate(args, accelerator, model, eval_dataset, eval_dataloader, metric): 94 | label_names = ['O', 'B', 'I'] 95 | # Evaluation 96 | logger.info("***** Running Evaluation *****") 97 | logger.info(f" Num examples = {len(eval_dataset)}") 98 | logger.info(f" Batch size = {args.per_device_eval_batch_size}") 99 | 100 | model.eval() 101 | 102 | all_predictions = [] 103 | all_labels = [] 104 | for step, batch in enumerate(eval_dataloader): 105 | with torch.no_grad(): 106 | length = batch.pop('length') 107 | word_ids = batch.pop('word_ids') 108 | task_id = batch.pop('task_id') 109 | outputs = model(**batch) 110 | logits = outputs.logits 111 | predictions = torch.argmax(logits, dim=-1).cpu().numpy().tolist() 112 | labels = batch['labels'].cpu().numpy().tolist() 113 | true_labels = [[label_names[l] for l in label if l != -100] for label in labels] 114 | true_predictions = [ 115 | [label_names[p] for (p, l) in zip(prediction, label) if l != -100] 116 | for prediction, label in zip(predictions, labels) 117 | ] 118 | all_predictions.extend(true_predictions) 119 | all_labels.extend(true_labels) 120 | 121 | precision = .0 122 | recall = .0 123 | F1 = (2 * precision * recall) / (precision + recall + 1e-12) 124 | all_metrics = metric.compute(predictions=all_predictions, references=all_labels) 125 | precision = all_metrics['overall_precision'] 126 | recall = all_metrics['overall_recall'] 127 | F1 = all_metrics['overall_f1'] 128 | 129 | return precision, recall, F1 130 | 131 | 132 | 133 | 134 | def train(args, accelerator, model, train_dataset, train_dataloader, optimizer, lr_scheduler, eval_dataset, eval_dataloader, tokenizer): 135 | # Figure out how many steps we should save the Accelerator states 136 | if hasattr(args.checkpointing_steps, "isdigit"): 137 | checkpointing_steps = args.checkpointing_steps 138 | if args.checkpointing_steps.isdigit(): 139 | checkpointing_steps = int(args.checkpointing_steps) 140 | else: 141 | checkpointing_steps = None 142 | 143 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 144 | logger.info("***** Running training *****") 145 | logger.info(f" Num examples = {len(train_dataset)}") 146 | logger.info(f" Num Epochs = {args.num_train_epochs}") 147 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 148 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 149 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 150 | logger.info(f" Total optimization steps = {args.max_train_steps}") 151 | 152 | # Only show the progress bar once on each machine. 153 | # progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 154 | completed_steps = 0 155 | starting_epoch = 0 156 | 157 | # Potentially load in the weights and states from a previous save 158 | if args.resume_from_checkpoint: 159 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 160 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 161 | accelerator.load_state(args.resume_from_checkpoint) 162 | path = os.path.basename(args.resume_from_checkpoint) 163 | else: 164 | # Get the most recent checkpoint 165 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 166 | dirs.sort(key=os.path.getctime) 167 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 168 | # Extract `epoch_{i}` or `step_{i}` 169 | training_difference = os.path.splitext(path)[0] 170 | 171 | if "epoch" in training_difference: 172 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 173 | resume_step = None 174 | else: 175 | resume_step = int(training_difference.replace("step_", "")) 176 | starting_epoch = resume_step // len(train_dataloader) 177 | resume_step -= starting_epoch * len(train_dataloader) 178 | 179 | if args.save_best: 180 | best_F1 = -1 181 | 182 | metric = evaluate.load('seqeval') 183 | for epoch in range(starting_epoch, args.num_train_epochs): 184 | model.train() 185 | total_loss = 0 186 | # if args.with_tracking: 187 | # total_loss = 0 188 | # if epoch == args.num_train_epochs1: 189 | # train_dataset = FrameSRLDataset(args.train_file2, tokenizer) 190 | # data_collator = DataCollatorForFrameSRL(tokenizer=tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 191 | 192 | # train_dataloader = DataLoader( 193 | # train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 194 | # ) 195 | 196 | # train_dataloader = accelerator.prepare(train_dataloader) 197 | for step, batch in enumerate(train_dataloader): 198 | # We need to skip steps until we reach the resumed step 199 | if args.resume_from_checkpoint and epoch == starting_epoch: 200 | if resume_step is not None and step < resume_step: 201 | completed_steps += 1 202 | continue 203 | length = batch.pop('length') 204 | # if not args.loss_on_context: 205 | # context_length = batch.pop('context_length') 206 | word_ids = batch.pop('word_ids') 207 | # FE_num = batch.pop('FE_num') 208 | task_id = batch.pop('task_id') 209 | # gt_FE_word_idx = batch.pop('gt_FE_word_idx') 210 | # gt_start_positions = batch.pop('gt_start_positions') 211 | # gt_end_positions = batch.pop('gt_end_positions') 212 | # FE_core_pts = batch.pop('FE_core_pts') 213 | try: 214 | outputs = model(**batch) 215 | except: 216 | for k, v in batch.items(): 217 | print(k, v.shape, v) 218 | loss = outputs.loss 219 | # We keep track of the loss at each epoch 220 | total_loss += loss.detach().float() 221 | if args.log_every_step is not None and step % args.log_every_step == 0: 222 | logger.info(f" | batch loss: {loss.detach().float():.6f} step = {step}") 223 | # if args.with_tracking: 224 | # total_loss += loss.detach().float() 225 | # if args.log_every_step is not None and step % args.log_every_step == 0: 226 | # logger.info(f" | batch loss: {loss.detach().float():.6f} step = {step}") 227 | loss = loss / args.gradient_accumulation_steps 228 | accelerator.backward(loss) 229 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 230 | optimizer.step() 231 | lr_scheduler.step() 232 | optimizer.zero_grad() 233 | # progress_bar.update(1) 234 | completed_steps += 1 235 | 236 | 237 | if isinstance(checkpointing_steps, int): 238 | if completed_steps % checkpointing_steps == 0: 239 | output_dir = f"step_{completed_steps }" 240 | if args.output_dir is not None: 241 | output_dir = os.path.join(args.output_dir, output_dir) 242 | accelerator.save_state(output_dir) 243 | 244 | if completed_steps >= args.max_train_steps: 245 | break 246 | 247 | # if args.with_tracking: 248 | # logger.info(f" Epoch Loss {total_loss:.6f}") 249 | logger.info(f" Epoch Loss {total_loss:.6f}") 250 | 251 | precision, recall, F1 = Evaluate(args, accelerator, model, eval_dataset, eval_dataloader, metric) 252 | logger.info(f" Precision: {precision:.6f}, Recall: {recall:.6f}, F1: {F1:.6f}") 253 | # logger.info(f" TP: {total_TP} FP: {total_FP} FN: {total_FN}") 254 | 255 | 256 | if args.with_tracking: 257 | log = { 258 | "train_loss": total_loss, 259 | "step": completed_steps, 260 | "F1": F1, 261 | } 262 | accelerator.log(log) 263 | 264 | if args.checkpointing_steps == "epoch": 265 | output_dir = f"epoch_{epoch}" 266 | if args.output_dir is not None: 267 | output_dir = os.path.join(args.output_dir, output_dir) 268 | accelerator.save_state(output_dir) 269 | 270 | if args.output_dir is not None: 271 | # print(f'best {best_F1} current {F1}') 272 | if args.save_best: 273 | if best_F1 <= F1: 274 | best_F1 = F1 275 | unwrapped_model = accelerator.unwrap_model(model) 276 | unwrapped_model.save_pretrained( 277 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 278 | ) 279 | if accelerator.is_main_process: 280 | tokenizer.save_pretrained(args.output_dir) 281 | else: 282 | unwrapped_model = accelerator.unwrap_model(model) 283 | unwrapped_model.save_pretrained( 284 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 285 | ) 286 | if accelerator.is_main_process: 287 | tokenizer.save_pretrained(args.output_dir) 288 | 289 | 290 | 291 | def main(): 292 | args = parse_args() 293 | 294 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 295 | # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment 296 | accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator() 297 | # Make one log on every process with the configuration for debugging. 298 | logging.basicConfig( 299 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 300 | datefmt="%m/%d/%Y %H:%M:%S", 301 | level=logging.INFO, 302 | ) 303 | logger.info(accelerator.state) 304 | 305 | # Setup logging, we only want one process per machine to log things on the screen. 306 | # accelerator.is_local_main_process is only True for one process per machine. 307 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) 308 | if accelerator.is_local_main_process: 309 | # datasets.utils.logging.set_verbosity_warning() 310 | transformers.utils.logging.set_verbosity_info() 311 | else: 312 | # datasets.utils.logging.set_verbosity_error() 313 | transformers.utils.logging.set_verbosity_error() 314 | 315 | # If passed along, set the training seed now. 316 | if args.seed is not None: 317 | set_seed(args.seed) 318 | 319 | if accelerator.is_main_process and args.output_dir is not None and not os.path.exists(args.output_dir): 320 | os.mkdir(args.output_dir) 321 | accelerator.wait_for_everyone() 322 | 323 | data_files = {} 324 | if args.train_file is not None: 325 | data_files["train"] = args.train_file 326 | if args.validation_file is not None: 327 | data_files["validation"] = args.validation_file 328 | if args.test_file is not None: 329 | data_files["test"] = args.test_file 330 | 331 | if args.config_name: 332 | config = BertConfig.from_pretrained(args.config_name) 333 | elif args.model_name_or_path: 334 | config = BertConfig.from_pretrained(args.model_name_or_path) 335 | config.update({'FE_pooling':args.FE_pooling}) 336 | 337 | if args.tokenizer_name: 338 | tokenizer = BertTokenizerFast.from_pretrained(args.tokenizer_name, use_fast=True) 339 | elif args.model_name_or_path: 340 | tokenizer = BertTokenizerFast.from_pretrained(args.model_name_or_path, use_fast=True) 341 | # tokenizer.add_tokens(['', '', '', '', '', '']) 342 | tokenizer.add_tokens(['', '']) 343 | label_names = ['O', 'B', 'I'] 344 | id2label = {i: label for i, label in enumerate(label_names)} 345 | label2id = {v: k for k, v in id2label.items()} 346 | 347 | config.id2label = id2label 348 | config.label2id = label2id 349 | config.num_labels = len(label_names) 350 | 351 | if args.model_name_or_path: 352 | model = BertForTokenClassification.from_pretrained( 353 | args.model_name_or_path, 354 | from_tf=bool(".ckpt" in args.model_name_or_path), 355 | config=config, 356 | ) 357 | else: 358 | logger.info("Training new model from scratch") 359 | model = BertForTokenClassification.from_config(config) 360 | model.resize_token_embeddings(len(tokenizer)) 361 | 362 | # frame_data = {} 363 | # with open('./ccl-cfn/frame_data_def.json', 'r') as f: 364 | # frame_lines = json.load(f) 365 | # for line in frame_lines: 366 | # frame_data[line["frame_name"]] = line 367 | 368 | if "train" not in data_files: 369 | raise ValueError("--do_train requires a train dataset") 370 | with accelerator.main_process_first(): 371 | train_dataset = FrameAIDataset(data_files['train'], tokenizer) 372 | if args.max_train_samples is not None: 373 | train_dataset = train_dataset.subset(range(args.max_train_samples)) 374 | 375 | if "validation" not in data_files: 376 | raise ValueError("--do_train requires a train dataset") 377 | with accelerator.main_process_first(): 378 | eval_dataset = FrameAIDataset(data_files['validation'], tokenizer) 379 | if args.max_eval_samples is not None: 380 | eval_dataset = eval_dataset.subset(range(args.max_eval_samples)) 381 | 382 | if args.do_predict: 383 | test_dataset = FrameAIDataset(data_files['test'], tokenizer) 384 | if args.max_predict_samples is not None: 385 | test_dataset = test_dataset.subset(range(args.max_predict_samples)) 386 | 387 | # data_collator = DataCollatorForFrameAI(tokenizer=tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 388 | data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 389 | 390 | train_dataloader = DataLoader( 391 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 392 | ) 393 | 394 | eval_dataloader = DataLoader( 395 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 396 | ) 397 | 398 | test_dataloader = DataLoader( 399 | test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 400 | ) 401 | # Optimizer 402 | # Split weights in two groups, one with weight decay and the other not. 403 | no_decay = ["bias", "LayerNorm.weight"] 404 | optimizer_grouped_parameters = [ 405 | { 406 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 407 | "weight_decay": args.weight_decay, 408 | }, 409 | { 410 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 411 | "weight_decay": 0.0, 412 | }, 413 | ] 414 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 415 | 416 | # Scheduler and math around the number of training steps. 417 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 418 | if args.max_train_steps is None: 419 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 420 | else: 421 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 422 | 423 | lr_scheduler = get_scheduler( 424 | name=args.lr_scheduler_type, 425 | optimizer=optimizer, 426 | num_warmup_steps=args.num_warmup_steps, 427 | num_training_steps=args.max_train_steps, 428 | ) 429 | 430 | # Prepare everything with our `accelerator`. 431 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 432 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 433 | ) 434 | 435 | 436 | # We need to initialize the trackers we use, and also store our configuration 437 | if args.with_tracking: 438 | experiment_config = vars(args) 439 | # TensorBoard cannot log Enums, need the raw value 440 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 441 | accelerator.init_trackers("FrameSRL", experiment_config) 442 | # evaluate(args, accelerator, model, eval_dataset, eval_dataloader) 443 | train(args, accelerator, model, train_dataset, train_dataloader, optimizer, lr_scheduler, eval_dataset, eval_dataloader, tokenizer) 444 | 445 | if args.do_predict: 446 | model = BertForTokenClassification.from_pretrained( 447 | args.output_dir, 448 | from_tf=bool(".ckpt" in args.model_name_or_path), 449 | config=config, 450 | ) 451 | model, test_dataloader = accelerator.prepare(model, test_dataloader) 452 | Predict(args, accelerator, model, test_dataset, test_dataloader) 453 | 454 | if __name__ == '__main__': 455 | main() -------------------------------------------------------------------------------- /until.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | 9 | 10 | class Example: 11 | def __init__(self, idx, lex_name, sentence, lex_pos, gloss, label): 12 | self.task_id = idx, 13 | self.lex_name = lex_name, 14 | self.sentence = sentence, 15 | self.lex_pos = lex_pos, 16 | self.gloss = gloss, 17 | self.label = label 18 | 19 | 20 | class InputFeatures1(object): 21 | """A single set of features of data.""" 22 | 23 | def __init__(self, input_ids, mask_array, graph_node, input_mask, segment_ids, label_id): 24 | self.input_ids = input_ids 25 | self.input_mask = input_mask 26 | self.segment_ids = segment_ids 27 | self.mask_array = mask_array 28 | self.graph_node = graph_node 29 | self.label_id = label_id 30 | 31 | 32 | class InputFeatures(object): 33 | """A single set of features of data.""" 34 | 35 | def __init__(self, task_id, input_ids, mask_array, input_mask, segment_ids, label_id): 36 | self.task_id = task_id 37 | self.input_ids = input_ids 38 | self.input_mask = input_mask 39 | self.segment_ids = segment_ids 40 | self.mask_array = mask_array 41 | self.label_id = label_id 42 | 43 | 44 | class DataProcessor(object): 45 | """Base class for data converters for sequence classification data sets.""" 46 | 47 | def get_train_examples(self, data_dir): 48 | """Gets a collection of `InputExample`s for the train set.""" 49 | raise NotImplementedError() 50 | 51 | def get_dev_examples(self, data_dir): 52 | """Gets a collection of `InputExample`s for the dev set.""" 53 | raise NotImplementedError() 54 | 55 | def get_test_examples(self, data_dir): 56 | """Gets a collection of `InputExample`s for the test set.""" 57 | raise NotImplementedError() 58 | 59 | @classmethod 60 | def _read_txt(cls, input_file): 61 | """Reads a tab separated value file.""" 62 | with open(input_file, "r") as f: 63 | data = f.readlines() 64 | lines = [] 65 | for line in data: 66 | lines.append(line) 67 | return lines 68 | 69 | 70 | class Processor(DataProcessor): 71 | def get_train_examples(self, data_dir, f2idx): 72 | """See base class.""" 73 | train_data = [] 74 | data = json.load(open(data_dir, encoding="utf8")) 75 | train_data = [{ 76 | "task_id": item['sentence_id'], 77 | "frame": item['frame'], 78 | "lex": item['text'][ item["target"]["start"]: item["target"]["end"] + 1 ], 79 | "sentence": item['text'], 80 | "label": f2idx.get(item['frame'], 0), 81 | "gloss": "来了", 82 | } for item in data] 83 | 84 | return self._create_examples(train_data, "train") 85 | 86 | def get_dev_examples(self, data_dir, f2idx): 87 | """See base class.""" 88 | dev_data = [] 89 | data = json.load(open(data_dir, encoding="utf8")) 90 | dev_data = [{ 91 | "task_id": item['sentence_id'], 92 | "frame": item['frame'], 93 | "lex": item['text'][ item["target"]["start"]: item["target"]["end"] + 1 ], 94 | "sentence": item['text'], 95 | "label": f2idx.get(item['frame'], 0), 96 | "gloss": "来了", 97 | } for item in data] 98 | return self._create_examples(dev_data, "dev") 99 | 100 | def get_test_examples(self, data_dir, f2idx): 101 | """See base class.""" 102 | train_data = [] 103 | data = json.load(open(data_dir, encoding="utf8")) 104 | train_data = [{ 105 | "task_id": item['sentence_id'], 106 | "frame": item['frame'], 107 | "lex": item['text'][ item["target"]["start"]: item["target"]["end"] + 1 ], 108 | "sentence": item['text'], 109 | "label": f2idx.get(item['frame'], 0), 110 | "gloss": "来了", 111 | } for item in data] 112 | return self._create_examples(train_data, "test") 113 | 114 | def _create_examples(self, lines, set_type): 115 | """Creates examples for the training and dev sets.""" 116 | examples = [] 117 | # max_sen_length = 0 118 | for (i, line) in enumerate(lines): 119 | """ 120 | line = {"label":1 ,"lex":" ",sentence:" "} 121 | """ 122 | # if set_type == 'train' and i >=1000: break 123 | # if set_type == 'dev' and i>=500: break 124 | guid = "%s-%s" % (set_type, i) 125 | lex_name = line['lex'] 126 | # features = torch.tensor(line['features'], dtype=torch.double) 127 | sentence = line['sentence'] 128 | try: 129 | # gloss = line['gloss'] 130 | # 131 | # lex_pos = sentence.split().index(lex_name) 132 | lex_pos = sentence.index(lex_name) 133 | # lex_pos = 0 134 | except Exception as e1: 135 | print(e1) 136 | print(sentence) 137 | print(lex_name) 138 | continue 139 | label = line['label'] 140 | if not isinstance(label, int): 141 | label = label.replace('\n', '').replace('\t', '').replace(' ', '') 142 | 143 | if i % 1000 == 0: 144 | print(i) 145 | print("guid=", guid) 146 | print("lex_name=", lex_name) 147 | print("sentence=", sentence) 148 | print("lex_pos=", lex_pos) 149 | 150 | examples.append( 151 | Example(idx=line['task_id'], lex_name=lex_name, sentence=sentence, lex_pos=lex_pos, gloss=None, label=label)) 152 | # print("max_length", max_sen_length) 153 | # print(len(lines)) 154 | return examples 155 | 156 | 157 | def convert_examples_to_features(examples, max_seq_length, tokenizer): 158 | """Loads a data file into a list of `InputBatch`s.""" 159 | 160 | features = [] 161 | for (ex_index, example) in enumerate(tqdm(examples)): 162 | if ex_index % 10000 == 0: 163 | print("Writing example %d of %d" % (ex_index, len(examples))) 164 | 165 | orig_tokens = example.sentence 166 | lex_pos = example.lex_pos[0] 167 | # lex_len = len(example.lex_name[0]) 168 | 169 | # orig_tokens_list = list(orig_tokens) 170 | # orig_tokens_list.insert(lex_pos, '[CLS]') 171 | # orig_tokens_list.insert(lex_pos+lex_len+1, '[SEP]') 172 | # orig_tokens = "".join(orig_tokens_list) 173 | if lex_pos > 254: 174 | continue 175 | bert_tokens = [] 176 | # token = tokenizer.tokenize(orig_tokens[0]) 177 | bert_tokens.extend(tokenizer.tokenize(orig_tokens[0])) 178 | # lex_pos = 0 179 | # lex_len = 0 180 | # pos = 0 181 | lex = example.lex_name[0] 182 | lex_len = len(lex) 183 | # orig_tokens = orig_tokens[0].split(' ') 184 | # for idx, value in enumerate(orig_tokens): 185 | # token = tokenizer.tokenize(value) 186 | # 187 | # if value in lex: 188 | # if pos > lex_pos == 0: 189 | # lex_pos = pos 190 | # if pos == lex_pos+lex_len: 191 | # lex_len += len(token) 192 | # pos += len(token) 193 | # bert_tokens.extend(token) 194 | if len(bert_tokens) > max_seq_length - 2: 195 | bert_tokens = bert_tokens[:(max_seq_length - 2)] 196 | 197 | # bert_tokens.insert(lex_pos, '[STAR]') 198 | # bert_tokens.insert(lex_pos + lex_len + 1, '[ENDD]') 199 | # print("lex:{} bert_token_lex:{}".format(example.lex_name[0], bert_tokens[lex_pos+1:lex_pos+lex_len+1])) 200 | bert_tokens = ["[CLS]"] + bert_tokens + ["[SEP]"] 201 | segment_ids = [0] * len(bert_tokens) 202 | input_ids = tokenizer.convert_tokens_to_ids(bert_tokens) 203 | input_mask = [1] * len(input_ids) 204 | mask_array = np.zeros(256) 205 | mask_array[lex_pos+1:lex_pos+lex_len+1] = 1 206 | # mask_array = np.array([lex_pos, lex_len]) 207 | padding = [0] * (max_seq_length - len(input_ids)) 208 | input_ids += padding 209 | input_mask += padding 210 | segment_ids += padding 211 | try: 212 | assert len(input_ids) == max_seq_length 213 | assert len(input_mask) == max_seq_length 214 | assert len(segment_ids) == max_seq_length 215 | except: 216 | print(len(input_ids)) 217 | 218 | label_id = example.label 219 | 220 | if ex_index < 5: 221 | print("*** Example ***") 222 | print("task_id: %d" % (example.task_id)) 223 | print("lex:{}".format(example.lex_name[0])) 224 | print("tokens: %s" % " ".join([str(x) for x in bert_tokens])) 225 | print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 226 | print("input_mask: %s" % " ".join([str(x) for x in input_mask])) 227 | print("mask_array: %s" % " ".join([str(x) for x in mask_array])) 228 | print("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 229 | print("label: %s (id = %d)" % (example.label, label_id)) 230 | 231 | features.append( 232 | InputFeatures( 233 | task_id=example.task_id, 234 | input_ids=input_ids, 235 | input_mask=input_mask, 236 | segment_ids=segment_ids, 237 | label_id=label_id, 238 | mask_array=mask_array, 239 | )) 240 | return features 241 | 242 | 243 | class Processor_chinese(DataProcessor): 244 | def get_train_examples(self, data_dir): 245 | """See base class.""" 246 | train_data = [] 247 | with open(data_dir) as f: 248 | data = f.readlines() 249 | for line in data: 250 | train_data.append(eval(line)) 251 | return self._create_examples(train_data, "train") 252 | 253 | def get_dev_examples(self, data_dir): 254 | """See base class.""" 255 | dev_data = [] 256 | with open(data_dir) as f: 257 | data = f.readlines() 258 | for line in data: 259 | dev_data.append(eval(line)) 260 | return self._create_examples(dev_data, "dev") 261 | 262 | def get_test_examples(self, data_dir): 263 | """See base class.""" 264 | train_data = [] 265 | with open(data_dir) as f: 266 | data = f.readlines() 267 | for line in data: 268 | train_data.append(eval(line)) 269 | return self._create_examples(train_data, "test") 270 | 271 | def _create_examples(self, lines, set_type): 272 | """Creates examples for the training and dev sets.""" 273 | examples = [] 274 | # max_sen_length = 0 275 | for (i, line) in enumerate(lines): 276 | """ 277 | line = {"label":1 ,"lex":" ",sentence:" "} 278 | """ 279 | # if set_type == 'train' and i >=1000: break 280 | # if set_type == 'dev' and i>=500: break 281 | guid = "%s-%s" % (set_type, i) 282 | lex_name = line['lex'] 283 | # features = torch.tensor(line['features'], dtype=torch.double) 284 | sentence = line['sentence'] 285 | try: 286 | # gloss = line['gloss'] 287 | # 288 | # lex_pos = sentence.split().index(lex_name) 289 | lex_pos = sentence.index(lex_name) 290 | except Exception as e1: 291 | print(e1) 292 | print(sentence) 293 | print(lex_name) 294 | continue 295 | label = line['label'] 296 | if not isinstance(label, int): 297 | label = label.replace('\n', '').replace('\t', '').replace(' ', '') 298 | 299 | if i % 1000 == 0: 300 | print(i) 301 | print("guid=", guid) 302 | print("lex_name=", lex_name) 303 | print("sentence=", sentence) 304 | print("lex_pos=", lex_pos) 305 | 306 | examples.append( 307 | Example(idx=i, lex_name=lex_name, sentence=sentence, lex_pos=lex_pos, gloss=None, label=label)) 308 | # print("max_length", max_sen_length) 309 | # print(len(lines)) 310 | return examples 311 | 312 | 313 | q_f = open('./ccl-cfn/q.json', 'a', encoding='utf8') 314 | 315 | 316 | def convert_examples_to_features_graph(examples, max_seq_length, tokenizer, ltp): 317 | """Loads a data file into a list of `InputBatch`s.""" 318 | 319 | features = [] 320 | for (ex_index, example) in enumerate(tqdm(examples)): 321 | if ex_index % 10000 == 0: 322 | print("Writing example %d of %d" % (ex_index, len(examples))) 323 | 324 | orig_tokens = example.sentence 325 | seg, hidden = ltp.seg([orig_tokens[0]]) 326 | srl = ltp.srl(hidden) 327 | try: 328 | srl_lex_pos = seg[0].index(example.lex_name[0]) 329 | except: 330 | print(example.lex_name[0]) 331 | q_f.write("lex:{},sentence:{}".format(example.lex_name[0], example.sentence[0])) 332 | q_f.write("\n") 333 | continue 334 | graph_node = [] 335 | for detail in srl[0][srl_lex_pos]: 336 | # print(''.join(seg[0][detail[1]:detail[2] + 1])) 337 | # print(orig_tokens.index(''.join(seg[0][detail[1]:detail[2] + 1]))) 338 | graph_node.append(orig_tokens[0].index(''.join(seg[0][detail[1]:detail[2] + 1]))) 339 | graph_node.append(orig_tokens[0].index(''.join(seg[0][detail[1]:detail[2] + 1])) + len( 340 | ''.join(seg[0][detail[1]:detail[2] + 1]))) 341 | 342 | lex_pos = example.lex_pos[0] 343 | lex_len = len(example.lex_name[0]) 344 | # orig_tokens_list = list(orig_tokens) 345 | # orig_tokens_list.insert(lex_pos, '[CLS]') 346 | # orig_tokens_list.insert(lex_pos+lex_len+1, '[SEP]') 347 | # orig_tokens = "".join(orig_tokens_list) 348 | if lex_pos > 254: 349 | continue 350 | bert_tokens = [] 351 | # token = tokenizer.tokenize(orig_tokens[0]) 352 | bert_tokens.extend(tokenizer.tokenize(orig_tokens[0])) 353 | if len(bert_tokens) > max_seq_length - 4: 354 | bert_tokens = bert_tokens[:(max_seq_length - 4)] 355 | 356 | bert_tokens.insert(lex_pos, '[CLS]') 357 | bert_tokens.insert(lex_pos + lex_len + 1, '[SEP]') 358 | bert_tokens = ["[CLS]"] + bert_tokens + ["[SEP]"] 359 | segment_ids = [0] * len(bert_tokens) 360 | input_ids = tokenizer.convert_tokens_to_ids(bert_tokens) 361 | input_mask = [1] * len(input_ids) 362 | mask_array = np.zeros(256) 363 | mask_array[lex_pos + 1] = 1 364 | # mask_array = np.array([lex_pos, lex_len]) 365 | padding = [0] * (max_seq_length - len(input_ids)) 366 | padding1 = [0] * (max_seq_length - len(graph_node)) 367 | input_ids += padding 368 | input_mask += padding 369 | segment_ids += padding 370 | graph_node += padding1 371 | try: 372 | assert len(input_ids) == max_seq_length 373 | assert len(input_mask) == max_seq_length 374 | assert len(segment_ids) == max_seq_length 375 | except: 376 | print(len(input_ids)) 377 | 378 | label_id = example.label 379 | 380 | if ex_index < 5: 381 | print("*** Example ***") 382 | print("guid: %s" % (example.guid[0])) 383 | print("lex:{}".format(example.lex_name[0])) 384 | print("tokens: %s" % " ".join([str(x) for x in bert_tokens])) 385 | print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 386 | print("input_mask: %s" % " ".join([str(x) for x in input_mask])) 387 | print("mask_array: %s" % " ".join([str(x) for x in mask_array])) 388 | print("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 389 | print("label: %s (id = %d)" % (example.label, label_id)) 390 | 391 | features.append( 392 | InputFeatures1(input_ids=input_ids, 393 | input_mask=input_mask, 394 | segment_ids=segment_ids, 395 | label_id=label_id, 396 | mask_array=mask_array, 397 | graph_node=graph_node 398 | )) 399 | return features 400 | 401 | 402 | # 将多个字的tensor pool成一个(,768)的tensor 403 | def avg_pool(input_tensor): 404 | if input_tensor.size(0) >= 2: 405 | m = torch.nn.AdaptiveAvgPool2d((1, 768)) 406 | input_tensor = input_tensor.unsqueeze(0) 407 | # 变成一个(,768)的tensor 408 | output = m(input_tensor).squeeze(0).squeeze(0) 409 | return output 410 | else: 411 | return input_tensor 412 | 413 | 414 | def convert_examples_to_features_graph_gloss(examples, max_seq_length, tokenizer, ltp): 415 | """Loads a data file into a list of `InputBatch`s.""" 416 | 417 | features = [] 418 | for (ex_index, example) in enumerate(tqdm(examples)): 419 | if ex_index % 10000 == 0: 420 | print("Writing example %d of %d" % (ex_index, len(examples))) 421 | 422 | orig_tokens = example.sentence 423 | gloss = example.gloss 424 | seg, hidden = ltp.seg([orig_tokens[0]]) 425 | srl = ltp.srl(hidden) 426 | try: 427 | srl_lex_pos = seg[0].index(example.lex_name[0]) 428 | except: 429 | print(example.lex_name[0]) 430 | q_f.write("lex:{},sentence:{}".format(example.lex_name[0], example.sentence[0])) 431 | q_f.write('\n') 432 | continue 433 | graph_node = [] 434 | for detail in srl[0][srl_lex_pos]: 435 | # print(''.join(seg[0][detail[1]:detail[2] + 1])) 436 | # print(orig_tokens.index(''.join(seg[0][detail[1]:detail[2] + 1]))) 437 | graph_node.append(orig_tokens[0].index(''.join(seg[0][detail[1]:detail[2] + 1]))) 438 | graph_node.append(orig_tokens[0].index(''.join(seg[0][detail[1]:detail[2] + 1])) + len( 439 | ''.join(seg[0][detail[1]:detail[2] + 1]))) 440 | 441 | lex_pos = example.lex_pos[0] 442 | lex_len = len(example.lex_name[0]) 443 | # orig_tokens_list = list(orig_tokens) 444 | # orig_tokens_list.insert(lex_pos, '[CLS]') 445 | # orig_tokens_list.insert(lex_pos+lex_len+1, '[SEP]') 446 | # orig_tokens = "".join(orig_tokens_list) 447 | if lex_pos > 510: 448 | continue 449 | bert_tokens = [] 450 | gloss_tokens = [] 451 | gloss_tokens.extend(tokenizer.tokenize(gloss[0])) 452 | # token = tokenizer.tokenize(orig_tokens[0]) 453 | bert_tokens.extend(tokenizer.tokenize(orig_tokens[0])) 454 | 455 | 456 | bert_tokens.insert(lex_pos, '[CLS]') 457 | bert_tokens.insert(lex_pos + lex_len + 1, '[SEP]') 458 | bert_tokens = ["[CLS]"] + bert_tokens + ["[SEP]"] 459 | segment_ids = [0] * len(bert_tokens) + [1] * len(gloss_tokens) 460 | bert_tokens = bert_tokens + gloss_tokens 461 | if len(bert_tokens) > max_seq_length - 5: 462 | bert_tokens = bert_tokens[:(max_seq_length - 5)] 463 | segment_ids = segment_ids[:(max_seq_length - 5)] 464 | bert_tokens = bert_tokens + ["[SEP]"] 465 | segment_ids = segment_ids + [1] 466 | input_ids = tokenizer.convert_tokens_to_ids(bert_tokens) 467 | input_mask = [1] * len(input_ids) 468 | mask_array = np.zeros(512) 469 | mask_array[lex_pos + 1] = 1 470 | # mask_array = np.array([lex_pos, lex_len]) 471 | padding = [0] * (max_seq_length - len(input_ids)) 472 | padding1 = [0] * (max_seq_length - len(graph_node)) 473 | input_ids += padding 474 | input_mask += padding 475 | segment_ids += padding 476 | graph_node += padding1 477 | try: 478 | assert len(input_ids) == max_seq_length 479 | assert len(input_mask) == max_seq_length 480 | assert len(segment_ids) == max_seq_length 481 | except: 482 | print(len(input_ids)) 483 | 484 | label_id = example.label 485 | 486 | if ex_index < 5: 487 | print("*** Example ***") 488 | print("guid: %s" % (example.guid[0])) 489 | print("lex:{}".format(example.lex_name[0])) 490 | print("tokens: %s" % " ".join([str(x) for x in bert_tokens])) 491 | print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 492 | print("input_mask: %s" % " ".join([str(x) for x in input_mask])) 493 | print("mask_array: %s" % " ".join([str(x) for x in mask_array])) 494 | print("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 495 | print("label: %s (id = %d)" % (example.label, label_id)) 496 | 497 | features.append( 498 | InputFeatures1(input_ids=input_ids, 499 | input_mask=input_mask, 500 | segment_ids=segment_ids, 501 | label_id=label_id, 502 | mask_array=mask_array, 503 | graph_node=graph_node 504 | )) 505 | return features 506 | 507 | 508 | class FocalLoss(nn.Module): 509 | '''Multi-class Focal loss implementation''' 510 | def __init__(self, gamma=2, weight=None): 511 | super(FocalLoss, self).__init__() 512 | self.gamma = gamma 513 | self.weight = weight 514 | 515 | def forward(self, input, target): 516 | """ 517 | input: [N, C] 518 | target: [N, ] 519 | """ 520 | logpt = F.log_softmax(input, dim=1) 521 | pt = torch.exp(logpt) 522 | logpt = (1-pt)**self.gamma * logpt 523 | loss = F.nll_loss(logpt, target, self.weight) 524 | return loss 525 | 526 | 527 | def convert_examples_to_features_dep_graph(examples, max_seq_length, tokenizer, ltp): 528 | """Loads a data file into a list of `InputBatch`s.""" 529 | 530 | features = [] 531 | for (ex_index, example) in enumerate(tqdm(examples)): 532 | if ex_index % 10000 == 0: 533 | print("Writing example %d of %d" % (ex_index, len(examples))) 534 | 535 | orig_tokens = example.sentence 536 | seg, hidden = ltp.seg([orig_tokens[0]]) 537 | dep = ltp.dep(hidden) 538 | try: 539 | dep_lex_pos = seg[0].index(example.lex_name[0]) + 1 540 | except: 541 | print(example.lex_name[0]) 542 | q_f.write("lex:{},sentence:{}".format(example.lex_name[0], example.sentence[0])) 543 | q_f.write("\n") 544 | continue 545 | graph_node = [] 546 | for idx, value in enumerate(dep[0]): 547 | if (value[2] == 'SBV' or value[2] == 'VOB') and value[1] == dep_lex_pos: 548 | head = value[0] - 1 549 | graph_node.append(orig_tokens[0].index(seg[0][head])) 550 | graph_node.append(orig_tokens[0].index(seg[0][head])+len(seg[0][head])) 551 | 552 | lex_pos = example.lex_pos[0] 553 | lex_len = len(example.lex_name[0]) 554 | if lex_pos > 254: 555 | continue 556 | bert_tokens = [] 557 | # token = tokenizer.tokenize(orig_tokens[0]) 558 | bert_tokens.extend(tokenizer.tokenize(orig_tokens[0])) 559 | if len(bert_tokens) > max_seq_length - 4: 560 | bert_tokens = bert_tokens[:(max_seq_length - 4)] 561 | 562 | bert_tokens.insert(lex_pos, '[CLS]') 563 | bert_tokens.insert(lex_pos + lex_len + 1, '[SEP]') 564 | bert_tokens = ["[CLS]"] + bert_tokens + ["[SEP]"] 565 | segment_ids = [0] * len(bert_tokens) 566 | input_ids = tokenizer.convert_tokens_to_ids(bert_tokens) 567 | input_mask = [1] * len(input_ids) 568 | mask_array = np.zeros(256) 569 | mask_array[lex_pos + 1] = 1 570 | # mask_array = np.array([lex_pos, lex_len]) 571 | padding = [0] * (max_seq_length - len(input_ids)) 572 | padding1 = [0] * (max_seq_length - len(graph_node)) 573 | input_ids += padding 574 | input_mask += padding 575 | segment_ids += padding 576 | graph_node += padding1 577 | try: 578 | assert len(input_ids) == max_seq_length 579 | assert len(input_mask) == max_seq_length 580 | assert len(segment_ids) == max_seq_length 581 | except: 582 | print(len(input_ids)) 583 | 584 | label_id = example.label 585 | 586 | if ex_index < 5: 587 | print("*** Example ***") 588 | print("guid: %s" % (example.guid[0])) 589 | print("lex:{}".format(example.lex_name[0])) 590 | print("tokens: %s" % " ".join([str(x) for x in bert_tokens])) 591 | print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 592 | print("input_mask: %s" % " ".join([str(x) for x in input_mask])) 593 | print("mask_array: %s" % " ".join([str(x) for x in mask_array])) 594 | print("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 595 | print("label: %s (id = %d)" % (example.label, label_id)) 596 | 597 | features.append( 598 | InputFeatures1(input_ids=input_ids, 599 | input_mask=input_mask, 600 | segment_ids=segment_ids, 601 | label_id=label_id, 602 | mask_array=mask_array, 603 | graph_node=graph_node 604 | )) 605 | return features 606 | 607 | 608 | 609 | --------------------------------------------------------------------------------