├── requirements.txt ├── data ├── label.txt ├── SemEval2010_task8 │ └── label.txt ├── test.tsv └── train.tsv ├── .idea ├── misc.xml ├── modules.xml ├── relation-cls.iml └── workspace.xml ├── README.md ├── official_eval.py ├── utils.py ├── model.py ├── main.py ├── data_loader.py ├── trainer.py └── eval ├── semeval2010_task8_scorer-v1.2.pl └── answer_keys.txt /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | transformers==2.11.0 3 | -------------------------------------------------------------------------------- /data/label.txt: -------------------------------------------------------------------------------- 1 | Other 2 | Whole-Component(e1,e2) 3 | Connection(e1,e2) 4 | 5 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Version 2 | python3 3 | pytorch>=1.4 4 | ## Installation 5 | pip install transformers 6 | ## How to train the model 7 | python main.py --do_train --eval 8 | #我们可以使用official_eval.py对SemEval-2010-task-8数据集进行测试 9 | ## to do 10 | add the prediction 11 | ## References 12 | * [Huggingface Transformers](https://github.com/huggingface/transformers) 13 | * [https://github.com/monologg/R-BERT](https://github.com/monologg/R-BERT) 14 | -------------------------------------------------------------------------------- /data/SemEval2010_task8/label.txt: -------------------------------------------------------------------------------- 1 | Other 2 | Cause-Effect(e1,e2) 3 | Cause-Effect(e2,e1) 4 | Instrument-Agency(e1,e2) 5 | Instrument-Agency(e2,e1) 6 | Product-Producer(e1,e2) 7 | Product-Producer(e2,e1) 8 | Content-Container(e1,e2) 9 | Content-Container(e2,e1) 10 | Entity-Origin(e1,e2) 11 | Entity-Origin(e2,e1) 12 | Entity-Destination(e1,e2) 13 | Entity-Destination(e2,e1) 14 | Component-Whole(e1,e2) 15 | Component-Whole(e2,e1) 16 | Member-Collection(e1,e2) 17 | Member-Collection(e2,e1) 18 | Message-Topic(e1,e2) 19 | Message-Topic(e2,e1) -------------------------------------------------------------------------------- /.idea/relation-cls.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /data/test.tsv: -------------------------------------------------------------------------------- 1 | Whole-Component(e1,e2) 本 实 用 新 型 公 开 了 一 种 易 除 垢 的 引 风 机 , 包 括 风 箱 、 转 动 密 封 轴 承 、 转 轴 、 风 叶 、 电 机 、 固 定 筒 、 进 水 管 、 进 水 阀 门 、 底 座 、 引 水 腔 、 排 水 管 、 排 水 阀 门 、 第 一 过 滤 网 、 套 管 、 第 二 过 滤 网 、 除 尘 箱 、 进 风 管 、 清 理 门 、 除 尘 布 袋 和 灰 斗 。 2 | Whole-Component(e1,e2) 本 实 用 新 型 公 开 了 一 种 易 除 垢 的 引 风 机 , 包 括 风 箱 、 转 动 密 封 轴 承 、 转 轴 、 风 叶 、 电 机 、 固 定 筒 、 进 水 管 、 进 水 阀 门 、 底 座 、 引 水 腔 、 排 水 管 、 排 水 阀 门 、 第 一 过 滤 网 、 套 管 、 第 二 过 滤 网 、 除 尘 箱 、 进 风 管 、 清 理 门 、 除 尘 布 袋 和 灰 斗 。 3 | Other 本 实 用 新 型 公 开 了 一 种 易 除 垢 的 引 风 机 , 包 括 风 箱 、 转 动 密 封 轴 承 、 转 轴 、 风 叶 、 电 机 、 固 定 筒 、 进 水 管 、 进 水 阀 门 、 底 座 、 引 水 腔 、 排 水 管 、 排 水 阀 门 、 第 一 过 滤 网 、 套 管 、 第 二 过 滤 网 、 除 尘 箱 、 进 风 管 、 清 理 门 、 除 尘 布 袋 和 灰 斗 。 4 | -------------------------------------------------------------------------------- /data/train.tsv: -------------------------------------------------------------------------------- 1 | Whole-Component(e1,e2) 本 实 用 新 型 公 开 了 一 种 易 除 垢 的 引 风 机 , 包 括 风 箱 、 转 动 密 封 轴 承 、 转 轴 、 风 叶 、 电 机 、 固 定 筒 、 进 水 管 、 进 水 阀 门 、 底 座 、 引 水 腔 、 排 水 管 、 排 水 阀 门 、 第 一 过 滤 网 、 套 管 、 第 二 过 滤 网 、 除 尘 箱 、 进 风 管 、 清 理 门 、 除 尘 布 袋 和 灰 斗 。 2 | Whole-Component(e1,e2) 本 实 用 新 型 公 开 了 一 种 易 除 垢 的 引 风 机 , 包 括 风 箱 、 转 动 密 封 轴 承 、 转 轴 、 风 叶 、 电 机 、 固 定 筒 、 进 水 管 、 进 水 阀 门 、 底 座 、 引 水 腔 、 排 水 管 、 排 水 阀 门 、 第 一 过 滤 网 、 套 管 、 第 二 过 滤 网 、 除 尘 箱 、 进 风 管 、 清 理 门 、 除 尘 布 袋 和 灰 斗 。 3 | Other 本 实 用 新 型 公 开 了 一 种 易 除 垢 的 引 风 机 , 包 括 风 箱 、 转 动 密 封 轴 承 、 转 轴 、 风 叶 、 电 机 、 固 定 筒 、 进 水 管 、 进 水 阀 门 、 底 座 、 引 水 腔 、 排 水 管 、 排 水 阀 门 、 第 一 过 滤 网 、 套 管 、 第 二 过 滤 网 、 除 尘 箱 、 进 风 管 、 清 理 门 、 除 尘 布 袋 和 灰 斗 。 4 | -------------------------------------------------------------------------------- /official_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | EVAL_DIR = 'eval' 4 | 5 | 6 | def official_f1(): 7 | # Run the perl script 8 | try: 9 | cmd = "perl {0}/semeval2010_task8_scorer-v1.2.pl {0}/proposed_answers.txt {0}/answer_keys.txt > {0}/result.txt".format(EVAL_DIR) 10 | os.system(cmd) 11 | except: 12 | raise Exception("perl is not installed or proposed_answers.txt is missing") 13 | 14 | with open(os.path.join(EVAL_DIR, 'result.txt'), 'r', encoding='utf-8') as f: 15 | macro_result = list(f)[-1] 16 | macro_result = macro_result.split(":")[1].replace(">>>", "").strip() 17 | macro_result = macro_result.split("=")[1].strip().replace("%", "") 18 | macro_result = float(macro_result) / 100 19 | 20 | return macro_result 21 | 22 | 23 | if __name__ == "__main__": 24 | print("macro-averaged F1 = {}%".format(official_f1() * 100)) 25 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | 5 | import torch 6 | import numpy as np 7 | from transformers import BertTokenizer, BertConfig 8 | 9 | # from official_eval import official_f1 10 | from model import RBERT 11 | from sklearn.metrics import precision_score, recall_score, f1_score 12 | 13 | MODEL_CLASSES = { 14 | 'bert':(BertConfig, RBERT, BertTokenizer) 15 | } 16 | 17 | MODEL_PATH_MAP = { 18 | 'bert':'bert-base-chinese' 19 | } 20 | 21 | ADDITIONAL_SPECIAL_TOKENS = ["","","",""] 22 | 23 | def get_label(args): 24 | return [label.strip() for label in open(os.path.join(args.data_dir, args.label_file), 'r', encoding='utf-8')] 25 | 26 | def load_tokenizer(args): 27 | tokenizer = MODEL_CLASSES[args.model_type][2].from_pretrained(args.model_name_or_path) 28 | tokenizer.add_special_tokens({"additional_special_tokens":ADDITIONAL_SPECIAL_TOKENS}) 29 | return tokenizer 30 | 31 | def write_prediction(args, output_file, preds): 32 | """ 33 | For official evaluation script--来自于英文关系抽取标准数据集SemEval2010_task8_scorer 34 | 35 | :param output_file: prediction_file_path (e.g. eval/preposed_answer.txt) 36 | :param preds: [0,1,0,2,18,...] 37 | """ 38 | relation_labels = get_label(args) 39 | with open(output_file, 'w', encoding='utf-8') as f: 40 | for idx, pred in enumerate(preds): 41 | f.write("{}\t{}\n".format(8001 + idx,relation_labels[pred])) 42 | 43 | 44 | def init_logger(): 45 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 46 | datefmt='%m%d%Y %H:%M:%S', 47 | level=logging.INFO) 48 | 49 | def set_seed(args): 50 | random.seed(args.seed) 51 | np.random.seed(args.seed) 52 | torch.manual_seed(args.seed)#为 CPU 设置种子用于生成随机数,以使得结果是确定的。 53 | # torch.cuda.manual_seed(args.seed)# 为当前 GPU 设置种子用于生成随机数,以使得结果是确定的。 54 | if not args.no_cuda and torch.cuda.is_available(): 55 | torch.cuda.manual_seed_all(args.seed)#为所有的 GPU 设置种子用于生成随机数,以使得结果是确定的。 56 | 57 | 58 | def compute_metrics(preds, labels): 59 | assert len(preds) == len(labels) 60 | return acc_and_f1(preds, labels) 61 | 62 | def simple_accuracy(preds, labels): 63 | return (preds == labels).mean() 64 | 65 | def acc_and_f1(preds, labels): 66 | acc = simple_accuracy(preds, labels) 67 | sk_P_result_micro = precision_score(labels, preds, average='micro') 68 | # sk_R_result_micro = recall_score(labels, preds, average='micro') 69 | sk_f1_result = f1_score(labels, preds, average='macro') 70 | return { 71 | "acc": acc, 72 | "sk_pre": sk_P_result_micro, 73 | # "f1": official_f1(), 74 | # "sk_recall": sk_R_result_micro, 75 | "f1": sk_f1_result 76 | } 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel, BertPreTrainedModel 4 | 5 | PRETRAINED_MODEL_MAP = { 6 | 'bert':BertModel 7 | } 8 | 9 | class FCLayer(nn.Module): 10 | def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True): 11 | super(FCLayer, self).__init__() 12 | self.use_activation = use_activation 13 | self.dropout = nn.Dropout(dropout_rate) 14 | self.linear = nn.Linear(input_dim, output_dim) 15 | self.tanh = nn.Tanh() 16 | 17 | def forward(self, x): 18 | x = self.dropout(x) 19 | if self.use_activation: 20 | x = self.tanh(x) 21 | return self.linear(x) 22 | 23 | 24 | class RBERT(BertPreTrainedModel): 25 | def __init__(self, config, args): 26 | super(RBERT, self).__init__(config) 27 | self.bert = PRETRAINED_MODEL_MAP[args.model_type](config=config) # Load pretrained bert 28 | 29 | self.num_labels = config.num_labels 30 | 31 | self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate) 32 | self.e1_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate) 33 | self.e2_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate) 34 | self.label_classifier = FCLayer(config.hidden_size * 3, config.num_labels, args.dropout_rate, use_activation=False) 35 | 36 | @staticmethod 37 | def entity_averate(hidden_ooutput, e_mask): 38 | """ 39 | Average the entity hidden state vectors 40 | :param hidden_ooutput: [batch_size, max_seq_len, dim] 41 | :param e_mask: [batch_size, max_seq_len] 42 | e.g. e_mask[0] = [0,0,1,1,1,0,...,0] 43 | :return: [batch_size, dim] 44 | """ 45 | e_mask_unsqueeze = e_mask.unsqueeze(1) # [batch_size, 1, max_seq_len] 46 | length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1] 47 | 48 | sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_ooutput).squeeze(1) # [batch_size, 1, dim]-->[batch_size, dim] 49 | avg_vector = sum_vector.float() / length_tensor.float() 50 | 51 | return avg_vector 52 | 53 | def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask): 54 | outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) # sequence_output, pooled_output,(hidden_states), (attention) 55 | sequence_output = outputs[0] 56 | pooled_output = outputs[1] 57 | 58 | #Average 59 | e1_h = self.entity_averate(sequence_output, e1_mask) 60 | e2_h = self.entity_averate(sequence_output, e2_mask) 61 | 62 | # Dropout -> tanh -> fc_layer 63 | pooled_output = self.cls_fc_layer(pooled_output) 64 | e1_h = self.e1_fc_layer(e1_h) 65 | e2_h = self.e2_fc_layer(e2_h) 66 | 67 | #Concat -> fc_layer 68 | concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1) 69 | logits = self.label_classifier(concat_h) 70 | 71 | outputs = (logits,) + outputs[2:] 72 | 73 | # Softmax 74 | if labels is not None: 75 | if self.num_labels == 1: 76 | loss_fct = nn.MSELoss() 77 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 78 | else: 79 | loss_fct = nn.CrossEntropyLoss() 80 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 81 | 82 | outputs = (loss,) + outputs 83 | 84 | return outputs # loss, logits, hidden_states, attentions 85 | 86 | 87 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from trainer import Trainer 4 | from utils import init_logger, load_tokenizer, MODEL_CLASSES, MODEL_PATH_MAP 5 | from data_loader import load_and_cache_examples 6 | 7 | def main(args): 8 | init_logger() 9 | tokenizer = load_tokenizer(args) 10 | 11 | train_dataset = load_and_cache_examples(args, tokenizer, mode="train") 12 | test_dataset = load_and_cache_examples(args, tokenizer, mode="test") 13 | 14 | trainer = Trainer(args, train_dataset=train_dataset, test_dataset=test_dataset) 15 | 16 | if args.do_train: 17 | trainer.train() 18 | if args.do_eval: 19 | trainer.load_model() 20 | trainer.evaluate('test') 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument("--task", default="patent", type=str, help="The name of the task to train") 26 | parser.add_argument("--data_dir", default="./data", type=str, 27 | help="The input data dir. Should contain the .tsv files for the task") 28 | parser.add_argument("--model_dir", default="./model", type=str, help="Path to model") 29 | parser.add_argument("--eval_dir", default="./eval", type=str, help="Evaluation script, result directory") 30 | parser.add_argument("--train_file", default="train.tsv", type=str, help="Train file") 31 | parser.add_argument("--test_file", default="test.tsv", type=str, help="Test file") 32 | parser.add_argument("--label_file", default="label.txt", type=str, help="Label file") 33 | 34 | parser.add_argument("--model_type", default="bert", type=str, help="Model type selected in the list:" + ", ".join(MODEL_CLASSES.keys())) 35 | 36 | parser.add_argument("--seed", type=int, default=66, help="random seed for initializaion") 37 | parser.add_argument("--train_batch_size", default=16, type=int, help="Batch size for training") 38 | parser.add_argument("--eval_batch_size", default=16, type=int, help="Batch size for evaluation") 39 | parser.add_argument("--max_seq_len", default=500, type=int, help="The maximum total input sequence length after tokenization") 40 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam") 41 | parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epoch to perform") 42 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some") 43 | parser.add_argument("--gradient_accumulation_steps",default=1, type=int, 44 | help="Number of updates steps to accumulate before performing a backward/updata pass") 45 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer") 46 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm") 47 | parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform, Override num_train_epochs") 48 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps") 49 | parser.add_argument("--dropout_rate", default=0.3, type=float, help="Dropout for fully-connected layers") 50 | 51 | parser.add_argument("--logging_steps", default=50, type=int, help="Log every X updates steps") 52 | parser.add_argument("--save_steps", default=50, type=int, help="Save checkpoint every X updates steps") 53 | 54 | parser.add_argument("--do_train", action="store_true", help="Whether to run training") 55 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set") 56 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 57 | parser.add_argument("--add_sep_token", action="store_true", help="Add [SEP] token at the end of sentence") 58 | 59 | args = parser.parse_args() 60 | 61 | args.model_name_or_path = MODEL_PATH_MAP[args.model_type] 62 | main(args) 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import json 5 | import logging 6 | 7 | import torch 8 | from torch.utils.data import TensorDataset 9 | 10 | from utils import get_label 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class InputExample(object): 15 | def __init__(self, guid, text_a, label): 16 | self.guid = guid 17 | self.text_a = text_a 18 | self.label = label 19 | 20 | def __repr__(self): 21 | return str(self.to_json.string()) 22 | 23 | def to_dict(self): 24 | output = copy.deepcopy(self.__dict__) 25 | return output 26 | 27 | def to_json_string(self): 28 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 29 | 30 | 31 | class InputFeatures(object): 32 | def __init__(self, input_ids, attention_mask, token_type_ids, label_id, e1_mask, e2_mask): 33 | self.input_ids = input_ids 34 | self.attention_mask = attention_mask 35 | self.token_type_ids = token_type_ids 36 | self.label_id = label_id 37 | self.e1_mask = e1_mask 38 | self.e2_mask = e2_mask 39 | 40 | def __repr(self): 41 | return str(self.to_json_string()) 42 | 43 | def to_dict(self): 44 | output = copy.deepcopy(self.__dict__) 45 | 46 | def to_json_string(self): 47 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 48 | 49 | class REPorcessor(object): 50 | """Processor for the Relational Extraction""" 51 | def __init__(self, args): 52 | self.args = args 53 | self.relation_labels = get_label(args) 54 | 55 | @classmethod 56 | def _read_tsv(cls, input_file, quotechar=None): 57 | with open(input_file, 'r', encoding='utf-8') as f: 58 | reader = csv.reader(f, delimiter='\t', quotechar=quotechar) 59 | lines = [] 60 | for line in reader: 61 | lines.append(line) 62 | return lines 63 | 64 | def _create_examples(self, lines, set_type): 65 | examples = [] 66 | for (i, line) in enumerate(lines): 67 | guid = "%s-%s" % (set_type, i) 68 | text_a = line[1] 69 | label = self.relation_labels.index(line[0]) 70 | if i % 1000 == 0: 71 | logger.info(line) 72 | examples.append(InputExample(guid=guid, text_a=text_a, label=label)) 73 | return examples 74 | 75 | def get_examples(self, mode): 76 | """ 77 | :param mode: train、dev、test 78 | """ 79 | file_to_read = None 80 | if mode == 'train': 81 | file_to_read = self.args.train_file 82 | elif mode == 'dev': 83 | file_to_read = self.args.dev_file 84 | elif mode == 'test': 85 | file_to_read = self.args.test_file 86 | 87 | logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir, file_to_read))) 88 | return self._create_examples(self._read_tsv(os.path.join(self.args.data_dir, file_to_read)), mode) 89 | 90 | processors = { 91 | "patent": REPorcessor 92 | } 93 | 94 | def convert_examples_to_features(examples, max_seq_len, tokenizer, 95 | cls_token='[CLS]', 96 | cls_token_segment_id = 0, 97 | sep_token='[SEP]', 98 | pad_token=0, 99 | pad_token_segment_id=0, 100 | sequence_a_segment_id = 0, 101 | add_sep_token=False, 102 | mask_padding_with_zero=True): 103 | features = [] 104 | for (ex_index, example) in enumerate(examples): 105 | if ex_index % 5000 == 0: 106 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 107 | 108 | tokens_a = tokenizer.tokenize(example.text_a) 109 | 110 | e11_p = tokens_a.index("") # the start position of entity1 111 | e12_p = tokens_a.index("") # the end position of entity1 112 | e21_p = tokens_a.index("") 113 | e22_p = tokens_a.index("") 114 | 115 | # Replace the token 116 | tokens_a[e11_p] = "$" 117 | tokens_a[e12_p] = "$" 118 | tokens_a[e21_p] = "#" 119 | tokens_a[e22_p] = "#" 120 | 121 | # Add 1 because of the {CLS} token 122 | e11_p += 1 123 | e12_p += 1 124 | e21_p += 1 125 | e22_p += 1 126 | 127 | # Account for {CLS} and {SEP} with "-2" and with "-3" for RoBERTa. 128 | if add_sep_token: 129 | special_tokens_count = 2 130 | else: 131 | special_tokens_count = 1 132 | if len(tokens_a) > max_seq_len - special_tokens_count: 133 | tokens_a = tokens_a[:(max_seq_len - special_tokens_count)] 134 | 135 | tokens = tokens_a 136 | if add_sep_token: 137 | tokens += [sep_token] 138 | 139 | token_type_ids = [sequence_a_segment_id] * len(tokens) 140 | 141 | tokens = [cls_token] + tokens 142 | token_type_ids = [cls_token_segment_id] + token_type_ids 143 | 144 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 145 | 146 | # The mask ha 1 for real tokens and 0 for padding toknes. Only real tokens are attended to. 147 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 148 | 149 | # Zero-pad up to the sequence length. 150 | padding_length = max_seq_len - len(input_ids) 151 | input_ids = input_ids + ([pad_token] * padding_length) 152 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 153 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 154 | 155 | # e1 mask, e1 mask 156 | e1_mask = [0] * len(attention_mask) 157 | e2_mask = [0] * len(attention_mask) 158 | 159 | for i in range(e11_p, e12_p + 1): 160 | e1_mask[i] = 1 161 | for i in range(e21_p, e22_p + 1): 162 | e2_mask[i] = 1 163 | 164 | assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len) 165 | assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len) 166 | assert len(token_type_ids) == max_seq_len, "Error with input token type length {} vs {}".format(len(token_type_ids), max_seq_len) 167 | 168 | label_id = int(example.label) 169 | 170 | if ex_index < 5: 171 | logger.info("*** Example ***") 172 | logger.info("guid: %s" % example.guid) 173 | logger.info("toknes: %s" % " ".join([str(x) for x in tokens])) 174 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 175 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 176 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 177 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 178 | logger.info("e1_mask: %s" % " ".join([str(x) for x in e1_mask])) 179 | logger.info("e2_mask: %s" % " ".join([str(x) for x in e2_mask])) 180 | features.append( 181 | InputFeatures(input_ids=input_ids, 182 | attention_mask=attention_mask, 183 | token_type_ids=token_type_ids, 184 | label_id=label_id, 185 | e1_mask=e1_mask, 186 | e2_mask=e2_mask) 187 | ) 188 | 189 | return features 190 | 191 | def load_and_cache_examples(args, tokenizer, mode): 192 | precessor = processors[args.task](args) 193 | 194 | # Load data features from cache or dataset file 195 | cached_features_file = os.path.join( 196 | args.data_dir, 197 | 'cached_{}_{}_{}_{}'.format( 198 | mode, 199 | args.task, 200 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 201 | args.max_seq_len 202 | ) 203 | ) 204 | if os.path.exists(cached_features_file): 205 | logger.info("Loading features from cached file %s", cached_features_file) 206 | features = torch.load(cached_features_file) 207 | else: 208 | logger.info("Create features from dataset file at %s", args.data_dir) 209 | if mode == "train": 210 | examples = precessor.get_examples("train") 211 | elif mode == "dev": 212 | examples = precessor.get_examples("dev") 213 | elif mode == "test": 214 | examples = precessor.get_examples("test") 215 | else: 216 | raise Exception("For mode, only train, dev,test is available") 217 | features = convert_examples_to_features(examples, args.max_seq_len, tokenizer, add_sep_token=args.add_sep_token) 218 | logger.info("Saving features into cached file %s", cached_features_file) 219 | torch.save(features, cached_features_file) 220 | # Convert to Tensor and build dataset 221 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 222 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 223 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 224 | all_e1_mask = torch.tensor([f.e1_mask for f in features], dtype=torch.long) 225 | all_e2_mask = torch.tensor([f.e2_mask for f in features], dtype=torch.long) 226 | 227 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 228 | 229 | dataset = TensorDataset(all_input_ids, all_attention_mask, 230 | all_token_type_ids, all_label_ids, all_e1_mask, all_e2_mask) 231 | return dataset 232 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from tqdm import tqdm, trange 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 8 | from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup 9 | 10 | from model import RBERT 11 | from utils import set_seed, write_prediction, compute_metrics, get_label, MODEL_CLASSES 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class Trainer(object): 16 | def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None): 17 | self.args = args 18 | self.train_dataset = train_dataset 19 | self.dev_dataset = dev_dataset 20 | self.test_dataset = test_dataset 21 | 22 | self.label_lst = get_label(args) 23 | self.num_labels = len(self.label_lst) 24 | 25 | self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type] # BertConfig, RBERT, BertTokenizer 26 | self.config = self.config_class.from_pretrained(args.model_name_or_path, num_labels=self.num_labels, finetuning_task=args.task) 27 | self.model = self.model_class.from_pretrained(args.model_name_or_path, 28 | config=self.config, 29 | args=args) 30 | 31 | self.best_epoch = 0 32 | self.best_f1 = 0 33 | 34 | # GPU or CPU 35 | self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 36 | self.model.to(self.device) 37 | 38 | def train(self): 39 | train_sampler = RandomSampler(self.train_dataset) 40 | train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size) 41 | 42 | if self.args.max_steps > 0: 43 | t_total = self.args.max_steps 44 | self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 45 | else: 46 | t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs 47 | 48 | # Prepare optimizer and schedule (linear warmup and decay) 49 | no_decay = ['bias', 'LayerNorm.weight'] 50 | optimizer_grouped_parameters = [ 51 | {'params': [p for n,p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 52 | 'weight_decay': self.args.weight_decay}, 53 | {'params':[p for n,p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 54 | 'weight_decay':0.0} 55 | ] 56 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) 57 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total) 58 | 59 | # Train! 60 | logger.info("***** Running training *****") 61 | logger.info(" Num examples = %d", len(self.train_dataset)) 62 | logger.info(" Num Epochs = %d", self.args.num_train_epochs) 63 | logger.info(" Total train batch size = %d", self.args.train_batch_size) 64 | logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) 65 | logger.info(" Total optimization steps =%d", t_total) 66 | logger.info(" Logging steps = %d", self.args.logging_steps) 67 | logger.info(" Save steps = %d", self.args.save_steps) 68 | 69 | global_step = 0 70 | tr_loss = 0.0 71 | self.model.zero_grad() 72 | 73 | train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch") 74 | set_seed(self.args) 75 | 76 | for _ in train_iterator: 77 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 78 | for step, batch in enumerate(epoch_iterator): 79 | self.model.train() 80 | batch = tuple(t.to(self.device) for t in batch) 81 | inputs = {'input_ids': batch[0], 82 | 'attention_mask': batch[1], 83 | 'token_type_ids': batch[2], 84 | 'labels': batch[3], 85 | 'e1_mask': batch[4], 86 | 'e2_mask': batch[5]} 87 | outputs = self.model(**inputs) 88 | loss = outputs[0] 89 | 90 | if self.args.gradient_accumulation_steps > 1: 91 | loss = loss / self.args.gradient_accumulation_steps 92 | 93 | loss.backward() 94 | 95 | tr_loss += loss.item() 96 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 97 | # 梯度裁剪 98 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 99 | 100 | # 反向传播,更新参数 101 | optimizer.step() 102 | # 更新学习率 103 | scheduler.step() 104 | # 清空学习率 105 | self.model.zero_grad() 106 | global_step += 1 107 | 108 | if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0: 109 | dev_result = self.evaluate('test') 110 | if dev_result['f1'] >= self.best_f1: 111 | self.best_f1 =dev_result['f1'] 112 | self.best_epoch = _ 113 | self.save_model() 114 | print('save model finished') 115 | print('best f1_score is ', self.best_f1) 116 | print('best epoch is ', self.best_epoch) 117 | 118 | # if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: 119 | # self.save_model() 120 | if 0 < self.args.max_steps < global_step: 121 | epoch_iterator.close() 122 | break 123 | if 0 < self.args.max_steps < global_step: 124 | train_iterator.close() 125 | break 126 | return global_step, tr_loss / global_step 127 | 128 | 129 | def evaluate(self, mode): 130 | if mode == 'test': 131 | dataset = self.test_dataset 132 | elif mode == 'dev': 133 | dataset = self.dev_dataset 134 | else: 135 | raise Exception("Only dev and test dataset available") 136 | 137 | eval_sampler = SequentialSampler(dataset) 138 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size) 139 | 140 | # Eval! 141 | logger.info("***** Running evaluation on %s dataset *****", mode) 142 | logger.info(" Num examples = %d", len(dataset)) 143 | logger.info(" Batch size = %d", self.args.eval_batch_size) 144 | eval_loss = 0.0 145 | nb_eval_steps = 0 146 | preds = None 147 | out_label_ids = None 148 | 149 | self.model.eval() 150 | 151 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 152 | batch = tuple(t.to(self.device) for t in batch) 153 | with torch.no_grad(): 154 | inputs = {'input_ids': batch[0], 155 | 'attention_mask': batch[1], 156 | 'token_type_ids': batch[2], 157 | 'labels': batch[3], 158 | 'e1_mask': batch[4], 159 | 'e2_mask': batch[5]} 160 | outputs = self.model(**inputs) 161 | tmp_eval_loss, logits = outputs[:2] 162 | 163 | 164 | eval_loss += tmp_eval_loss.mean().item() 165 | nb_eval_steps += 1 166 | 167 | if preds is None: 168 | preds = logits.detach().cpu().numpy() 169 | out_label_ids = inputs['labels'].detach().cpu().numpy() 170 | else: 171 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 172 | out_label_ids = np.append( 173 | out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 174 | # print('out_label_ids', out_label_ids) 175 | eval_loss = eval_loss / nb_eval_steps 176 | results = { 177 | "loss": eval_loss 178 | } 179 | preds = np.argmax(preds, axis=1) 180 | write_prediction(self.args, os.path.join(self.args.eval_dir, "proposed_answers.txt"), preds) 181 | 182 | print('preds---',preds) 183 | print('out_label_ids---', out_label_ids) 184 | 185 | result = compute_metrics(preds, out_label_ids) 186 | results.update(result) 187 | 188 | if mode == 'test': 189 | print('best f1_score is ', self.best_f1) 190 | print('best epoch is ', self.best_epoch) 191 | 192 | logger.info("***** Eval results *****") 193 | for key in sorted(results.keys()): 194 | logger.info(" {} = {:.4f}".format(key, results[key])) 195 | 196 | return results 197 | 198 | def save_model(self): 199 | # Save model checkpoint (Overwrite) 200 | if not os.path.exists(self.args.model_dir): 201 | os.makedirs(self.args.model_dir) 202 | model_to_save = self.model.module if hasattr(self.model, 'module') else self.model 203 | model_to_save.save_pretrained(self.args.model_dir) 204 | 205 | # Save training arguments together with the trained model 206 | torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin')) 207 | logger.info("Saving model checkpoint to %s", self.args.model_dir) 208 | 209 | def load_model(self): 210 | # Check whether model exists 211 | if not os.path.exists(self.args.model_dir): 212 | raise Exception("Model doesn't exists! Train first!") 213 | 214 | try: 215 | self.args = torch.load(os.path.join(self.args.model_dir, 'training_args.bin')) 216 | self.config = self.config_class.from_pretrained(self.args.model_dir) 217 | self.model = self.model_class.from_pretrained(self.args.model_dir, config=self.config, args=self.args) 218 | 219 | self.model.to(self.device) 220 | logger.info("****** Model Loaded ******") 221 | except: 222 | raise Exception("Some model files might be missing...") 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /eval/semeval2010_task8_scorer-v1.2.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # 3 | # 4 | # Author: Preslav Nakov 5 | # nakov@comp.nus.edu.sg 6 | # National University of Singapore 7 | # 8 | # WHAT: This is the official scorer for SemEval-2010 Task #8. 9 | # 10 | # 11 | # Last modified: March 22, 2010 12 | # 13 | # Current version: 1.2 14 | # 15 | # Revision history: 16 | # - Version 1.2 (fixed a bug in the precision for the scoring of (iii)) 17 | # - Version 1.1 (fixed a bug in the calculation of accuracy) 18 | # 19 | # 20 | # Use: 21 | # semeval2010_task8_scorer-v1.1.pl 22 | # 23 | # Example2: 24 | # semeval2010_task8_scorer-v1.1.pl proposed_answer1.txt answer_key1.txt > result_scores1.txt 25 | # semeval2010_task8_scorer-v1.1.pl proposed_answer2.txt answer_key2.txt > result_scores2.txt 26 | # semeval2010_task8_scorer-v1.1.pl proposed_answer3.txt answer_key3.txt > result_scores3.txt 27 | # 28 | # Description: 29 | # The scorer takes as input a proposed classification file and an answer key file. 30 | # Both files should contain one prediction per line in the format " " 31 | # with a TAB as a separator, e.g., 32 | # 1 Component-Whole(e2,e1) 33 | # 2 Other 34 | # 3 Instrument-Agency(e2,e1) 35 | # ... 36 | # The files do not have to be sorted in any way and the first file can have predictions 37 | # for a subset of the IDs in the second file only, e.g., because hard examples have been skipped. 38 | # Repetitions of IDs are not allowed in either of the files. 39 | # 40 | # The scorer calculates and outputs the following statistics: 41 | # (1) confusion matrix, which shows 42 | # - the sums for each row/column: -SUM- 43 | # - the number of skipped examples: skip 44 | # - the number of examples with correct relation, but wrong directionality: xDIRx 45 | # - the number of examples in the answer key file: ACTUAL ( = -SUM- + skip + xDIRx ) 46 | # (2) accuracy and coverage 47 | # (3) precision (P), recall (R), and F1-score for each relation 48 | # (4) micro-averaged P, R, F1, where the calculations ignore the Other category. 49 | # (5) macro-averaged P, R, F1, where the calculations ignore the Other category. 50 | # 51 | # Note that in scores (4) and (5), skipped examples are equivalent to those classified as Other. 52 | # So are examples classified as relations that do not exist in the key file (which is probably not optimal). 53 | # 54 | # The scoring is done three times: 55 | # (i) as a (2*9+1)-way classification 56 | # (ii) as a (9+1)-way classification, with directionality ignored 57 | # (iii) as a (9+1)-way classification, with directionality taken into account. 58 | # 59 | # The official score is the macro-averaged F1-score for (iii). 60 | # 61 | 62 | use strict; 63 | 64 | 65 | ############### 66 | ### I/O ### 67 | ############### 68 | 69 | if ($#ARGV != 1) { 70 | die "Usage:\nsemeval2010_task8_scorer.pl \n"; 71 | } 72 | 73 | my $PROPOSED_ANSWERS_FILE_NAME = $ARGV[0]; 74 | my $ANSWER_KEYS_FILE_NAME = $ARGV[1]; 75 | 76 | 77 | ################ 78 | ### MAIN ### 79 | ################ 80 | 81 | my (%confMatrix19way, %confMatrix10wayNoDir, %confMatrix10wayWithDir) = (); 82 | my (%idsProposed, %idsAnswer) = (); 83 | my (%allLabels19waylAnswer, %allLabels10wayAnswer) = (); 84 | my (%allLabels19wayProposed, %allLabels10wayNoDirProposed, %allLabels10wayWithDirProposed) = (); 85 | 86 | ### 1. Read the file contents 87 | my $totalProposed = &readFileIntoHash($PROPOSED_ANSWERS_FILE_NAME, \%idsProposed); 88 | my $totalAnswer = &readFileIntoHash($ANSWER_KEYS_FILE_NAME, \%idsAnswer); 89 | 90 | ### 2. Calculate the confusion matrices 91 | foreach my $id (keys %idsProposed) { 92 | 93 | ### 2.1. Unexpected IDs are not allowed 94 | die "File $PROPOSED_ANSWERS_FILE_NAME contains a bad ID: '$id'" 95 | if (!defined($idsAnswer{$id})); 96 | 97 | ### 2.2. Update the 19-way confusion matrix 98 | my $labelProposed = $idsProposed{$id}; 99 | my $labelAnswer = $idsAnswer{$id}; 100 | $confMatrix19way{$labelProposed}{$labelAnswer}++; 101 | $allLabels19wayProposed{$labelProposed}++; 102 | 103 | ### 2.3. Update the 10-way confusion matrix *without* direction 104 | my $labelProposedNoDir = $labelProposed; 105 | my $labelAnswerNoDir = $labelAnswer; 106 | $labelProposedNoDir =~ s/\(e[12],e[12]\)[\n\r]*$//; 107 | $labelAnswerNoDir =~ s/\(e[12],e[12]\)[\n\r]*$//; 108 | $confMatrix10wayNoDir{$labelProposedNoDir}{$labelAnswerNoDir}++; 109 | $allLabels10wayNoDirProposed{$labelProposedNoDir}++; 110 | 111 | ### 2.4. Update the 10-way confusion matrix *with* direction 112 | if ($labelProposed eq $labelAnswer) { ## both relation and direction match 113 | $confMatrix10wayWithDir{$labelProposedNoDir}{$labelAnswerNoDir}++; 114 | $allLabels10wayWithDirProposed{$labelProposedNoDir}++; 115 | } 116 | elsif ($labelProposedNoDir eq $labelAnswerNoDir) { ## the relations match, but the direction is wrong 117 | $confMatrix10wayWithDir{'WRONG_DIR'}{$labelAnswerNoDir}++; 118 | $allLabels10wayWithDirProposed{'WRONG_DIR'}++; 119 | } 120 | else { ### Wrong relation 121 | $confMatrix10wayWithDir{$labelProposedNoDir}{$labelAnswerNoDir}++; 122 | $allLabels10wayWithDirProposed{$labelProposedNoDir}++; 123 | } 124 | } 125 | 126 | ### 3. Calculate the ground truth distributions 127 | foreach my $id (keys %idsAnswer) { 128 | 129 | ### 3.1. Update the 19-way answer distribution 130 | my $labelAnswer = $idsAnswer{$id}; 131 | $allLabels19waylAnswer{$labelAnswer}++; 132 | 133 | ### 3.2. Update the 10-way answer distribution 134 | my $labelAnswerNoDir = $labelAnswer; 135 | $labelAnswerNoDir =~ s/\(e[12],e[12]\)[\n\r]*$//; 136 | $allLabels10wayAnswer{$labelAnswerNoDir}++; 137 | } 138 | 139 | ### 4. Check for proposed classes that are not contained in the answer key file: this may happen in cross-validation 140 | foreach my $labelProposed (sort keys %allLabels19wayProposed) { 141 | if (!defined($allLabels19waylAnswer{$labelProposed})) { 142 | print "!!!WARNING!!! The proposed file contains $allLabels19wayProposed{$labelProposed} label(s) of type '$labelProposed', which is NOT present in the key file.\n\n"; 143 | } 144 | } 145 | 146 | ### 4. 19-way evaluation with directionality 147 | print "<<< (2*9+1)-WAY EVALUATION (USING DIRECTIONALITY)>>>:\n\n"; 148 | &evaluate(\%confMatrix19way, \%allLabels19wayProposed, \%allLabels19waylAnswer, $totalProposed, $totalAnswer, 0); 149 | 150 | ### 5. Evaluate without directionality 151 | print "<<< (9+1)-WAY EVALUATION IGNORING DIRECTIONALITY >>>:\n\n"; 152 | &evaluate(\%confMatrix10wayNoDir, \%allLabels10wayNoDirProposed, \%allLabels10wayAnswer, $totalProposed, $totalAnswer, 0); 153 | 154 | ### 6. Evaluate without directionality 155 | print "<<< (9+1)-WAY EVALUATION TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL >>>:\n\n"; 156 | my $officialScore = &evaluate(\%confMatrix10wayWithDir, \%allLabels10wayWithDirProposed, \%allLabels10wayAnswer, $totalProposed, $totalAnswer, 1); 157 | 158 | ### 7. Output the official score 159 | printf "<<< The official score is (9+1)-way evaluation with directionality taken into account: macro-averaged F1 = %0.2f%s >>>\n", $officialScore, '%'; 160 | 161 | 162 | ################ 163 | ### SUBS ### 164 | ################ 165 | 166 | sub getIDandLabel() { 167 | my $line = shift; 168 | return (-1,()) if ($line !~ /^([0-9]+)\t([^\r]+)\r?\n$/); 169 | 170 | my ($id, $label) = ($1, $2); 171 | 172 | return ($id, '_Other') if ($label eq 'Other'); 173 | 174 | return ($id, $label) 175 | if (($label eq 'Cause-Effect(e1,e2)') || ($label eq 'Cause-Effect(e2,e1)') || 176 | ($label eq 'Component-Whole(e1,e2)') || ($label eq 'Component-Whole(e2,e1)') || 177 | ($label eq 'Content-Container(e1,e2)') || ($label eq 'Content-Container(e2,e1)') || 178 | ($label eq 'Entity-Destination(e1,e2)') || ($label eq 'Entity-Destination(e2,e1)') || 179 | ($label eq 'Entity-Origin(e1,e2)') || ($label eq 'Entity-Origin(e2,e1)') || 180 | ($label eq 'Instrument-Agency(e1,e2)') || ($label eq 'Instrument-Agency(e2,e1)') || 181 | ($label eq 'Member-Collection(e1,e2)') || ($label eq 'Member-Collection(e2,e1)') || 182 | ($label eq 'Message-Topic(e1,e2)') || ($label eq 'Message-Topic(e2,e1)') || 183 | ($label eq 'Product-Producer(e1,e2)') || ($label eq 'Product-Producer(e2,e1)')); 184 | 185 | return (-1, ()); 186 | } 187 | 188 | 189 | sub readFileIntoHash() { 190 | my ($fname, $ids) = @_; 191 | open(INPUT, $fname) or die "Failed to open $fname for text reading.\n"; 192 | my $lineNo = 0; 193 | while () { 194 | $lineNo++; 195 | my ($id, $label) = &getIDandLabel($_); 196 | die "Bad file format on line $lineNo: '$_'\n" if ($id < 0); 197 | if (defined $$ids{$id}) { 198 | s/[\n\r]*$//; 199 | die "Bad file format on line $lineNo (ID $id is already defined): '$_'\n"; 200 | } 201 | $$ids{$id} = $label; 202 | } 203 | close(INPUT) or die "Failed to close $fname.\n"; 204 | return $lineNo; 205 | } 206 | 207 | 208 | sub evaluate() { 209 | my ($confMatrix, $allLabelsProposed, $allLabelsAnswer, $totalProposed, $totalAnswer, $useWrongDir) = @_; 210 | 211 | ### 0. Create a merged list for the confusion matrix 212 | my @allLabels = (); 213 | &mergeLabelLists($allLabelsAnswer, $allLabelsProposed, \@allLabels); 214 | 215 | ### 1. Print the confusion matrix heading 216 | print "Confusion matrix:\n"; 217 | print " "; 218 | foreach my $label (@allLabels) { 219 | printf " %4s", &getShortRelName($label, $allLabelsAnswer); 220 | } 221 | print " <-- classified as\n"; 222 | print " +"; 223 | foreach my $label (@allLabels) { 224 | print "-----"; 225 | } 226 | if ($useWrongDir) { 227 | print "+ -SUM- xDIRx skip ACTUAL\n"; 228 | } 229 | else { 230 | print "+ -SUM- skip ACTUAL\n"; 231 | } 232 | 233 | ### 2. Print the rest of the confusion matrix 234 | my $freqCorrect = 0; 235 | my $ind = 1; 236 | my $otherSkipped = 0; 237 | foreach my $labelAnswer (sort keys %{$allLabelsAnswer}) { 238 | 239 | ### 2.1. Output the short relation label 240 | printf " %4s |", &getShortRelName($labelAnswer, $allLabelsAnswer); 241 | 242 | ### 2.2. Output a row of the confusion matrix 243 | my $sumProposed = 0; 244 | foreach my $labelProposed (@allLabels) { 245 | $$confMatrix{$labelProposed}{$labelAnswer} = 0 246 | if (!defined($$confMatrix{$labelProposed}{$labelAnswer})); 247 | printf "%4d ", $$confMatrix{$labelProposed}{$labelAnswer}; 248 | $sumProposed += $$confMatrix{$labelProposed}{$labelAnswer}; 249 | } 250 | 251 | ### 2.3. Output the horizontal sums 252 | if ($useWrongDir) { 253 | my $ans = defined($$allLabelsAnswer{$labelAnswer}) ? $$allLabelsAnswer{$labelAnswer} : 0; 254 | $$confMatrix{'WRONG_DIR'}{$labelAnswer} = 0 if (!defined $$confMatrix{'WRONG_DIR'}{$labelAnswer}); 255 | printf "| %4d %4d %4d %6d\n", $sumProposed, $$confMatrix{'WRONG_DIR'}{$labelAnswer}, $ans - $sumProposed - $$confMatrix{'WRONG_DIR'}{$labelAnswer}, $ans; 256 | if ($labelAnswer eq '_Other') { 257 | $otherSkipped = $ans - $sumProposed - $$confMatrix{'WRONG_DIR'}{$labelAnswer}; 258 | } 259 | } 260 | else { 261 | my $ans = defined($$allLabelsAnswer{$labelAnswer}) ? $$allLabelsAnswer{$labelAnswer} : 0; 262 | printf "| %4d %4d %4d\n", $sumProposed, $ans - $sumProposed, $ans; 263 | if ($labelAnswer eq '_Other') { 264 | $otherSkipped = $ans - $sumProposed; 265 | } 266 | } 267 | 268 | $ind++; 269 | 270 | $$confMatrix{$labelAnswer}{$labelAnswer} = 0 271 | if (!defined($$confMatrix{$labelAnswer}{$labelAnswer})); 272 | $freqCorrect += $$confMatrix{$labelAnswer}{$labelAnswer}; 273 | } 274 | print " +"; 275 | foreach (@allLabels) { 276 | print "-----"; 277 | } 278 | print "+\n"; 279 | 280 | ### 3. Print the vertical sums 281 | print " -SUM- "; 282 | foreach my $labelProposed (@allLabels) { 283 | $$allLabelsProposed{$labelProposed} = 0 284 | if (!defined $$allLabelsProposed{$labelProposed}); 285 | printf "%4d ", $$allLabelsProposed{$labelProposed}; 286 | } 287 | if ($useWrongDir) { 288 | printf " %4d %4d %4d %6d\n\n", $totalProposed - $$allLabelsProposed{'WRONG_DIR'}, $$allLabelsProposed{'WRONG_DIR'}, $totalAnswer - $totalProposed, $totalAnswer; 289 | } 290 | else { 291 | printf " %4d %4d %4d\n\n", $totalProposed, $totalAnswer - $totalProposed, $totalAnswer; 292 | } 293 | 294 | ### 4. Output the coverage 295 | my $coverage = 100.0 * $totalProposed / $totalAnswer; 296 | printf "%s%d%s%d%s%5.2f%s", 'Coverage = ', $totalProposed, '/', $totalAnswer, ' = ', $coverage, "\%\n"; 297 | 298 | ### 5. Output the accuracy 299 | my $accuracy = 100.0 * $freqCorrect / $totalProposed; 300 | printf "%s%d%s%d%s%5.2f%s", 'Accuracy (calculated for the above confusion matrix) = ', $freqCorrect, '/', $totalProposed, ' = ', $accuracy, "\%\n"; 301 | 302 | ### 6. Output the accuracy considering all skipped to be wrong 303 | $accuracy = 100.0 * $freqCorrect / $totalAnswer; 304 | printf "%s%d%s%d%s%5.2f%s", 'Accuracy (considering all skipped examples as Wrong) = ', $freqCorrect, '/', $totalAnswer, ' = ', $accuracy, "\%\n"; 305 | 306 | ### 7. Calculate accuracy with all skipped examples considered Other 307 | my $accuracyWithOther = 100.0 * ($freqCorrect + $otherSkipped) / $totalAnswer; 308 | printf "%s%d%s%d%s%5.2f%s", 'Accuracy (considering all skipped examples as Other) = ', ($freqCorrect + $otherSkipped), '/', $totalAnswer, ' = ', $accuracyWithOther, "\%\n"; 309 | 310 | ### 8. Output P, R, F1 for each relation 311 | my ($macroP, $macroR, $macroF1) = (0, 0, 0); 312 | my ($microCorrect, $microProposed, $microAnswer) = (0, 0, 0); 313 | print "\nResults for the individual relations:\n"; 314 | foreach my $labelAnswer (sort keys %{$allLabelsAnswer}) { 315 | 316 | ### 8.1. Consider all wrong directionalities as wrong classification decisions 317 | my $wrongDirectionCnt = 0; 318 | if ($useWrongDir && defined $$confMatrix{'WRONG_DIR'}{$labelAnswer}) { 319 | $wrongDirectionCnt = $$confMatrix{'WRONG_DIR'}{$labelAnswer}; 320 | } 321 | 322 | ### 8.2. Prevent Perl complains about unintialized values 323 | if (!defined($$allLabelsProposed{$labelAnswer})) { 324 | $$allLabelsProposed{$labelAnswer} = 0; 325 | } 326 | 327 | ### 8.3. Calculate P/R/F1 328 | my $P = (0 == $$allLabelsProposed{$labelAnswer}) ? 0 329 | : 100.0 * $$confMatrix{$labelAnswer}{$labelAnswer} / ($$allLabelsProposed{$labelAnswer} + $wrongDirectionCnt); 330 | my $R = (0 == $$allLabelsAnswer{$labelAnswer}) ? 0 331 | : 100.0 * $$confMatrix{$labelAnswer}{$labelAnswer} / $$allLabelsAnswer{$labelAnswer}; 332 | my $F1 = (0 == $P + $R) ? 0 : 2 * $P * $R / ($P + $R); 333 | 334 | ### 8.4. Output P/R/F1 335 | if ($useWrongDir) { 336 | printf "%25s%s%4d%s(%4d +%4d)%s%6.2f", $labelAnswer, 337 | " : P = ", $$confMatrix{$labelAnswer}{$labelAnswer}, '/', $$allLabelsProposed{$labelAnswer}, $wrongDirectionCnt, ' = ', $P; 338 | } 339 | else { 340 | printf "%25s%s%4d%s%4d%s%6.2f", $labelAnswer, 341 | " : P = ", $$confMatrix{$labelAnswer}{$labelAnswer}, '/', ($$allLabelsProposed{$labelAnswer} + $wrongDirectionCnt), ' = ', $P; 342 | } 343 | printf"%s%4d%s%4d%s%6.2f%s%6.2f%s\n", 344 | "% R = ", $$confMatrix{$labelAnswer}{$labelAnswer}, '/', $$allLabelsAnswer{$labelAnswer}, ' = ', $R, 345 | "% F1 = ", $F1, '%'; 346 | 347 | ### 8.5. Accumulate statistics for micro/macro-averaging 348 | if ($labelAnswer ne '_Other') { 349 | $macroP += $P; 350 | $macroR += $R; 351 | $macroF1 += $F1; 352 | $microCorrect += $$confMatrix{$labelAnswer}{$labelAnswer}; 353 | $microProposed += $$allLabelsProposed{$labelAnswer} + $wrongDirectionCnt; 354 | $microAnswer += $$allLabelsAnswer{$labelAnswer}; 355 | } 356 | } 357 | 358 | ### 9. Output the micro-averaged P, R, F1 359 | my $microP = (0 == $microProposed) ? 0 : 100.0 * $microCorrect / $microProposed; 360 | my $microR = (0 == $microAnswer) ? 0 : 100.0 * $microCorrect / $microAnswer; 361 | my $microF1 = (0 == $microP + $microR) ? 0 : 2.0 * $microP * $microR / ($microP + $microR); 362 | print "\nMicro-averaged result (excluding Other):\n"; 363 | printf "%s%4d%s%4d%s%6.2f%s%4d%s%4d%s%6.2f%s%6.2f%s\n", 364 | "P = ", $microCorrect, '/', $microProposed, ' = ', $microP, 365 | "% R = ", $microCorrect, '/', $microAnswer, ' = ', $microR, 366 | "% F1 = ", $microF1, '%'; 367 | 368 | ### 10. Output the macro-averaged P, R, F1 369 | my $distinctLabelsCnt = keys %{$allLabelsAnswer}; 370 | ## -1, if '_Other' exists 371 | $distinctLabelsCnt-- if (defined $$allLabelsAnswer{'_Other'}); 372 | 373 | $macroP /= $distinctLabelsCnt; # first divide by the number of non-Other categories 374 | $macroR /= $distinctLabelsCnt; 375 | $macroF1 /= $distinctLabelsCnt; 376 | print "\nMACRO-averaged result (excluding Other):\n"; 377 | printf "%s%6.2f%s%6.2f%s%6.2f%s\n\n\n\n", "P = ", $macroP, "%\tR = ", $macroR, "%\tF1 = ", $macroF1, '%'; 378 | 379 | ### 11. Return the official score 380 | return $macroF1; 381 | } 382 | 383 | 384 | sub getShortRelName() { 385 | my ($relName, $hashToCheck) = @_; 386 | return '_O_' if ($relName eq '_Other'); 387 | die "relName='$relName'" if ($relName !~ /^(.)[^\-]+\-(.)/); 388 | my $result = (defined $$hashToCheck{$relName}) ? "$1\-$2" : "*$1$2"; 389 | if ($relName =~ /\(e([12])/) { 390 | $result .= $1; 391 | } 392 | return $result; 393 | } 394 | 395 | sub mergeLabelLists() { 396 | my ($hash1, $hash2, $mergedList) = @_; 397 | foreach my $key (sort keys %{$hash1}) { 398 | push @{$mergedList}, $key if ($key ne 'WRONG_DIR'); 399 | } 400 | foreach my $key (sort keys %{$hash2}) { 401 | push @{$mergedList}, $key if (($key ne 'WRONG_DIR') && !defined($$hash1{$key})); 402 | } 403 | } 404 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 83 | 84 | 85 | 86 | model_name_or 87 | 88 | 89 | 90 | 101 | 102 | 103 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 |