├── __init__.py ├── Engine ├── __init__.py ├── opt.py ├── loss.py ├── encoder.py ├── base_trainer.py ├── util.py ├── input_reader.py ├── terms.py ├── sampling.py ├── evaluator.py ├── trainer.py └── models.py ├── data ├── log │ └── __init__.py └── save │ └── __init__.py ├── figures ├── fig_1.png ├── fig_2.png └── fig_3.png ├── configs ├── eval.conf └── train.conf ├── main.py ├── config_reader.py ├── README.md └── args.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/log/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/save/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/fig_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scofield7419/UABSA-SyMux/HEAD/figures/fig_1.png -------------------------------------------------------------------------------- /figures/fig_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scofield7419/UABSA-SyMux/HEAD/figures/fig_2.png -------------------------------------------------------------------------------- /figures/fig_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scofield7419/UABSA-SyMux/HEAD/figures/fig_3.png -------------------------------------------------------------------------------- /Engine/opt.py: -------------------------------------------------------------------------------- 1 | # optional packages 2 | 3 | try: 4 | import tensorboardx 5 | except ImportError: 6 | tensorboardx = None 7 | 8 | 9 | try: 10 | import jinja2 11 | except ImportError: 12 | jinja2 = None 13 | -------------------------------------------------------------------------------- /configs/eval.conf: -------------------------------------------------------------------------------- 1 | [1] 2 | label = 16res 3 | model_type = symux 4 | model_path = data/models/16res 5 | tokenizer_path = data/models/16res 6 | dataset_path = data/datasets/16res/test.json 7 | eval_batch_size = 1 8 | rel_filter_threshold = 0.4 9 | size_embedding = 25 10 | prop_drop = 0.1 11 | max_span_size = 10 12 | store_predictions = true 13 | store_examples = true 14 | sampling_processes = 4 15 | sampling_limit = 100 16 | max_pairs = 1000 17 | log_path = data/log/ -------------------------------------------------------------------------------- /configs/train.conf: -------------------------------------------------------------------------------- 1 | [1] 2 | label = 16res 3 | model_type = symux 4 | model_path = roberta-base 5 | tokenizer_path = bert-base-cased 6 | train_path = data/datasets/ 7 | valid_path = data/datasets/ 8 | train_batch_size = 4 9 | eval_batch_size = 1 10 | neg_term_count = 100 11 | neg_relation_count = 100 12 | epochs = 20 13 | lr = 4e-5 14 | lr_warmup = 0.1 15 | weight_decay = 0.01 16 | max_grad_norm = 1.0 17 | rel_filter_threshold = 0.4 18 | size_embedding = 25 19 | prop_drop = 0.4 20 | max_span_size = 8 21 | store_predictions = true 22 | store_examples = true 23 | sampling_processes = 4 24 | sampling_limit = 100 25 | max_pairs = 800 26 | final_eval = false 27 | log_path = data/log/ 28 | save_path = data/save/ 29 | bert_dim = 768 30 | dep_dim = 300 31 | dep_num = 42 32 | pos_num = 50 33 | pos_dim = 300 34 | w_size = 5 35 | bert_dropout = 0.1 36 | output_dropout = 0.1 37 | num_layer = 3 38 | alpha = 1.0 39 | beta = 0.4 40 | sigma = 1.0 41 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from args import train_argparser, eval_argparser 4 | from config_reader import process_configs 5 | from Engine import input_reader 6 | from Engine.trainer import SyMuxTrainer 7 | 8 | 9 | def __train(run_args): 10 | trainer = SyMuxTrainer(run_args) 11 | trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path, 12 | types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader) 13 | 14 | 15 | def _train(): 16 | arg_parser = train_argparser() 17 | process_configs(target=__train, arg_parser=arg_parser) 18 | 19 | 20 | def __eval(run_args): 21 | trainer = SyMuxTrainer(run_args) 22 | trainer.eval(dataset_path=run_args.dataset_path, types_path=run_args.types_path, 23 | input_reader_cls=input_reader.JsonInputReader) 24 | 25 | 26 | def _eval(): 27 | arg_parser = eval_argparser() 28 | process_configs(target=__eval, arg_parser=arg_parser) 29 | 30 | 31 | if __name__ == '__main__': 32 | arg_parser = argparse.ArgumentParser(add_help=False) 33 | arg_parser.add_argument('train_mode', type=str, default='train', help="Mode: 'train' or 'eval'") 34 | args, _ = arg_parser.parse_known_args() 35 | 36 | if args.train_mode == 'train': 37 | _train() 38 | elif args.train_mode == 'eval': 39 | _eval() 40 | else: 41 | raise Exception("Mode not in ['train', 'eval'], e.g. 'python main.py train ...'") 42 | -------------------------------------------------------------------------------- /Engine/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import torch 3 | 4 | 5 | class Loss(ABC): 6 | def compute(self, *args, **kwargs): 7 | pass 8 | 9 | 10 | class SyMuxLoss(Loss): 11 | def __init__(self, pol_criterion, term_criterion, model, optimizer, scheduler, max_grad_norm): 12 | self._pol_criterion = pol_criterion 13 | self._term_criterion = term_criterion 14 | self._model = model 15 | self._optimizer = optimizer 16 | self._scheduler = scheduler 17 | self._max_grad_norm = max_grad_norm 18 | 19 | def compute(self, term_logits, pol_logits, term_types, pol_types, term_sample_masks, pol_sample_masks): 20 | # term loss 21 | term_logits = term_logits.view(-1, term_logits.shape[-1]) 22 | term_types = term_types.view(-1) 23 | term_sample_masks = term_sample_masks.view(-1).float() 24 | 25 | term_loss = self._term_criterion(term_logits, term_types) 26 | term_loss = (term_loss * term_sample_masks).sum() / term_sample_masks.sum() 27 | 28 | # polarity loss 29 | pol_sample_masks = pol_sample_masks.view(-1).float() 30 | pol_count = pol_sample_masks.sum() 31 | 32 | if pol_count.item() != 0: 33 | pol_logits = pol_logits.view(-1, pol_logits.shape[-1]) 34 | pol_types = pol_types.view(-1, pol_types.shape[-1]) 35 | 36 | pol_loss = self._pol_criterion(pol_logits, pol_types) 37 | pol_loss = pol_loss.sum(-1) / pol_loss.shape[-1] 38 | pol_loss = (pol_loss * pol_sample_masks).sum() / pol_count 39 | 40 | # joint loss 41 | train_loss = term_loss + pol_loss 42 | else: 43 | # corner case: no positive/negative polation samples 44 | train_loss = term_loss 45 | 46 | train_loss.backward() 47 | torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._max_grad_norm) 48 | self._optimizer.step() 49 | self._scheduler.step() 50 | self._model.zero_grad() 51 | return train_loss.item() 52 | -------------------------------------------------------------------------------- /config_reader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import multiprocessing as mp 3 | 4 | 5 | def process_configs(target, arg_parser): 6 | args, _ = arg_parser.parse_known_args() 7 | ctx = mp.get_context('spawn') 8 | 9 | for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args): 10 | p = ctx.Process(target=target, args=(run_args,)) 11 | p.start() 12 | p.join() 13 | 14 | 15 | def _read_config(path): 16 | lines = open(path).readlines() 17 | 18 | runs = [] 19 | run = [1, dict()] 20 | for line in lines: 21 | stripped_line = line.strip() 22 | 23 | # continue in case of comment 24 | if stripped_line.startswith('#'): 25 | continue 26 | 27 | if not stripped_line: 28 | if run[1]: 29 | runs.append(run) 30 | 31 | run = [1, dict()] 32 | continue 33 | 34 | if stripped_line.startswith('[') and stripped_line.endswith(']'): 35 | repeat = int(stripped_line[1:-1]) 36 | run[0] = repeat 37 | else: 38 | key, value = stripped_line.split('=') 39 | key, value = (key.strip(), value.strip()) 40 | run[1][key] = value 41 | 42 | if run[1]: 43 | runs.append(run) 44 | 45 | return runs 46 | 47 | 48 | def _convert_config(config): 49 | config_list = [] 50 | for k, v in config.items(): 51 | if v.lower() == 'true': 52 | config_list.append('--' + k) 53 | elif v.lower() != 'false': 54 | config_list.extend(['--' + k] + v.split(' ')) 55 | 56 | return config_list 57 | 58 | 59 | def _yield_configs(arg_parser, args, verbose=True): 60 | _print = (lambda x: print(x)) if verbose else lambda x: x 61 | 62 | if args.config: 63 | config = _read_config(args.config) 64 | 65 | for run_repeat, run_config in config: 66 | print("-" * 50) 67 | print("Config:") 68 | print(run_config) 69 | 70 | args_copy = copy.deepcopy(args) 71 | config_list = _convert_config(run_config) 72 | run_args = arg_parser.parse_args(config_list, namespace=args_copy) 73 | run_args_dict = vars(run_args) 74 | 75 | # set boolean values 76 | for k, v in run_config.items(): 77 | if v.lower() == 'false': 78 | run_args_dict[k] = False 79 | 80 | print("Repeat %s times" % run_repeat) 81 | print("-" * 50) 82 | 83 | for iteration in range(run_repeat): 84 | _print("Iteration %s" % iteration) 85 | _print("-" * 50) 86 | 87 | yield run_args, run_config, run_repeat 88 | 89 | else: 90 | yield args, None, None 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Resources for IJCAI2022 paper: [Inheriting the Wisdom of Predecessors: A Multiplex Cascade Framework for Unified Aspect-based Sentiment Analysis 3 | ](https://www.ijcai.org/proceedings/2022/0572.pdf) 4 | 5 | ---- 6 | See the [project page](https://haofei.vip/UABSA/) for more details. 7 | 8 | 9 | ---- 10 | 11 | 12 | ## Seven subtasks of ABSA by unification: 13 | 14 | In ABSA community there are at least following seven representative subtasks: 15 | 16 | ``` 17 | AE: aspect term extraction; 18 | OE: opinion term extraction; 19 | AOE: aspect-oriented opinion extraction; 20 | AOPE: aspect-opinion pair extraction; 21 | ALSC: aspect-level sentiment classification; 22 | AESC: aspect extraction and sentiment classification; 23 | TE: triplet extraction 24 | 25 | ``` 26 | 27 | All these subtasks are related by revolving around the predictions of three elements: . 28 | 29 | 30 |

31 | 32 |

33 | 34 | In this project, we consider the unfied ABSA. 35 | We try to enhance the ABSA subtasks by making full use of the interactions between all subtasks, with a **multiplex cascade** framework. 36 | 37 | 38 | ---- 39 | 40 | ## Re-ensembled data for unified ABSA 41 | 42 | * Wang et al. (2017) [1] annotate the unpaired opinion terms (denoted as D_17), 43 | * Fan et al. (2019) [2] pair the aspects with opinion terms (D_19), 44 | * Peng et al. (2020) [3] further provide the labels for triple extraction (D_20). 45 | 46 | 47 | To enable multi-task training, we re-ensemble the existing ABSA datasets so that most of the sentences’ annotations cover all seven subtasks. 48 | 49 | 50 | 51 | 52 | [1] Coupled Multi-Layer Attentions for Co-Extraction of Aspect and Opinion Terms. In AAAI. 2017. 53 | 54 | [2] Target-oriented Opinion Words Extraction with Target-fused Neural Sequence Labeling. In NAACL. 2019. 55 | 56 | [3] Knowing What, How and Why: A Near Complete Solution for Aspect-Based Sentiment Analysis. In AAAI. 2020. 57 | 58 | 59 | ---- 60 | 61 | ## Multiplex Cascade Framework. 62 | 63 | 64 | 65 | The schematic of hierarchical dependency (HD): 66 | 67 |

68 | 69 |

70 | 71 | 72 | The multiplex cascade framework: 73 | 74 |

75 | 76 |

77 | 78 | 79 | 80 | ---- 81 | 82 | 83 | ## Environments 84 | 85 | 86 | ``` 87 | - python (3.8.12) 88 | - cuda (11.4) 89 | - numpy (1.21.4) 90 | - torch (1.10.0) 91 | - gensim (4.1.2) 92 | - transformers (4.13.0) 93 | - pandas (1.3.4) 94 | - scikit-learn (1.0.1) 95 | - corenlp (4.2) 96 | ``` 97 | 98 | 99 | 100 | ---- 101 | 102 | 103 | ## Usage 104 | 105 | ### Preprocessing 106 | 107 | First parse out the dependency trees and POS tags for each sentence, and save as json format. 108 | Recommend employ the stanfordnlp [CoreNLP](https://stanfordnlp.github.io/CoreNLP/) tool. 109 | And use the NLTK package to wrap the parsing process. 110 | 111 | Download the `RoBERTa` PLM. 112 | 113 | ### Configuration 114 | 115 | Configure the `configs\train.conf` and `configs\eval.conf` files. 116 | 117 | ### Running 118 | 119 | ``` 120 | python main.py 121 | ``` 122 | 123 | ---- 124 | If you use this work or code, please kindly cite: 125 | 126 | 127 | ``` 128 | @inproceedings{fei2022unifiedABSA, 129 | author = {Hao Fei and Fei Li and Chenliang Li and Shengqiong Wu and Jingye Li and Donghong Ji}, 130 | title = {Inheriting the Wisdom of Predecessors: A Multiplex Cascade Framework for Unified Aspect-based Sentiment Analysis}, 131 | booktitle = {Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence, {IJCAI}}, 132 | pages = {4121--4128}, 133 | year = {2022}, 134 | } 135 | -------------------------------------------------------------------------------- /Engine/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | class SyntaxAwareGCN(nn.Module): 10 | """ 11 | Simple GCN layer 12 | """ 13 | def __init__(self, dep_dim, in_features, out_features, pos_dim=None, bias=True): 14 | super(SyntaxAwareGCN, self).__init__() 15 | self.dep_dim = dep_dim 16 | self.pos_dim = pos_dim 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | 20 | self.dep_attn = nn.Linear(dep_dim + pos_dim + in_features, out_features) 21 | self.dep_fc = nn.Linear(dep_dim, out_features) 22 | self.pos_fc = nn.Linear(pos_dim, out_features) 23 | 24 | def forward(self, text, adj, dep_embed, pos_embed=None): 25 | """ 26 | 27 | :param text: [batch size, seq_len, feat_dim] 28 | :param adj: [batch size, seq_len, seq_len] 29 | :param dep_embed: [batch size, seq_len, seq_len, dep_type_dim] 30 | :param pos_embed: [batch size, seq_len, pos_dim] 31 | :return: [batch size, seq_len, feat_dim] 32 | """ 33 | batch_size, seq_len, feat_dim = text.shape 34 | 35 | val_us = text.unsqueeze(dim=2) 36 | val_us = val_us.repeat(1, 1, seq_len, 1) 37 | pos_us = pos_embed.unsqueeze(dim=2).repeat(1, 1, seq_len, 1) 38 | # [batch size, seq_len, seq_len, feat_dim+pos_dim+dep_dim] 39 | val_sum = torch.cat([val_us, pos_us, dep_embed], dim=-1) 40 | 41 | r = self.dep_attn(val_sum) 42 | 43 | p = torch.sum(r, dim=-1) 44 | mask = (adj == 0).float() * (-1e30) 45 | p = p + mask 46 | p = torch.softmax(p, dim=2) 47 | p_us = p.unsqueeze(3).repeat(1, 1, 1, feat_dim) 48 | 49 | output = val_us + self.pos_fc(pos_us) + self.dep_fc(dep_embed) 50 | output = torch.mul(p_us, output) 51 | 52 | output_sum = torch.sum(output, dim=2) 53 | 54 | return r, output_sum, p 55 | 56 | 57 | class SGCN(nn.Module): 58 | def __init__(self, opt): 59 | super(SGCN, self).__init__() 60 | self.opt = opt 61 | self.model = nn.ModuleList([SyntaxAwareGCN(opt.dep_dim, opt.bert_dim, 62 | opt.bert_dim, opt.pos_dim) 63 | for i in range(self.opt.num_layer)]) 64 | self.dep_embedding = nn.Embedding(opt.dep_num, opt.dep_dim, padding_idx=0) 65 | 66 | def forward(self, x, simple_graph, graph, pos_embed=None, output_attention=False): 67 | 68 | dep_embed = self.dep_embedding(graph) 69 | 70 | attn_list = [] 71 | for lagcn in self.model: 72 | r, x, attn = lagcn(x, simple_graph, dep_embed, pos_embed=pos_embed) 73 | attn_list.append(attn) 74 | 75 | if output_attention is True: 76 | return x, r, attn_list 77 | else: 78 | return x, r 79 | 80 | 81 | class SyMuxEncoder(nn.Module): 82 | def __init__(self, bert, opt): 83 | super(SyMuxEncoder, self).__init__() 84 | self.opt = opt 85 | self.bert = bert 86 | self.sgcn = SGCN(opt) 87 | 88 | self.fc = nn.Linear(opt.bert_dim*2 + opt.pos_dim, opt.bert_dim*2) 89 | self.bert_dropout = nn.Dropout(opt.bert_dropout) 90 | self.output_dropout = nn.Dropout(opt.output_dropout) 91 | 92 | self.pod_embedding = nn.Embedding(opt.pos_num, opt.pos_dim, padding_idx=0) 93 | 94 | def forward(self, input_ids, input_masks, simple_graph, graph, pos=None, output_attention=False): 95 | 96 | pos_embed = self.pod_embedding(pos) 97 | sequence_output, pooled_output = self.bert(input_ids) 98 | x = self.bert_dropout(sequence_output) 99 | 100 | lagcn_output = self.sgcn(x, simple_graph, graph, pos_embed, output_attention) 101 | 102 | pos_output = self.local_attn(x, pos_embed, self.opt.num_layer, self.opt.w_size) 103 | 104 | output = torch.cat((lagcn_output[0], pos_output, sequence_output), dim=-1) 105 | output = self.fc(output) 106 | output = self.output_dropout(output) 107 | return output, lagcn_output[1] 108 | 109 | def local_attn(self, x, pos_embed, num_layer, w_size): 110 | """ 111 | 112 | :param x: 113 | :param pos_embed: 114 | :return: 115 | """ 116 | batch_size, seq_len, feat_dim = x.shape 117 | pos_dim = pos_embed.size(-1) 118 | output = pos_embed 119 | for i in range(num_layer): 120 | val_sum = torch.cat([x, output], dim=-1) # [batch size, seq_len, feat_dim+pos_dim] 121 | attn = torch.matmul(val_sum, val_sum.transpose(1, 2)) # [batch size, seq_len, seq_len] 122 | # pad size = seq_len + (window_size - 1) // 2 * 2 123 | pad_size = seq_len + w_size * 2 124 | mask = torch.zeros((batch_size, seq_len, pad_size), dtype=torch.float).to( 125 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 126 | for i in range(seq_len): 127 | mask[:, i, i:i + w_size] = 1.0 128 | pad_attn = torch.full((batch_size, seq_len, w_size), -1e18).to( 129 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 130 | attn = torch.cat([pad_attn, attn, pad_attn], dim=-1) 131 | local_attn = torch.softmax(torch.mul(attn, mask), dim=-1) 132 | local_attn = local_attn[:, :, w_size:pad_size - w_size] # [batch size, seq_len, seq_len] 133 | local_attn = local_attn.unsqueeze(dim=3).repeat(1, 1, 1, pos_dim) 134 | output = output.unsqueeze(dim=2).repeat(1, 1, seq_len, 1) 135 | output = torch.sum(torch.mul(output, local_attn), dim=2) # [batch size, seq_len, pos_dim] 136 | return output -------------------------------------------------------------------------------- /Engine/base_trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import os 5 | import sys 6 | from typing import List, Dict, Tuple 7 | 8 | import torch 9 | from torch.nn import DataParallel 10 | from torch.optim import optimizer 11 | from transformers import PreTrainedModel 12 | from transformers import PreTrainedTokenizer 13 | 14 | from Engine import util 15 | from Engine.opt import tensorboardx 16 | 17 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 18 | 19 | 20 | class BaseTrainer: 21 | """ Trainer base class with common methods """ 22 | 23 | def __init__(self, args: argparse.Namespace): 24 | self.args = args 25 | self._debug = self.args.debug 26 | 27 | # logging 28 | name = str(datetime.datetime.now()).replace(' ', '_') 29 | self._log_path = os.path.join(self.args.log_path, self.args.label, name) 30 | util.create_directories_dir(self._log_path) 31 | 32 | if hasattr(args, 'save_path'): 33 | self._save_path = os.path.join(self.args.save_path, self.args.label, name) 34 | util.create_directories_dir(self._save_path) 35 | 36 | self._log_paths = dict() 37 | 38 | # file + console logging 39 | log_formatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") 40 | self._logger = logging.getLogger() 41 | util.reset_logger(self._logger) 42 | 43 | file_handler = logging.FileHandler(os.path.join(self._log_path, 'all.log')) 44 | file_handler.setFormatter(log_formatter) 45 | self._logger.addHandler(file_handler) 46 | 47 | console_handler = logging.StreamHandler(sys.stdout) 48 | console_handler.setFormatter(log_formatter) 49 | self._logger.addHandler(console_handler) 50 | 51 | if self._debug: 52 | self._logger.setLevel(logging.DEBUG) 53 | else: 54 | self._logger.setLevel(logging.INFO) 55 | 56 | # tensorboard summary 57 | self._summary_writer = tensorboardx.SummaryWriter(self._log_path) if tensorboardx is not None else None 58 | 59 | self._best_results = dict() 60 | self._log_arguments() 61 | 62 | # CUDA devices 63 | self._device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") 64 | self._gpu_count = torch.cuda.device_count() 65 | 66 | # set seed 67 | if args.seed is not None: 68 | util.set_seed(args.seed) 69 | 70 | def _add_dataset_logging(self, *labels, data: Dict[str, List[str]]): 71 | for label in labels: 72 | dic = dict() 73 | 74 | for key, columns in data.items(): 75 | path = os.path.join(self._log_path, '%s_%s.csv' % (key, label)) 76 | util.create_csv(path, *columns) 77 | dic[key] = path 78 | 79 | self._log_paths[label] = dic 80 | self._best_results[label] = 0 81 | 82 | def _log_arguments(self): 83 | util.save_dict(self._log_path, self.args, 'args') 84 | if self._summary_writer is not None: 85 | util.summarize_dict(self._summary_writer, self.args, 'args') 86 | 87 | def _log_tensorboard(self, dataset_label: str, data_label: str, data: object, iteration: int): 88 | if self._summary_writer is not None: 89 | self._summary_writer.add_scalar('data/%s/%s' % (dataset_label, data_label), data, iteration) 90 | 91 | def _log_csv(self, dataset_label: str, data_label: str, *data: Tuple[object]): 92 | logs = self._log_paths[dataset_label] 93 | util.append_csv(logs[data_label], *data) 94 | 95 | def _save_best(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, optimizer: optimizer, 96 | accuracy: float, iteration: int, label: str, extra=None): 97 | if accuracy > self._best_results[label]: 98 | self._logger.info("[%s] Best model in iteration %s: %s%% accuracy" % (label, iteration, accuracy)) 99 | self._save_model(self._save_path, model, tokenizer, iteration, 100 | optimizer=optimizer if self.args.save_optimizer else None, 101 | save_as_best=True, name='model_%s' % label, extra=extra) 102 | self._best_results[label] = accuracy 103 | 104 | def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, 105 | iteration: int, optimizer: optimizer = None, save_as_best: bool = False, 106 | extra: dict = None, include_iteration: int = True, name: str = 'model'): 107 | extra_state = dict(iteration=iteration) 108 | 109 | if optimizer: 110 | extra_state['optimizer'] = optimizer.state_dict() 111 | 112 | if extra: 113 | extra_state.update(extra) 114 | 115 | if save_as_best: 116 | dir_path = os.path.join(save_path, '%s_best' % name) 117 | else: 118 | dir_name = '%s_%s' % (name, iteration) if include_iteration else name 119 | dir_path = os.path.join(save_path, dir_name) 120 | 121 | util.create_directories_dir(dir_path) 122 | 123 | # save model 124 | if isinstance(model, DataParallel): 125 | model.module.save_pretrained(dir_path) 126 | else: 127 | model.save_pretrained(dir_path) 128 | 129 | # save vocabulary 130 | tokenizer.save_pretrained(dir_path) 131 | 132 | # save extra 133 | state_path = os.path.join(dir_path, 'extra.state') 134 | torch.save(extra_state, state_path) 135 | 136 | def _get_lr(self, optimizer): 137 | lrs = [] 138 | for group in optimizer.param_groups: 139 | lr_scheduled = group['lr'] 140 | lrs.append(lr_scheduled) 141 | return lrs 142 | 143 | def _close_summary_writer(self): 144 | if self._summary_writer is not None: 145 | self._summary_writer.close() 146 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def _add_common_args(arg_parser): 5 | arg_parser.add_argument('--config', type=str) 6 | 7 | # Input 8 | arg_parser.add_argument('--types_path', type=str, help="Path to type specifications") 9 | 10 | # Preprocessing 11 | arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer") 12 | arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans") 13 | arg_parser.add_argument('--lowercase', action='store_true', default=False, 14 | help="If true, input is lowercased during preprocessing") 15 | arg_parser.add_argument('--sampling_processes', type=int, default=4, 16 | help="Number of sampling processes. 0 = no multiprocessing for sampling") 17 | arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue") 18 | 19 | # Logging 20 | arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models") 21 | arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored") 22 | arg_parser.add_argument('--store_predictions', action='store_true', default=False, 23 | help="If true, store predictions on disc (in log directory)") 24 | arg_parser.add_argument('--store_examples', action='store_true', default=False, 25 | help="If true, store evaluation examples on disc (in log directory)") 26 | arg_parser.add_argument('--example_count', type=int, default=None, 27 | help="Count of evaluation example to store (if store_examples == True)") 28 | arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off") 29 | 30 | # Model / Training / Evaluation 31 | arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints") 32 | arg_parser.add_argument('--model_type', type=str, default="SynFue", help="Type of model") 33 | arg_parser.add_argument('--cpu', action='store_true', default=False, 34 | help="If true, train/evaluate on CPU even if a CUDA device is available") 35 | arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size") 36 | arg_parser.add_argument('--max_pairs', type=int, default=1000, 37 | help="Maximum term pairs to process during training/evaluation") 38 | arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations") 39 | arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding") 40 | arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT") 41 | arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights") 42 | arg_parser.add_argument('--no_overlapping', action='store_true', default=False, 43 | help="If true, do not evaluate on overlapping terms " 44 | "and relations with overlapping terms") 45 | 46 | # SGCN 47 | arg_parser.add_argument('--bert_dim', type=int, default=768) 48 | arg_parser.add_argument('--dep_dim', type=int, default=100) 49 | arg_parser.add_argument('--bert_dropout', type=float, default=0.5) 50 | arg_parser.add_argument('--output_dropout', type=float, default=0.5) 51 | arg_parser.add_argument('--dep_num', type=int, default=42) # 40 + 1(None) + 1(self-loop) 52 | arg_parser.add_argument('--num_layer', type=int, default=3) 53 | 54 | # POS 55 | # arg_parser.add_argument('--use_pos', type=bool, default=True) 56 | arg_parser.add_argument('--pos_num', type=int, default=45) 57 | arg_parser.add_argument('--pos_dim', type=int, default=100) 58 | arg_parser.add_argument('--w_size', type=int, default=1, help='the window size of local attention') 59 | 60 | # Misc 61 | arg_parser.add_argument('--seed', type=int, default=58986, help="Seed") 62 | arg_parser.add_argument('--cache_path', type=str, default=None, 63 | help="Path to cache transformer models (for HuggingFace transformers library)") 64 | 65 | 66 | def train_argparser(): 67 | arg_parser = argparse.ArgumentParser() 68 | 69 | # Input 70 | arg_parser.add_argument('--train_path', type=str, help="Path to train dataset") 71 | arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset") 72 | 73 | # Logging 74 | arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored") 75 | arg_parser.add_argument('--init_eval', action='store_true', default=False, 76 | help="If true, evaluate validation set before training") 77 | arg_parser.add_argument('--save_optimizer', action='store_true', default=False, 78 | help="Save optimizer alongside model") 79 | arg_parser.add_argument('--train_log_iter', type=int, default=1, help="Log training process every x iterations") 80 | arg_parser.add_argument('--final_eval', action='store_true', default=False, 81 | help="Evaluate the model only after training, not at every epoch") 82 | 83 | # Model / Training 84 | arg_parser.add_argument('--train_batch_size', type=int, default=2, help="Training batch size") 85 | arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs") 86 | arg_parser.add_argument('--neg_term_count', type=int, default=100, 87 | help="Number of negative term samples per document (sentence)") 88 | arg_parser.add_argument('--neg_relation_count', type=int, default=100, 89 | help="Number of negative relation samples per document (sentence)") 90 | arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate") 91 | arg_parser.add_argument('--lr_warmup', type=float, default=0.1, 92 | help="Proportion of total train iterations to warmup in linear increase/decrease schedule") 93 | arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply") 94 | arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm") 95 | 96 | _add_common_args(arg_parser) 97 | 98 | return arg_parser 99 | 100 | 101 | def eval_argparser(): 102 | arg_parser = argparse.ArgumentParser() 103 | 104 | # Input 105 | arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset") 106 | 107 | _add_common_args(arg_parser) 108 | 109 | return arg_parser 110 | -------------------------------------------------------------------------------- /Engine/util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import random 5 | import shutil 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from Engine.terms import TokenSpan 11 | 12 | CSV_DELIMETER = ';' 13 | 14 | 15 | TAG_1 = {0: 'O', 1: 'I'} 16 | TAG_2 = {0: 'O', 1: 'P', 2: 'N', 3: 'M'} 17 | 18 | def get_tag_set_size(type=0): 19 | if type == 0: 20 | return len(TAG_1) 21 | elif type == 1: 22 | return len(TAG_2) 23 | 24 | def create_directories_file(f): 25 | d = os.path.dirname(f) 26 | 27 | if d and not os.path.exists(d): 28 | os.makedirs(d) 29 | 30 | return f 31 | 32 | 33 | def create_directories_dir(d): 34 | if d and not os.path.exists(d): 35 | os.makedirs(d) 36 | 37 | return d 38 | 39 | 40 | def create_csv(file_path, *column_names): 41 | if not os.path.exists(file_path): 42 | with open(file_path, 'w', newline='') as csv_file: 43 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 44 | 45 | if column_names: 46 | writer.writerow(column_names) 47 | 48 | 49 | def append_csv(file_path, *row): 50 | if not os.path.exists(file_path): 51 | raise Exception("The given file doesn't exist") 52 | 53 | with open(file_path, 'a', newline='') as csv_file: 54 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 55 | writer.writerow(row) 56 | 57 | 58 | def append_csv_multiple(file_path, *rows): 59 | if not os.path.exists(file_path): 60 | raise Exception("The given file doesn't exist") 61 | 62 | with open(file_path, 'a', newline='') as csv_file: 63 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 64 | for row in rows: 65 | writer.writerow(row) 66 | 67 | 68 | def read_csv(file_path): 69 | lines = [] 70 | with open(file_path, 'r') as csv_file: 71 | reader = csv.reader(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 72 | for row in reader: 73 | lines.append(row) 74 | 75 | return lines[0], lines[1:] 76 | 77 | 78 | def copy_python_directory(source, dest, ignore_dirs=None): 79 | source = source if source.endswith('/') else source + '/' 80 | for (dir_path, dir_names, file_names) in os.walk(source): 81 | tail = '/'.join(dir_path.split(source)[1:]) 82 | new_dir = os.path.join(dest, tail) 83 | 84 | if ignore_dirs and True in [(ignore_dir in tail) for ignore_dir in ignore_dirs]: 85 | continue 86 | 87 | create_directories_dir(new_dir) 88 | 89 | for file_name in file_names: 90 | if file_name.endswith('.py'): 91 | file_path = os.path.join(dir_path, file_name) 92 | shutil.copy2(file_path, new_dir) 93 | 94 | 95 | def save_dict(log_path, dic, name): 96 | # save arguments 97 | # 1. as json 98 | path = os.path.join(log_path, '%s.json' % name) 99 | f = open(path, 'w') 100 | json.dump(vars(dic), f) 101 | f.close() 102 | 103 | # 2. as string 104 | path = os.path.join(log_path, '%s.txt' % name) 105 | f = open(path, 'w') 106 | args_str = ["%s = %s" % (key, value) for key, value in vars(dic).items()] 107 | f.write('\n'.join(args_str)) 108 | f.close() 109 | 110 | 111 | def summarize_dict(summary_writer, dic, name): 112 | table = 'Argument|Value\n-|-' 113 | 114 | for k, v in vars(dic).items(): 115 | row = '\n%s|%s' % (k, v) 116 | table += row 117 | summary_writer.add_text(name, table) 118 | 119 | 120 | def set_seed(seed): 121 | random.seed(seed) 122 | np.random.seed(seed) 123 | torch.manual_seed(seed) 124 | torch.cuda.manual_seed_all(seed) 125 | 126 | 127 | def reset_logger(logger): 128 | for handler in logger.handlers[:]: 129 | logger.removeHandler(handler) 130 | 131 | for f in logger.filters[:]: 132 | logger.removeFilters(f) 133 | 134 | 135 | def flatten(l): 136 | return [i for p in l for i in p] 137 | 138 | 139 | def get_as_list(dic, key): 140 | if key in dic: 141 | return [dic[key]] 142 | else: 143 | return [] 144 | 145 | 146 | def extend_tensor(tensor, extended_shape, fill=0): 147 | tensor_shape = tensor.shape 148 | 149 | extended_tensor = torch.zeros(extended_shape, dtype=tensor.dtype).to(tensor.device) 150 | extended_tensor = extended_tensor.fill_(fill) 151 | 152 | if len(tensor_shape) == 1: 153 | extended_tensor[:tensor_shape[0]] = tensor 154 | elif len(tensor_shape) == 2: 155 | extended_tensor[:tensor_shape[0], :tensor_shape[1]] = tensor 156 | elif len(tensor_shape) == 3: 157 | extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2]] = tensor 158 | elif len(tensor_shape) == 4: 159 | extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2], :tensor_shape[3]] = tensor 160 | 161 | return extended_tensor 162 | 163 | 164 | def padded_stack(tensors, padding=0): 165 | dim_count = len(tensors[0].shape) 166 | 167 | max_shape = [max([t.shape[d] for t in tensors]) for d in range(dim_count)] 168 | padded_tensors = [] 169 | 170 | for t in tensors: 171 | e = extend_tensor(t, max_shape, fill=padding) 172 | padded_tensors.append(e) 173 | 174 | stacked = torch.stack(padded_tensors) 175 | return stacked 176 | 177 | 178 | def batch_index(tensor, index, pad=False): 179 | if tensor.shape[0] != index.shape[0]: 180 | raise Exception() 181 | 182 | if not pad: 183 | return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])]) 184 | else: 185 | return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])]) 186 | 187 | 188 | def padded_nonzero(tensor, padding=0): 189 | indices = padded_stack([tensor[i].nonzero().view(-1) for i in range(tensor.shape[0])], padding) 190 | return indices 191 | 192 | 193 | def swap(v1, v2): 194 | return v2, v1 195 | 196 | 197 | def get_span_tokens(tokens, span): 198 | inside = False 199 | span_tokens = [] 200 | 201 | for t in tokens: 202 | if t.span[0] == span[0]: 203 | inside = True 204 | 205 | if inside: 206 | span_tokens.append(t) 207 | 208 | if inside and t.span[1] == span[1]: 209 | return TokenSpan(span_tokens) 210 | 211 | return None 212 | 213 | 214 | def to_device(batch, device): 215 | converted_batch = dict() 216 | for key in batch.keys(): 217 | converted_batch[key] = batch[key].to(device) 218 | 219 | return converted_batch 220 | 221 | 222 | def check_version(config, model_class, model_path): 223 | if os.path.exists(model_path): 224 | model_path = model_path if model_path.endswith('.bin') else os.path.join(model_path, 'pytorch_model.bin') 225 | state_dict = torch.load(model_path, map_location=torch.device('cpu')) 226 | config_dict = config.to_dict() 227 | 228 | # version check 229 | loaded_version = config_dict.get('symux_version', '1.0') 230 | if 'rel_classifier.weight' in state_dict and loaded_version != model_class.VERSION: 231 | msg = ("Current Engine version (%s) does not match the version of the loaded model (%s). " 232 | % (model_class.VERSION, loaded_version)) 233 | msg += "Use the code matching your version or train a new model." 234 | raise Exception(msg) 235 | 236 | 237 | -------------------------------------------------------------------------------- /Engine/input_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import abstractmethod, ABC 3 | from collections import OrderedDict 4 | from logging import Logger 5 | from typing import Iterable, List 6 | 7 | from tqdm import tqdm 8 | from transformers import BertTokenizer 9 | 10 | from Engine import util 11 | from Engine.terms import Dataset, TermType, RelationType, Term, Relation, Document 12 | 13 | 14 | class BaseInputReader(ABC): 15 | def __init__(self, types_path: str, tokenizer: BertTokenizer, neg_term_count: int = None, 16 | neg_rel_count: int = None, max_span_size: int = None, logger: Logger = None): 17 | types = json.load(open(types_path), object_pairs_hook=OrderedDict) # term + relation types 18 | 19 | self._term_types = OrderedDict() 20 | self._idx2term_type = OrderedDict() 21 | self._relation_types = OrderedDict() 22 | self._idx2relation_type = OrderedDict() 23 | 24 | # terms 25 | # add 'None' term type 26 | none_term_type = TermType('None', 0, 'None', 'No Term') 27 | self._term_types['None'] = none_term_type 28 | self._idx2term_type[0] = none_term_type 29 | 30 | # specified term types 31 | for i, (key, v) in enumerate(types['terms'].items()): 32 | term_type = TermType(key, i + 1, v['short'], v['verbose']) 33 | self._term_types[key] = term_type 34 | self._idx2term_type[i + 1] = term_type 35 | 36 | # relations 37 | # add 'None' relation type 38 | none_relation_type = RelationType('None', 0, 'None', 'No Relation') 39 | self._relation_types['None'] = none_relation_type 40 | self._idx2relation_type[0] = none_relation_type 41 | 42 | # specified relation types 43 | for i, (key, v) in enumerate(types['relations'].items()): 44 | relation_type = RelationType(key, i + 1, v['short'], v['verbose'], v['symmetric']) 45 | self._relation_types[key] = relation_type 46 | self._idx2relation_type[i + 1] = relation_type 47 | 48 | self._neg_term_count = neg_term_count 49 | self._neg_rel_count = neg_rel_count 50 | self._max_span_size = max_span_size 51 | 52 | self._datasets = dict() 53 | 54 | self._tokenizer = tokenizer 55 | self._logger = logger 56 | 57 | self._vocabulary_size = tokenizer.vocab_size 58 | self._context_size = -1 59 | 60 | @abstractmethod 61 | def read(self, datasets): 62 | pass 63 | 64 | def get_dataset(self, label) -> Dataset: 65 | return self._datasets[label] 66 | 67 | def get_term_type(self, idx) -> TermType: 68 | term = self._idx2term_type[idx] 69 | return term 70 | 71 | def get_relation_type(self, idx) -> RelationType: 72 | relation = self._idx2relation_type[idx] 73 | return relation 74 | 75 | def _calc_context_size(self, datasets: Iterable[Dataset]): 76 | sizes = [] 77 | 78 | for dataset in datasets: 79 | for doc in dataset.documents: 80 | sizes.append(len(doc.encoding)) 81 | 82 | context_size = max(sizes) 83 | return context_size 84 | 85 | def _log(self, text): 86 | if self._logger is not None: 87 | self._logger.info(text) 88 | 89 | @property 90 | def datasets(self): 91 | return self._datasets 92 | 93 | @property 94 | def term_types(self): 95 | return self._term_types 96 | 97 | @property 98 | def relation_types(self): 99 | return self._relation_types 100 | 101 | @property 102 | def relation_type_count(self): 103 | return len(self._relation_types) 104 | 105 | @property 106 | def term_type_count(self): 107 | return len(self._term_types) 108 | 109 | @property 110 | def vocabulary_size(self): 111 | return self._vocabulary_size 112 | 113 | @property 114 | def context_size(self): 115 | return self._context_size 116 | 117 | def __str__(self): 118 | string = "" 119 | for dataset in self._datasets.values(): 120 | string += "Dataset: %s\n" % dataset 121 | string += str(dataset) 122 | 123 | return string 124 | 125 | def __repr__(self): 126 | return self.__str__() 127 | 128 | 129 | class JsonInputReader(BaseInputReader): 130 | def __init__(self, types_path: str, tokenizer: BertTokenizer, neg_term_count: int = None, 131 | neg_rel_count: int = None, max_span_size: int = None, logger: Logger = None): 132 | super().__init__(types_path, tokenizer, neg_term_count, neg_rel_count, max_span_size, logger) 133 | 134 | def read(self, dataset_paths): 135 | for dataset_label, dataset_path in dataset_paths.items(): 136 | dataset = Dataset(dataset_label, self._relation_types, self._term_types, self._neg_term_count, 137 | self._neg_rel_count, self._max_span_size) 138 | self._parse_dataset(dataset_path, dataset) 139 | self._datasets[dataset_label] = dataset 140 | 141 | self._context_size = self._calc_context_size(self._datasets.values()) 142 | 143 | def _parse_dataset(self, dataset_path, dataset): 144 | documents = json.load(open(dataset_path)) 145 | for document in tqdm(documents, desc="Parse dataset '%s'" % dataset.label): 146 | self._parse_document(document, dataset) 147 | 148 | def _parse_document(self, doc, dataset) -> Document: 149 | jtokens = doc['tokens'] 150 | jrelations = doc['relations'] 151 | jterms = doc['entities'] 152 | jpols = doc['polarities'] 153 | jdep_label = doc['dep_label'] 154 | jdep_label_indices = doc['dep_label_indices'] 155 | jdep = doc['dep'] 156 | jpos = doc['pos'] 157 | jpos_indices = doc['pos_indices'] 158 | 159 | # parse tokens 160 | doc_tokens, doc_encoding = self._parse_tokens(jtokens, dataset) 161 | 162 | # parse term mentions 163 | terms = self._parse_terms(jterms, doc_tokens, dataset) 164 | 165 | # parse relations 166 | relations = self._parse_relations(jrelations, terms, dataset) 167 | 168 | # create document 169 | document = dataset.create_document(doc_tokens, terms, relations, jpols, doc_encoding, jdep_label, 170 | jdep_label_indices, jdep, jpos, jpos_indices) 171 | 172 | return document 173 | 174 | def _parse_tokens(self, jtokens, dataset): 175 | doc_tokens = [] 176 | 177 | # full document encoding including special tokens ([CLS] and [SEP]) and byte-pair encodings of original tokens 178 | doc_encoding = [self._tokenizer.convert_tokens_to_ids('[CLS]')] 179 | 180 | # parse tokens 181 | for i, token_phrase in enumerate(jtokens): 182 | token_encoding = self._tokenizer.encode(token_phrase, add_special_tokens=False) 183 | span_start, span_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding)) 184 | 185 | token = dataset.create_token(i, span_start, span_end, token_phrase) 186 | 187 | doc_tokens.append(token) 188 | doc_encoding += token_encoding 189 | 190 | doc_encoding += [self._tokenizer.convert_tokens_to_ids('[SEP]')] 191 | 192 | return doc_tokens, doc_encoding 193 | 194 | def _parse_terms(self, jterms, doc_tokens, dataset) -> List[Term]: 195 | terms = [] 196 | 197 | for term_idx, jterm in enumerate(jterms): 198 | term_type = self._term_types[jterm['type']] 199 | start, end = jterm['start'], jterm['end'] 200 | 201 | # create term mention 202 | tokens = doc_tokens[start:end+1] 203 | phrase = " ".join([t.phrase for t in tokens]) 204 | term = dataset.create_term(term_type, tokens, phrase) 205 | terms.append(term) 206 | 207 | return terms 208 | 209 | def _parse_relations(self, jrelations, terms, dataset) -> List[Relation]: 210 | relations = [] 211 | 212 | for jrelation in jrelations: 213 | relation_type = self._relation_types[jrelation['type']] 214 | 215 | head_idx = jrelation['head'] 216 | tail_idx = jrelation['tail'] 217 | 218 | # create relation 219 | head = terms[head_idx] 220 | tail = terms[tail_idx] 221 | 222 | reverse = int(tail.tokens[0].index) < int(head.tokens[0].index) 223 | 224 | # for symmetric relations: head occurs before tail in sentence 225 | if relation_type.symmetric and reverse: 226 | head, tail = util.swap(head, tail) 227 | 228 | relation = dataset.create_relation(relation_type, head_term=head, tail_term=tail, reverse=reverse) 229 | relations.append(relation) 230 | 231 | return relations 232 | -------------------------------------------------------------------------------- /Engine/terms.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List 3 | from torch.utils.data import Dataset as TorchDataset 4 | 5 | from Engine import sampling 6 | 7 | 8 | class RelationType: 9 | def __init__(self, identifier, index, short_name, verbose_name, symmetric=False): 10 | self._identifier = identifier 11 | self._index = index 12 | self._short_name = short_name 13 | self._verbose_name = verbose_name 14 | self._symmetric = symmetric 15 | 16 | @property 17 | def identifier(self): 18 | return self._identifier 19 | 20 | @property 21 | def index(self): 22 | return self._index 23 | 24 | @property 25 | def short_name(self): 26 | return self._short_name 27 | 28 | @property 29 | def verbose_name(self): 30 | return self._verbose_name 31 | 32 | @property 33 | def symmetric(self): 34 | return self._symmetric 35 | 36 | def __int__(self): 37 | return self._index 38 | 39 | def __eq__(self, other): 40 | if isinstance(other, RelationType): 41 | return self._identifier == other._identifier 42 | return False 43 | 44 | def __hash__(self): 45 | return hash(self._identifier) 46 | 47 | class Polarity: 48 | def __init__(self, identifier, index, short_name, verbose_name, symmetric=False): 49 | self._identifier = identifier 50 | self._index = index 51 | self._short_name = short_name 52 | self._verbose_name = verbose_name 53 | self._symmetric = symmetric 54 | 55 | @property 56 | def identifier(self): 57 | return self._identifier 58 | 59 | @property 60 | def index(self): 61 | return self._index 62 | 63 | @property 64 | def short_name(self): 65 | return self._short_name 66 | 67 | @property 68 | def verbose_name(self): 69 | return self._verbose_name 70 | 71 | @property 72 | def symmetric(self): 73 | return self._symmetric 74 | 75 | def __int__(self): 76 | return self._index 77 | 78 | def __eq__(self, other): 79 | if isinstance(other, RelationType): 80 | return self._identifier == other._identifier 81 | return False 82 | 83 | def __hash__(self): 84 | return hash(self._identifier) 85 | 86 | 87 | class TermType: 88 | def __init__(self, identifier, index, short_name, verbose_name): 89 | self._identifier = identifier 90 | self._index = index 91 | self._short_name = short_name 92 | self._verbose_name = verbose_name 93 | 94 | @property 95 | def identifier(self): 96 | return self._identifier 97 | 98 | @property 99 | def index(self): 100 | return self._index 101 | 102 | @property 103 | def short_name(self): 104 | return self._short_name 105 | 106 | @property 107 | def verbose_name(self): 108 | return self._verbose_name 109 | 110 | def __int__(self): 111 | return self._index 112 | 113 | def __eq__(self, other): 114 | if isinstance(other, TermType): 115 | return self._identifier == other._identifier 116 | return False 117 | 118 | def __hash__(self): 119 | return hash(self._identifier) 120 | 121 | 122 | class Token: 123 | def __init__(self, tid: int, index: int, span_start: int, span_end: int, phrase: str): 124 | self._tid = tid # ID within the corresponding dataset 125 | self._index = index # original token index in document 126 | 127 | self._span_start = span_start # start of token span in document (inclusive) 128 | self._span_end = span_end # end of token span in document (exclusive) 129 | self._phrase = phrase 130 | 131 | @property 132 | def index(self): 133 | return self._index 134 | 135 | @property 136 | def span_start(self): 137 | return self._span_start 138 | 139 | @property 140 | def span_end(self): 141 | return self._span_end 142 | 143 | @property 144 | def span(self): 145 | return self._span_start, self._span_end 146 | 147 | @property 148 | def phrase(self): 149 | return self._phrase 150 | 151 | def __eq__(self, other): 152 | if isinstance(other, Token): 153 | return self._tid == other._tid 154 | return False 155 | 156 | def __hash__(self): 157 | return hash(self._tid) 158 | 159 | def __str__(self): 160 | return self._phrase 161 | 162 | def __repr__(self): 163 | return self._phrase 164 | 165 | 166 | class TokenSpan: 167 | def __init__(self, tokens): 168 | self._tokens = tokens 169 | 170 | @property 171 | def span_start(self): 172 | return self._tokens[0].span_start 173 | 174 | @property 175 | def span_end(self): 176 | return self._tokens[-1].span_end 177 | 178 | @property 179 | def span(self): 180 | return self.span_start, self.span_end 181 | 182 | def __getitem__(self, s): 183 | if isinstance(s, slice): 184 | return TokenSpan(self._tokens[s.start:s.stop:s.step]) 185 | else: 186 | try: 187 | return self._tokens[s] 188 | except: 189 | print(self._tokens) 190 | print(len(self._tokens)) 191 | print(s) 192 | 193 | def __iter__(self): 194 | return iter(self._tokens) 195 | 196 | def __len__(self): 197 | return len(self._tokens) 198 | 199 | 200 | class Term: 201 | def __init__(self, eid: int, term_type: TermType, tokens: List[Token], phrase: str): 202 | self._eid = eid # ID within the corresponding dataset 203 | 204 | self._term_type = term_type 205 | 206 | self._tokens = tokens 207 | self._phrase = phrase 208 | 209 | def as_tuple(self): 210 | return self.span_start, self.span_end, self._term_type 211 | 212 | @property 213 | def term_type(self): 214 | return self._term_type 215 | 216 | @property 217 | def tokens(self): 218 | return TokenSpan(self._tokens) 219 | 220 | @property 221 | def span_start(self): 222 | return self._tokens[0].span_start 223 | 224 | @property 225 | def span_end(self): 226 | return self._tokens[-1].span_end 227 | 228 | @property 229 | def span(self): 230 | return self.span_start, self.span_end 231 | 232 | @property 233 | def phrase(self): 234 | return self._phrase 235 | 236 | def __eq__(self, other): 237 | if isinstance(other, Term): 238 | return self._eid == other._eid 239 | return False 240 | 241 | def __hash__(self): 242 | return hash(self._eid) 243 | 244 | def __str__(self): 245 | return self._phrase 246 | 247 | 248 | class Relation: 249 | def __init__(self, rid: int, relation_type: RelationType, head_term: Term, 250 | tail_term: Term, reverse: bool = False): 251 | self._rid = rid # ID within the corresponding dataset 252 | self._relation_type = relation_type 253 | 254 | self._head_term = head_term 255 | self._tail_term = tail_term 256 | 257 | self._reverse = reverse 258 | 259 | self._first_term = head_term if not reverse else tail_term 260 | self._second_term = tail_term if not reverse else head_term 261 | 262 | def as_tuple(self): 263 | head = self._head_term 264 | tail = self._tail_term 265 | head_start, head_end = (head.span_start, head.span_end) 266 | tail_start, tail_end = (tail.span_start, tail.span_end) 267 | 268 | t = ((head_start, head_end, head.term_type), 269 | (tail_start, tail_end, tail.term_type), self._relation_type) 270 | return t 271 | 272 | @property 273 | def relation_type(self): 274 | return self._relation_type 275 | 276 | @property 277 | def head_term(self): 278 | return self._head_term 279 | 280 | @property 281 | def tail_term(self): 282 | return self._tail_term 283 | 284 | @property 285 | def first_term(self): 286 | return self._first_term 287 | 288 | @property 289 | def second_term(self): 290 | return self._second_term 291 | 292 | @property 293 | def reverse(self): 294 | return self._reverse 295 | 296 | def __eq__(self, other): 297 | if isinstance(other, Relation): 298 | return self._rid == other._rid 299 | return False 300 | 301 | def __hash__(self): 302 | return hash(self._rid) 303 | 304 | 305 | class Document: 306 | def __init__(self, doc_id: int, tokens: List[Token], terms: List[Term], relations: List[Relation], polarities: List[Polarity], 307 | encoding: List[int], dep_label: List[int], dep_label_indices: List[int], dep: List[int], 308 | pos: List[str], pos_indices: List[int]): 309 | self._doc_id = doc_id # ID within the corresponding dataset 310 | 311 | self._tokens = tokens 312 | self._terms = terms 313 | self._relations = relations 314 | 315 | # byte-pair document encoding including special tokens ([CLS] and [SEP]) 316 | self._encoding = encoding 317 | 318 | self._dep_label = dep_label 319 | self._dep_label_indices = dep_label_indices 320 | self._dep = dep 321 | 322 | self._pos = pos 323 | self._pos_indices = pos_indices 324 | 325 | @property 326 | def doc_id(self): 327 | return self._doc_id 328 | 329 | @property 330 | def terms(self): 331 | return self._terms 332 | 333 | @property 334 | def relations(self): 335 | return self._relations 336 | 337 | @property 338 | def tokens(self): 339 | return TokenSpan(self._tokens) 340 | 341 | @property 342 | def encoding(self): 343 | return self._encoding 344 | 345 | @property 346 | def dep_label(self): 347 | return self._dep_label 348 | 349 | @property 350 | def dep_label_indices(self): 351 | return self._dep_label_indices 352 | 353 | @property 354 | def dep(self): 355 | return self._dep 356 | 357 | @property 358 | def pos_indices(self): 359 | return self._pos_indices 360 | 361 | @property 362 | def pos(self): 363 | return self._pos 364 | 365 | @encoding.setter 366 | def encoding(self, value): 367 | self._encoding = value 368 | 369 | def __eq__(self, other): 370 | if isinstance(other, Document): 371 | return self._doc_id == other._doc_id 372 | return False 373 | 374 | def __hash__(self): 375 | return hash(self._doc_id) 376 | 377 | 378 | class BatchIterator: 379 | def __init__(self, terms, batch_size, order=None, truncate=False): 380 | self._terms = terms 381 | self._batch_size = batch_size 382 | self._truncate = truncate 383 | self._length = len(self._terms) 384 | self._order = order 385 | 386 | if order is None: 387 | self._order = list(range(len(self._terms))) 388 | 389 | self._i = 0 390 | 391 | def __iter__(self): 392 | return self 393 | 394 | def __next__(self): 395 | if self._truncate and self._i + self._batch_size > self._length: 396 | raise StopIteration 397 | elif not self._truncate and self._i >= self._length: 398 | raise StopIteration 399 | else: 400 | terms = [self._terms[n] for n in self._order[self._i:self._i + self._batch_size]] 401 | self._i += self._batch_size 402 | return terms 403 | 404 | 405 | class Dataset(TorchDataset): 406 | TRAIN_MODE = 'train' 407 | EVAL_MODE = 'eval' 408 | 409 | def __init__(self, label, rel_types, term_types, neg_term_count, 410 | neg_rel_count, max_span_size): 411 | self._label = label 412 | self._rel_types = rel_types 413 | self._term_types = term_types 414 | self._neg_term_count = neg_term_count 415 | self._neg_rel_count = neg_rel_count 416 | self._max_span_size = max_span_size 417 | self._mode = Dataset.TRAIN_MODE 418 | 419 | self._documents = OrderedDict() 420 | self._terms = OrderedDict() 421 | self._relations = OrderedDict() 422 | 423 | # current ids 424 | self._doc_id = 0 425 | self._rid = 0 426 | self._eid = 0 427 | self._tid = 0 428 | 429 | def iterate_documents(self, batch_size, order=None, truncate=False): 430 | return BatchIterator(self.documents, batch_size, order=order, truncate=truncate) 431 | 432 | def iterate_relations(self, batch_size, order=None, truncate=False): 433 | return BatchIterator(self.relations, batch_size, order=order, truncate=truncate) 434 | 435 | def create_token(self, idx, span_start, span_end, phrase) -> Token: 436 | token = Token(self._tid, idx, span_start, span_end, phrase) 437 | self._tid += 1 438 | return token 439 | 440 | def create_document(self, tokens, term_mentions, relations, polarities, doc_encoding, dep_label, dep_label_indices, dep, 441 | pos, pos_indices) -> Document: 442 | document = Document(self._doc_id, tokens, term_mentions, relations, polarities, doc_encoding, dep_label, 443 | dep_label_indices, dep, pos, pos_indices) 444 | self._documents[self._doc_id] = document 445 | self._doc_id += 1 446 | 447 | return document 448 | 449 | def create_term(self, term_type, tokens, phrase) -> Term: 450 | mention = Term(self._eid, term_type, tokens, phrase) 451 | self._terms[self._eid] = mention 452 | self._eid += 1 453 | return mention 454 | 455 | def create_relation(self, relation_type, head_term, tail_term, reverse=False) -> Relation: 456 | relation = Relation(self._rid, relation_type, head_term, tail_term, reverse) 457 | self._relations[self._rid] = relation 458 | self._rid += 1 459 | return relation 460 | 461 | def __len__(self): 462 | return len(self._documents) 463 | 464 | def __getitem__(self, index: int): 465 | doc = self._documents[index] 466 | 467 | if self._mode == Dataset.TRAIN_MODE: 468 | return sampling.create_train_sample(doc, self._neg_term_count, self._neg_rel_count, 469 | self._max_span_size, len(self._rel_types)) 470 | else: 471 | return sampling.create_eval_sample(doc, self._max_span_size) 472 | 473 | def switch_mode(self, mode): 474 | self._mode = mode 475 | 476 | @property 477 | def label(self): 478 | return self._label 479 | 480 | @property 481 | def input_reader(self): 482 | return self._input_reader 483 | 484 | @property 485 | def documents(self): 486 | return list(self._documents.values()) 487 | 488 | @property 489 | def terms(self): 490 | return list(self._terms.values()) 491 | 492 | @property 493 | def relations(self): 494 | return list(self._relations.values()) 495 | 496 | @property 497 | def document_count(self): 498 | return len(self._documents) 499 | 500 | @property 501 | def term_count(self): 502 | return len(self._terms) 503 | 504 | @property 505 | def relation_count(self): 506 | return len(self._relations) 507 | -------------------------------------------------------------------------------- /Engine/sampling.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from Engine import util 6 | 7 | 8 | def create_train_sample(doc, neg_term_count: int, neg_rel_count: int, max_span_size: int, rel_type_count: int): 9 | encodings = doc.encoding 10 | token_count = len(doc.tokens) 11 | context_size = len(encodings) 12 | 13 | # positive terms 14 | pos_term_spans, pos_term_types, pos_term_masks, pos_term_sizes = [], [], [], [] 15 | for e in doc.terms: 16 | pos_term_spans.append(e.span) 17 | pos_term_types.append(e.term_type.index) 18 | pos_term_masks.append(create_term_mask(*e.span, context_size)) 19 | pos_term_sizes.append(len(e.tokens)) 20 | 21 | # positive relations 22 | pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], [] 23 | pos_rels3, pos_rel_masks3, pos_rel_spans3 = [], [], [] 24 | pos_pair_mask = [] # which triplet rel is true 25 | for rel in doc.relations: 26 | s1, s2 = rel.head_term.span, rel.tail_term.span 27 | pos_rels.append((pos_term_spans.index(s1), pos_term_spans.index(s2))) 28 | pos_rel_spans.append((s1, s2)) 29 | pos_rel_types.append(rel.relation_type) 30 | pos_rel_masks.append(create_rel_mask(s1, s2, context_size)) 31 | 32 | def is_in_relation(head, tail, relations): 33 | for rel in relations: 34 | s1, s2 = rel.head_term, rel.tail_term 35 | if s1 == head and s2 == tail: 36 | return 1 37 | return 0 38 | 39 | for x in range(len(doc.relations)): 40 | s1, s2 = doc.relations[x].head_term, doc.relations[x].tail_term 41 | x1, x2 = pos_term_spans.index(s1.span), pos_term_spans.index(s2.span) 42 | t_p_rels3 = [] 43 | t_p_rels3_mask = [] 44 | 45 | t_p_rel_span3 = [] 46 | for idx, e in enumerate(doc.terms): 47 | if idx != x1 and idx != x2: 48 | if is_in_relation(s1, e, doc.relations) or is_in_relation(s2, e, doc.relations) or is_in_relation(e, s1, doc.relations) or is_in_relation(e, s2, doc.relations): 49 | t_p_rels3.append((x1, x2, idx)) 50 | t_p_rels3_mask.append(create_rel_mask3(s1.span, s2.span, e.span, context_size)) 51 | t_p_rel_span3.append((s1.span, s2.span, e.span)) 52 | # t_p_rel_types3.append(1) 53 | if len(t_p_rels3) > 0: 54 | pos_rels3.append(t_p_rels3) 55 | pos_rel_masks3.append(t_p_rels3_mask) 56 | pos_pair_mask.append(1) 57 | pos_rel_spans3.append(t_p_rel_span3) 58 | # pos_rel_types3.append(t_p_rel_types3) 59 | else: 60 | pos_rels3.append([(x1, x2, 0)]) 61 | pos_rel_masks3.append([(create_rel_mask3(s1.span, s2.span, (0, 0), context_size))]) 62 | pos_pair_mask.append(0) 63 | 64 | assert len(pos_rels) == len(pos_rels3) == len(pos_pair_mask) 65 | 66 | # negative terms 67 | neg_term_spans, neg_term_sizes = [], [] 68 | for size in range(1, max_span_size + 1): 69 | for i in range(0, (token_count - size) + 1): 70 | span = doc.tokens[i:i + size].span 71 | if span not in pos_term_spans: 72 | neg_term_spans.append(span) 73 | neg_term_sizes.append(size) 74 | 75 | # sample negative terms 76 | neg_term_samples = random.sample(list(zip(neg_term_spans, neg_term_sizes)), 77 | min(len(neg_term_spans), neg_term_count)) 78 | neg_term_spans, neg_term_sizes = zip(*neg_term_samples) if neg_term_samples else ([], []) 79 | 80 | neg_term_masks = [create_term_mask(*span, context_size) for span in neg_term_spans] 81 | neg_term_types = [0] * len(neg_term_spans) 82 | 83 | # negative relations 84 | # use only strong negative relations, i.e. pairs of actual (labeled) terms that are not related 85 | # neg_rels3 = [] 86 | neg_rel_spans = [] 87 | neg_rel_spans3 = [] 88 | neg_pair_mask = [] 89 | 90 | for i1, s1 in enumerate(pos_term_spans): 91 | for i2, s2 in enumerate(pos_term_spans): 92 | rev = (s2, s1) 93 | rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric 94 | 95 | # do not add as negative relation sample: 96 | # neg. relations from an term to itself 97 | # term pairs that are related according to gt 98 | # term pairs whose reverse exists as a symmetric relation in gt 99 | if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric: 100 | neg_rel_spans.append((s1, s2)) 101 | 102 | p_rel_span3 = [] 103 | for i3, s3 in enumerate(pos_term_spans): 104 | # three spans are different from each other and not exist in pos_rel_span3 105 | if s1 != s2 and s1 != s3 and s2 != s3 and (s1, s2, s3) not in pos_rel_spans3: 106 | p_rel_span3.append((s1, s2, s3)) 107 | if len(p_rel_span3) > 0: 108 | neg_rel_spans3.append(p_rel_span3) 109 | neg_pair_mask.append(1) 110 | else: 111 | neg_rel_spans3.append([(s1, s2, (0, 0))]) 112 | neg_pair_mask.append(0) 113 | 114 | # sample negative relations 115 | 116 | assert len(neg_rel_spans) == len(neg_rel_spans3) == len(neg_pair_mask) 117 | 118 | neg_rel_spans_samples = random.sample(list(zip(neg_rel_spans, neg_rel_spans3, neg_pair_mask)), min(len(neg_rel_spans), neg_rel_count)) 119 | neg_rel_spans, neg_rel_spans3, neg_pair_mask = zip(*neg_rel_spans_samples) if neg_rel_spans_samples else ([], [], []) 120 | 121 | neg_rels = [(pos_term_spans.index(s1), pos_term_spans.index(s2)) for s1, s2 in neg_rel_spans] 122 | neg_rels3 = [[(pos_term_spans.index(s1), pos_term_spans.index(s2), pos_term_spans.index(s3)) for s1, s2, s3 in x] for x in neg_rel_spans3] 123 | 124 | assert len(neg_rels3) == len(neg_rel_spans3) == len(neg_pair_mask) 125 | 126 | neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans] 127 | neg_rel_masks3 = [[create_rel_mask3(*sps, context_size) for sps in spans] for spans in neg_rel_spans3] 128 | neg_rel_types = [0] * len(neg_rel_spans) 129 | # neg_rel_types3 = [0] * len(neg_rel_spans3) 130 | 131 | # merge 132 | term_types = pos_term_types + neg_term_types 133 | term_masks = pos_term_masks + neg_term_masks 134 | term_sizes = pos_term_sizes + list(neg_term_sizes) 135 | term_spans = pos_term_spans + list(neg_term_spans) 136 | 137 | rels = pos_rels + neg_rels 138 | rel_types = [r.index for r in pos_rel_types] + neg_rel_types 139 | rel_masks = pos_rel_masks + neg_rel_masks 140 | 141 | rels3 = pos_rels3 + neg_rels3 142 | # rel_types3 = pos_rel_types3 + neg_rel_types3 143 | rel_masks3 = pos_rel_masks3 + neg_rel_masks3 144 | pair_mask = pos_pair_mask + list(neg_pair_mask) 145 | 146 | assert len(term_masks) == len(term_sizes) == len(term_types) 147 | try: 148 | assert len(rels) == len(rel_masks) == len(rel_types) == len(rels3) == len(pair_mask) 149 | except: 150 | print(len(rels)) 151 | print(len(rels3)) 152 | print(len(pair_mask)) 153 | 154 | encodings = torch.tensor(encodings, dtype=torch.long) 155 | 156 | # masking of tokens 157 | context_masks = torch.ones(context_size, dtype=torch.bool) 158 | 159 | # also create samples_masks: 160 | # tensors to mask term/relation samples of batch 161 | # since samples are stacked into batches, "padding" terms/relations possibly must be created 162 | # these are later masked during loss computation 163 | if term_masks: 164 | term_types = torch.tensor(term_types, dtype=torch.long) 165 | term_masks = torch.stack(term_masks) 166 | term_sizes = torch.tensor(term_sizes, dtype=torch.long) 167 | term_sample_masks = torch.ones([term_masks.shape[0]], dtype=torch.bool) 168 | term_spans = torch.tensor(term_spans, dtype=torch.long) 169 | else: 170 | # corner case handling (no pos/neg terms) 171 | term_types = torch.zeros([1], dtype=torch.long) 172 | term_masks = torch.zeros([1, context_size], dtype=torch.bool) 173 | term_sizes = torch.zeros([1], dtype=torch.long) 174 | term_sample_masks = torch.zeros([1], dtype=torch.bool) 175 | term_spans = torch.tensor([1, 2], dtype=torch.long) 176 | 177 | if rels: 178 | rels = torch.tensor(rels, dtype=torch.long) 179 | rel_masks = torch.stack(rel_masks) 180 | rel_types = torch.tensor(rel_types, dtype=torch.long) 181 | rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool) 182 | else: 183 | # corner case handling (no pos/neg relations) 184 | rels = torch.zeros([1, 2], dtype=torch.long) 185 | rel_types = torch.zeros([1], dtype=torch.long) 186 | rel_masks = torch.zeros([1, context_size], dtype=torch.bool) 187 | rel_sample_masks = torch.zeros([1], dtype=torch.bool) 188 | 189 | if rels3: 190 | max_tri = max([len(x) for x in rels3]) 191 | for idx, r in enumerate(rels3): 192 | r_len = len(r) 193 | if r_len < max_tri: 194 | rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len)) 195 | rel_masks3[idx].extend([rel_masks3[idx][0]] * (max_tri - r_len)) 196 | rels3 = torch.tensor(rels3, dtype=torch.long) 197 | try: 198 | rel_masks3 = torch.stack([torch.stack(x) for x in rel_masks3]) 199 | except: 200 | print(rel_masks3) 201 | rel_sample_masks3 = torch.ones([rels3.shape[0]], dtype=torch.bool) 202 | pair_mask = torch.tensor(pair_mask, dtype=torch.bool) 203 | else: 204 | rels3 = torch.zeros([1, 3], dtype=torch.long) 205 | rel_masks3 = torch.zeros([1, context_size], dtype=torch.bool) 206 | rel_sample_masks3 = torch.zeros([1], dtype=torch.bool) 207 | pair_mask = torch.tensor(pair_mask, dtype=torch.bool) 208 | 209 | # relation types to one-hot encoding 210 | rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32) 211 | rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1) 212 | rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation 213 | 214 | simple_graph = None 215 | graph = None 216 | try: 217 | simple_graph = torch.tensor(get_simple_graph(context_size, doc.dep), dtype=torch.long) # only the relation 218 | except: 219 | print(context_size) 220 | print(token_count) 221 | print(encodings) 222 | print(doc.dep) 223 | print(doc.dep_label_indices) 224 | try: 225 | graph = torch.tensor(get_graph(context_size, doc.dep, doc.dep_label_indices), 226 | dtype=torch.long) # relation and the type of relation 227 | except: 228 | print(context_size) 229 | print(token_count) 230 | print(encodings) 231 | print(doc.dep) 232 | print(doc.dep_label_indices) 233 | 234 | pos = torch.tensor(get_pos(context_size, doc.pos_indices), dtype=torch.long) 235 | 236 | return dict(encodings=encodings, context_masks=context_masks, term_masks=term_masks, 237 | term_sizes=term_sizes, term_types=term_types, term_spans=term_spans, 238 | rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot, 239 | rels3=rels3, rel_sample_masks3=rel_sample_masks3, rel_masks3=rel_masks3, 240 | pair_mask=pair_mask, 241 | term_sample_masks=term_sample_masks, rel_sample_masks=rel_sample_masks, 242 | simple_graph=simple_graph, graph=graph, pos=pos) 243 | 244 | 245 | def create_eval_sample(doc, max_span_size: int): 246 | encodings = doc.encoding 247 | token_count = len(doc.tokens) 248 | context_size = len(encodings) 249 | 250 | # create term candidates 251 | term_spans = [] 252 | term_masks = [] 253 | term_sizes = [] 254 | 255 | for size in range(1, max_span_size + 1): 256 | for i in range(0, (token_count - size) + 1): 257 | span = doc.tokens[i:i + size].span 258 | term_spans.append(span) 259 | term_masks.append(create_term_mask(*span, context_size)) 260 | term_sizes.append(size) 261 | 262 | # create tensors 263 | # token indices 264 | _encoding = encodings 265 | encodings = torch.zeros(context_size, dtype=torch.long) 266 | encodings[:len(_encoding)] = torch.tensor(_encoding, dtype=torch.long) 267 | 268 | # masking of tokens 269 | context_masks = torch.zeros(context_size, dtype=torch.bool) 270 | context_masks[:len(_encoding)] = 1 271 | 272 | # terms 273 | if term_masks: 274 | term_masks = torch.stack(term_masks) 275 | term_sizes = torch.tensor(term_sizes, dtype=torch.long) 276 | term_spans = torch.tensor(term_spans, dtype=torch.long) 277 | 278 | # tensors to mask term samples of batch 279 | # since samples are stacked into batches, "padding" terms possibly must be created 280 | # these are later masked during evaluation 281 | term_sample_masks = torch.tensor([1] * term_masks.shape[0], dtype=torch.bool) 282 | else: 283 | # corner case handling (no terms) 284 | term_masks = torch.zeros([1, context_size], dtype=torch.bool) 285 | term_sizes = torch.zeros([1], dtype=torch.long) 286 | term_spans = torch.zeros([1, 2], dtype=torch.long) 287 | term_sample_masks = torch.zeros([1], dtype=torch.bool) 288 | 289 | simple_graph = torch.tensor(get_simple_graph(context_size, doc.dep), dtype=torch.long) # only the relation 290 | graph = torch.tensor(get_graph(context_size, doc.dep, doc.dep_label_indices), 291 | dtype=torch.long) # relation and the type of relation 292 | pos = torch.tensor(get_pos(context_size, doc.pos_indices), dtype=torch.long) 293 | 294 | return dict(encodings=encodings, context_masks=context_masks, term_masks=term_masks, 295 | term_sizes=term_sizes, term_spans=term_spans, term_sample_masks=term_sample_masks, 296 | simple_graph=simple_graph, graph=graph, pos=pos) 297 | 298 | 299 | def create_term_mask(start, end, context_size): 300 | mask = torch.zeros(context_size, dtype=torch.bool) 301 | mask[start:end] = 1 302 | return mask 303 | 304 | 305 | def create_rel_mask(s1, s2, context_size): 306 | start = s1[1] if s1[1] < s2[0] else s2[1] 307 | end = s2[0] if s1[1] < s2[0] else s1[0] 308 | mask = create_term_mask(start, end, context_size) 309 | return mask 310 | 311 | 312 | def create_rel_mask3(s1, s2, s3, context_size): 313 | mask = torch.zeros(context_size, dtype=torch.bool) 314 | start = min(s1[0], s1[1], s2[0], s2[1], s3[0], s3[1]) 315 | end = max(s1[0], s1[1], s2[0], s2[1], s3[0], s3[1]) 316 | mask[start:end] = 1 317 | return mask 318 | 319 | 320 | def collate_fn_padding(batch): 321 | padded_batch = dict() 322 | keys = batch[0].keys() 323 | 324 | for key in keys: 325 | samples = [s[key] for s in batch] 326 | if not batch[0][key].shape: 327 | padded_batch[key] = torch.stack(samples) 328 | else: 329 | padded_batch[key] = util.padded_stack([s[key] for s in batch]) 330 | 331 | return padded_batch 332 | 333 | 334 | def get_graph(seq_len, feature_data, feature2id): 335 | ret = [[0] * seq_len for _ in range(seq_len)] 336 | for i, item in enumerate(feature_data): 337 | if int(item) > seq_len-1 or int(item) == 0: 338 | continue 339 | ret[i + 1][int(item) - 1] = feature2id[i] + 2 340 | ret[int(item) - 1][i + 1] = feature2id[i] + 2 341 | ret[i + 1][i + 1] = 1 342 | return ret 343 | 344 | 345 | def get_simple_graph(seq_len, feature_data): 346 | ret = [[0] * seq_len for _ in range(seq_len)] 347 | for i, item in enumerate(feature_data): 348 | if int(item) > seq_len-1: 349 | continue 350 | ret[i + 1][int(item) - 1] = 1 351 | ret[int(item) - 1][i + 1] = 1 352 | ret[i + 1][i + 1] = 1 353 | return ret 354 | 355 | 356 | def get_pos(seq_len, pos_indices): 357 | ret = [0] * seq_len 358 | for i, item in enumerate(pos_indices): 359 | ret[i + 1] = pos_indices[i] + 1 360 | return ret -------------------------------------------------------------------------------- /Engine/evaluator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | from typing import List, Tuple, Dict 5 | 6 | import torch 7 | from sklearn.metrics import precision_recall_fscore_support as prfs 8 | from transformers import RoBertaTokenizer 9 | 10 | from Engine import util 11 | from Engine.terms import Document, Dataset, TermType 12 | from Engine.input_reader import JsonInputReader 13 | from Engine.opt import jinja2 14 | 15 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | 18 | class Evaluator: 19 | def __init__(self, dataset: Dataset, input_reader: JsonInputReader, text_encoder: RoBertaTokenizer, 20 | rel_filter_threshold: float, no_overlapping: bool, 21 | predictions_path: str, examples_path: str, example_count: int, epoch: int, dataset_label: str): 22 | self._text_encoder = text_encoder 23 | self._input_reader = input_reader 24 | self._dataset = dataset 25 | self._rel_filter_threshold = rel_filter_threshold 26 | self._no_overlapping = no_overlapping 27 | 28 | self._epoch = epoch 29 | self._dataset_label = dataset_label 30 | 31 | self._predictions_path = predictions_path 32 | 33 | self._examples_path = examples_path 34 | self._example_count = example_count 35 | 36 | # relations 37 | self._gt_relations = [] # ground truth 38 | self._pred_relations = [] # prediction 39 | 40 | # terms 41 | self._gt_terms = [] # ground truth 42 | self._pred_terms = [] # prediction 43 | 44 | self._pseudo_term_type = TermType('Term', 1, 'Term', 'Term') # for span only evaluation 45 | 46 | self._convert_gt(self._dataset.documents) 47 | 48 | def eval_batch(self, AE_clf, OE_clf, AOE_clf, AOPE_clf, ALSC_clf, AESC_clf, TE_clf, batch: dict): 49 | batch_size = batch_rel_clf.shape[0] 50 | rel_class_count = batch_rel_clf.shape[2] 51 | # get maximum activation (index of predicted term type) 52 | batch_term_types = batch_term_clf.argmax(dim=-1) 53 | # apply term sample mask 54 | batch_term_types *= batch['term_sample_masks'].long() 55 | 56 | batch_rel_clf = batch_rel_clf.view(batch_size, -1) 57 | 58 | # apply threshold to relations 59 | if self._rel_filter_threshold > 0: 60 | batch_rel_clf[batch_rel_clf < self._rel_filter_threshold] = 0 61 | 62 | for i in range(batch_size): 63 | # get model predictions for sample 64 | rel_clf = batch_rel_clf[i] 65 | term_types = batch_term_types[i] 66 | 67 | # get predicted relation labels and corresponding term pairs 68 | rel_nonzero = rel_clf.nonzero().view(-1) 69 | rel_scores = rel_clf[rel_nonzero] 70 | 71 | rel_types = (rel_nonzero % rel_class_count) + 1 # model does not predict None class (+1) 72 | rel_indices = rel_nonzero // rel_class_count 73 | 74 | rels = batch_rels[i][rel_indices] 75 | 76 | # get masks of terms in relation 77 | rel_term_spans = batch['term_spans'][i][rels].long() 78 | 79 | # get predicted term types 80 | rel_term_types = torch.zeros([rels.shape[0], 2]) 81 | if rels.shape[0] != 0: 82 | rel_term_types = torch.stack([term_types[rels[j]] for j in range(rels.shape[0])]) 83 | 84 | # convert predicted relations for evaluation 85 | sample_pred_relations = self._convert_pred_relations(rel_types, rel_term_spans, 86 | rel_term_types, rel_scores) 87 | 88 | # get terms that are not classified as 'None' 89 | valid_term_indices = term_types.nonzero().view(-1) 90 | valid_term_types = term_types[valid_term_indices] 91 | valid_term_spans = batch['term_spans'][i][valid_term_indices] 92 | valid_term_scores = torch.gather(batch_term_clf[i][valid_term_indices], 1, 93 | valid_term_types.unsqueeze(1)).view(-1) 94 | sample_pred_terms = self._convert_pred_terms(valid_term_types, valid_term_spans, 95 | valid_term_scores) 96 | 97 | if self._no_overlapping: 98 | sample_pred_terms, sample_pred_relations = self._remove_overlapping(sample_pred_terms, 99 | sample_pred_relations) 100 | 101 | self._pred_terms.append(sample_pred_terms) 102 | self._pred_relations.append(sample_pred_relations) 103 | 104 | def compute_scores(self): 105 | print("Evaluation") 106 | 107 | print("") 108 | print("--- Terms (named term recognition (NTR)) ---") 109 | print("An term is considered correct if the term type and span is predicted correctly") 110 | print("") 111 | gt, pred = self._convert_by_setting(self._gt_terms, self._pred_terms, include_term_types=True) 112 | ner_eval = self._score(gt, pred, print_results=True) 113 | 114 | print("") 115 | print("Without named term classification (NTC)") 116 | print("A relation is considered correct if the relation type and the spans of the two " 117 | "related terms are predicted correctly (term type is not considered)") 118 | print("") 119 | gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_term_types=False) 120 | rel_eval = self._score(gt, pred, print_results=True) 121 | 122 | print("") 123 | print("With named term classification (NTC)") 124 | print("A relation is considered correct if the relation type and the two " 125 | "related terms are predicted correctly (in span and term type)") 126 | print("") 127 | gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_term_types=True) 128 | rel_nec_eval = self._score(gt, pred, print_results=True) 129 | 130 | return ner_eval, rel_eval, rel_nec_eval 131 | 132 | def store_predictions(self): 133 | predictions = [] 134 | 135 | for i, doc in enumerate(self._dataset.documents): 136 | tokens = doc.tokens 137 | pred_terms = self._pred_terms[i] 138 | pred_relations = self._pred_relations[i] 139 | 140 | # convert terms 141 | converted_terms = [] 142 | for term in pred_terms: 143 | term_span = term[:2] 144 | span_tokens = util.get_span_tokens(tokens, term_span) 145 | term_type = term[2].identifier 146 | # if term_type == 'None': 147 | # continue 148 | converted_term = dict(type=term_type, start=span_tokens[0].index, end=span_tokens[-1].index + 1) 149 | converted_terms.append(converted_term) 150 | converted_terms = sorted(converted_terms, key=lambda e: e['start']) 151 | 152 | # print('converted_terms: ', converted_terms) 153 | 154 | # convert relations 155 | converted_relations = [] 156 | for relation in pred_relations: 157 | head, tail = relation[:2] 158 | head_span, head_type = head[:2], head[2].identifier 159 | tail_span, tail_type = tail[:2], tail[2].identifier 160 | head_span_tokens = util.get_span_tokens(tokens, head_span) 161 | tail_span_tokens = util.get_span_tokens(tokens, tail_span) 162 | relation_type = relation[2].identifier 163 | 164 | converted_head = dict(type=head_type, start=head_span_tokens[0].index, 165 | end=head_span_tokens[-1].index + 1) 166 | converted_tail = dict(type=tail_type, start=tail_span_tokens[0].index, 167 | end=tail_span_tokens[-1].index + 1) 168 | 169 | # print(converted_tail) 170 | head_idx = converted_terms.index(converted_head) 171 | tail_idx = converted_terms.index(converted_tail) 172 | 173 | converted_relation = dict(type=relation_type, head=head_idx, tail=tail_idx) 174 | converted_relations.append(converted_relation) 175 | converted_relations = sorted(converted_relations, key=lambda r: r['head']) 176 | 177 | doc_predictions = dict(tokens=[t.phrase for t in tokens], terms=converted_terms, 178 | relations=converted_relations) 179 | predictions.append(doc_predictions) 180 | 181 | # store as json 182 | label, epoch = self._dataset_label, self._epoch 183 | with open(self._predictions_path % (label, epoch), 'w') as predictions_file: 184 | json.dump(predictions, predictions_file) 185 | 186 | def store_examples(self): 187 | if jinja2 is None: 188 | warnings.warn("Examples cannot be stored since Jinja2 is not installed.") 189 | return 190 | 191 | term_examples = [] 192 | rel_examples = [] 193 | rel_examples_nec = [] 194 | 195 | for i, doc in enumerate(self._dataset.documents): 196 | # terms 197 | term_example = self._convert_example(doc, self._gt_terms[i], self._pred_terms[i], 198 | include_term_types=True, to_html=self._term_to_html) 199 | term_examples.append(term_example) 200 | 201 | # relations 202 | # without term types 203 | rel_example = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i], 204 | include_term_types=False, to_html=self._rel_to_html) 205 | rel_examples.append(rel_example) 206 | 207 | # with term types 208 | rel_example_nec = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i], 209 | include_term_types=True, to_html=self._rel_to_html) 210 | rel_examples_nec.append(rel_example_nec) 211 | 212 | label, epoch = self._dataset_label, self._epoch 213 | 214 | # terms 215 | self._store_examples(term_examples[:self._example_count], 216 | file_path=self._examples_path % ('terms', label, epoch), 217 | template='term_examples.html') 218 | 219 | self._store_examples(sorted(term_examples[:self._example_count], 220 | key=lambda k: k['length']), 221 | file_path=self._examples_path % ('terms_sorted', label, epoch), 222 | template='term_examples.html') 223 | 224 | # without term types 225 | self._store_examples(rel_examples[:self._example_count], 226 | file_path=self._examples_path % ('rel', label, epoch), 227 | template='relation_examples.html') 228 | 229 | self._store_examples(sorted(rel_examples[:self._example_count], 230 | key=lambda k: k['length']), 231 | file_path=self._examples_path % ('rel_sorted', label, epoch), 232 | template='relation_examples.html') 233 | 234 | # with term types 235 | self._store_examples(rel_examples_nec[:self._example_count], 236 | file_path=self._examples_path % ('rel_nec', label, epoch), 237 | template='relation_examples.html') 238 | 239 | self._store_examples(sorted(rel_examples_nec[:self._example_count], 240 | key=lambda k: k['length']), 241 | file_path=self._examples_path % ('rel_nec_sorted', label, epoch), 242 | template='relation_examples.html') 243 | 244 | def _convert_gt(self, docs: List[Document]): 245 | for doc in docs: 246 | gt_relations = doc.relations 247 | gt_terms = doc.terms 248 | 249 | # convert ground truth relations and terms for precision/recall/f1 evaluation 250 | sample_gt_terms = [term.as_tuple() for term in gt_terms] 251 | sample_gt_relations = [rel.as_tuple() for rel in gt_relations] 252 | 253 | if self._no_overlapping: 254 | sample_gt_terms, sample_gt_relations = self._remove_overlapping(sample_gt_terms, 255 | sample_gt_relations) 256 | 257 | self._gt_terms.append(sample_gt_terms) 258 | self._gt_relations.append(sample_gt_relations) 259 | 260 | def _convert_pred_terms(self, pred_types: torch.tensor, pred_spans: torch.tensor, pred_scores: torch.tensor): 261 | converted_preds = [] 262 | 263 | for i in range(pred_types.shape[0]): 264 | label_idx = pred_types[i].item() 265 | term_type = self._input_reader.get_term_type(label_idx) 266 | 267 | start, end = pred_spans[i].tolist() 268 | score = pred_scores[i].item() 269 | 270 | converted_pred = (start, end, term_type, score) 271 | converted_preds.append(converted_pred) 272 | 273 | return converted_preds 274 | 275 | def _convert_pred_relations(self, pred_rel_types: torch.tensor, pred_term_spans: torch.tensor, 276 | pred_term_types: torch.tensor, pred_scores: torch.tensor): 277 | converted_rels = [] 278 | check = set() 279 | 280 | for i in range(pred_rel_types.shape[0]): 281 | label_idx = pred_rel_types[i].item() 282 | pred_rel_type = self._input_reader.get_relation_type(label_idx) 283 | pred_head_type_idx, pred_tail_type_idx = pred_term_types[i][0].item(), pred_term_types[i][1].item() 284 | pred_head_type = self._input_reader.get_term_type(pred_head_type_idx) 285 | pred_tail_type = self._input_reader.get_term_type(pred_tail_type_idx) 286 | score = pred_scores[i].item() 287 | 288 | spans = pred_term_spans[i] 289 | head_start, head_end = spans[0].tolist() 290 | tail_start, tail_end = spans[1].tolist() 291 | 292 | converted_rel = ((head_start, head_end, pred_head_type), 293 | (tail_start, tail_end, pred_tail_type), pred_rel_type) 294 | converted_rel = self._adjust_rel(converted_rel) 295 | 296 | if converted_rel not in check: 297 | check.add(converted_rel) 298 | converted_rels.append(tuple(list(converted_rel) + [score])) 299 | 300 | return converted_rels 301 | 302 | def _remove_overlapping(self, terms, relations): 303 | non_overlapping_terms = [] 304 | non_overlapping_relations = [] 305 | 306 | for term in terms: 307 | if not self._is_overlapping(term, terms): 308 | non_overlapping_terms.append(term) 309 | 310 | for rel in relations: 311 | e1, e2 = rel[0], rel[1] 312 | if not self._check_overlap(e1, e2): 313 | non_overlapping_relations.append(rel) 314 | 315 | return non_overlapping_terms, non_overlapping_relations 316 | 317 | def _is_overlapping(self, e1, terms): 318 | for e2 in terms: 319 | if self._check_overlap(e1, e2): 320 | return True 321 | 322 | return False 323 | 324 | def _check_overlap(self, e1, e2): 325 | if e1 == e2 or e1[1] <= e2[0] or e2[1] <= e1[0]: 326 | return False 327 | else: 328 | return True 329 | 330 | def _adjust_rel(self, rel: Tuple): 331 | adjusted_rel = rel 332 | if rel[-1].symmetric: 333 | head, tail = rel[:2] 334 | if tail[0] < head[0]: 335 | adjusted_rel = tail, head, rel[-1] 336 | 337 | return adjusted_rel 338 | 339 | def _convert_by_setting(self, gt: List[List[Tuple]], pred: List[List[Tuple]], 340 | include_term_types: bool = True, include_score: bool = False): 341 | assert len(gt) == len(pred) 342 | 343 | # either include or remove term types based on setting 344 | def convert(t): 345 | if not include_term_types: 346 | # remove term type and score for evaluation 347 | if type(t[0]) == int: # term 348 | c = [t[0], t[1], self._pseudo_term_type] 349 | else: # relation 350 | c = [(t[0][0], t[0][1], self._pseudo_term_type), 351 | (t[1][0], t[1][1], self._pseudo_term_type), t[2]] 352 | else: 353 | c = list(t[:3]) 354 | 355 | if include_score and len(t) > 3: 356 | # include prediction scores 357 | c.append(t[3]) 358 | 359 | return tuple(c) 360 | 361 | converted_gt, converted_pred = [], [] 362 | 363 | for sample_gt, sample_pred in zip(gt, pred): 364 | converted_gt.append([convert(t) for t in sample_gt]) 365 | converted_pred.append([convert(t) for t in sample_pred]) 366 | 367 | return converted_gt, converted_pred 368 | 369 | def _score(self, gt: List[List[Tuple]], pred: List[List[Tuple]], print_results: bool = False): 370 | assert len(gt) == len(pred) 371 | 372 | gt_flat = [] 373 | pred_flat = [] 374 | types = set() 375 | 376 | for (sample_gt, sample_pred) in zip(gt, pred): 377 | union = set() 378 | union.update(sample_gt) 379 | union.update(sample_pred) 380 | 381 | for s in union: 382 | if s in sample_gt: 383 | t = s[2] 384 | gt_flat.append(t.index) 385 | types.add(t) 386 | else: 387 | gt_flat.append(0) 388 | 389 | if s in sample_pred: 390 | t = s[2] 391 | pred_flat.append(t.index) 392 | types.add(t) 393 | else: 394 | pred_flat.append(0) 395 | 396 | metrics = self._compute_metrics(gt_flat, pred_flat, types, print_results) 397 | return metrics 398 | 399 | def _compute_metrics(self, gt_all, pred_all, types, print_results: bool = False): 400 | labels = [t.index for t in types] 401 | per_type = prfs(gt_all, pred_all, labels=labels, average=None) 402 | micro = prfs(gt_all, pred_all, labels=labels, average='micro')[:-1] 403 | macro = prfs(gt_all, pred_all, labels=labels, average='macro')[:-1] 404 | total_support = sum(per_type[-1]) 405 | 406 | if print_results: 407 | self._print_results(per_type, list(micro) + [total_support], list(macro) + [total_support], types) 408 | 409 | return [m * 100 for m in micro + macro] 410 | 411 | def _print_results(self, per_type: List, micro: List, macro: List, types: List): 412 | columns = ('type', 'precision', 'recall', 'f1-score', 'support') 413 | 414 | row_fmt = "%20s" + (" %12s" * (len(columns) - 1)) 415 | results = [row_fmt % columns, '\n'] 416 | 417 | metrics_per_type = [] 418 | for i, t in enumerate(types): 419 | metrics = [] 420 | for j in range(len(per_type)): 421 | metrics.append(per_type[j][i]) 422 | metrics_per_type.append(metrics) 423 | 424 | for m, t in zip(metrics_per_type, types): 425 | results.append(row_fmt % self._get_row(m, t.short_name)) 426 | results.append('\n') 427 | 428 | results.append('\n') 429 | 430 | # micro 431 | results.append(row_fmt % self._get_row(micro, 'micro')) 432 | results.append('\n') 433 | 434 | # macro 435 | results.append(row_fmt % self._get_row(macro, 'macro')) 436 | 437 | results_str = ''.join(results) 438 | print(results_str) 439 | 440 | def _get_row(self, data, label): 441 | row = [label] 442 | for i in range(len(data) - 1): 443 | row.append("%.2f" % (data[i] * 100)) 444 | row.append(data[3]) 445 | return tuple(row) 446 | 447 | def _convert_example(self, doc: Document, gt: List[Tuple], pred: List[Tuple], 448 | include_term_types: bool, to_html): 449 | encoding = doc.encoding 450 | 451 | gt, pred = self._convert_by_setting([gt], [pred], include_term_types=include_term_types, include_score=True) 452 | gt, pred = gt[0], pred[0] 453 | 454 | # get micro precision/recall/f1 scores 455 | if gt or pred: 456 | pred_s = [p[:3] for p in pred] # remove score 457 | precision, recall, f1 = self._score([gt], [pred_s])[:3] 458 | else: 459 | # corner case: no ground truth and no predictions 460 | precision, recall, f1 = [100] * 3 461 | 462 | scores = [p[-1] for p in pred] 463 | pred = [p[:-1] for p in pred] 464 | union = set(gt + pred) 465 | 466 | # true positives 467 | tp = [] 468 | # false negatives 469 | fn = [] 470 | # false positives 471 | fp = [] 472 | 473 | for s in union: 474 | type_verbose = s[2].verbose_name 475 | 476 | if s in gt: 477 | if s in pred: 478 | score = scores[pred.index(s)] 479 | tp.append((to_html(s, encoding), type_verbose, score)) 480 | else: 481 | fn.append((to_html(s, encoding), type_verbose, -1)) 482 | else: 483 | score = scores[pred.index(s)] 484 | fp.append((to_html(s, encoding), type_verbose, score)) 485 | 486 | tp = sorted(tp, key=lambda p: p[-1], reverse=True) 487 | fp = sorted(fp, key=lambda p: p[-1], reverse=True) 488 | 489 | text = self._prettify(self._text_encoder.decode(encoding)) 490 | return dict(text=text, tp=tp, fn=fn, fp=fp, precision=precision, recall=recall, f1=f1, length=len(doc.tokens)) 491 | 492 | def _term_to_html(self, term: Tuple, encoding: List[int]): 493 | start, end = term[:2] 494 | term_type = term[2].verbose_name 495 | 496 | tag_start = ' ' 497 | tag_start += '%s' % term_type 498 | 499 | ctx_before = self._text_encoder.decode(encoding[:start]) 500 | e1 = self._text_encoder.decode(encoding[start:end]) 501 | ctx_after = self._text_encoder.decode(encoding[end:]) 502 | 503 | html = ctx_before + tag_start + e1 + ' ' + ctx_after 504 | html = self._prettify(html) 505 | 506 | return html 507 | 508 | def _rel_to_html(self, relation: Tuple, encoding: List[int]): 509 | head, tail = relation[:2] 510 | head_tag = ' %s' 511 | tail_tag = ' %s' 512 | 513 | if head[0] < tail[0]: 514 | e1, e2 = head, tail 515 | e1_tag, e2_tag = head_tag % head[2].verbose_name, tail_tag % tail[2].verbose_name 516 | else: 517 | e1, e2 = tail, head 518 | e1_tag, e2_tag = tail_tag % tail[2].verbose_name, head_tag % head[2].verbose_name 519 | 520 | segments = [encoding[:e1[0]], encoding[e1[0]:e1[1]], encoding[e1[1]:e2[0]], 521 | encoding[e2[0]:e2[1]], encoding[e2[1]:]] 522 | 523 | ctx_before = self._text_encoder.decode(segments[0]) 524 | e1 = self._text_encoder.decode(segments[1]) 525 | ctx_between = self._text_encoder.decode(segments[2]) 526 | e2 = self._text_encoder.decode(segments[3]) 527 | ctx_after = self._text_encoder.decode(segments[4]) 528 | 529 | html = (ctx_before + e1_tag + e1 + ' ' 530 | + ctx_between + e2_tag + e2 + ' ' + ctx_after) 531 | html = self._prettify(html) 532 | 533 | return html 534 | 535 | def _prettify(self, text: str): 536 | text = text.replace('_start_', '').replace('_classify_', '').replace('', '').replace('⁇', '') 537 | text = text.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '') 538 | return text 539 | 540 | def _store_examples(self, examples: List[Dict], file_path: str, template: str): 541 | template_path = os.path.join(SCRIPT_PATH, 'templates', template) 542 | 543 | # read template 544 | with open(os.path.join(SCRIPT_PATH, template_path)) as f: 545 | template = jinja2.Template(f.read()) 546 | 547 | # write to disc 548 | template.stream(examples=examples).dump(file_path) 549 | -------------------------------------------------------------------------------- /Engine/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch.nn import DataParallel 7 | from torch.optim import Optimizer 8 | import transformers 9 | from torch.utils.data import DataLoader 10 | from transformers import AdamW, BertConfig 11 | from transformers import BertTokenizer 12 | 13 | from Engine import models 14 | from Engine import sampling 15 | from Engine import util 16 | from Engine.terms import Dataset 17 | from Engine.evaluator import Evaluator 18 | from Engine.input_reader import JsonInputReader, BaseInputReader 19 | from Engine.loss import SyMuxLoss, Loss 20 | from tqdm import tqdm 21 | from Engine.base_trainer import BaseTrainer 22 | 23 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 24 | 25 | 26 | class SyMuxTrainer(BaseTrainer): 27 | """ Joint term and relation extraction training and evaluation """ 28 | 29 | def __init__(self, args: argparse.Namespace): 30 | super().__init__(args) 31 | 32 | # byte-pair encoding 33 | self._tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path, 34 | do_lower_case=args.lowercase, 35 | cache_dir=args.cache_path) 36 | 37 | # path to export predictions to 38 | self._predictions_path = os.path.join(self._log_path, 'predictions_%s_epoch_%s.json') 39 | 40 | # path to export relation extraction examples to 41 | self._examples_path = os.path.join(self._log_path, 'examples_%s_%s_epoch_%s.html') 42 | 43 | def train(self, train_path: str, valid_path: str, types_path: str, input_reader_cls: BaseInputReader): 44 | args = self.args 45 | train_label, valid_label = 'train', 'valid' 46 | 47 | self._logger.info("Datasets: %s, %s" % (train_path, valid_path)) 48 | self._logger.info("Model type: %s" % args.model_type) 49 | 50 | # create log csv files 51 | self._init_train_logging(train_label) 52 | self._init_eval_logging(valid_label) 53 | 54 | # read datasets 55 | input_reader = input_reader_cls(types_path, self._tokenizer, args.neg_term_count, 56 | args.neg_relation_count, args.max_span_size, self._logger) 57 | input_reader.read({train_label: train_path, valid_label: valid_path}) 58 | self._log_datasets(input_reader) 59 | 60 | train_dataset = input_reader.get_dataset(train_label) 61 | train_sample_count = train_dataset.document_count 62 | updates_epoch = train_sample_count // args.train_batch_size 63 | updates_total = updates_epoch * args.epochs 64 | 65 | validation_dataset = input_reader.get_dataset(valid_label) 66 | 67 | self._logger.info("Updates per epoch: %s" % updates_epoch) 68 | self._logger.info("Updates total: %s" % updates_total) 69 | 70 | # create model 71 | model_class = models.get_model(self.args.model_type) 72 | 73 | # load model 74 | config = BertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path) 75 | util.check_version(config, model_class, self.args.model_path) 76 | 77 | config.model_version = model_class.VERSION 78 | model = model_class.from_pretrained(self.args.model_path, 79 | config=config, 80 | cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'), 81 | relation_types=input_reader.relation_type_count - 1, 82 | term_types=input_reader.term_type_count, 83 | max_pairs=self.args.max_pairs, 84 | prop_drop=self.args.prop_drop, 85 | size_embedding=self.args.size_embedding, 86 | freeze_transformer=self.args.freeze_transformer, 87 | args=self.args, 88 | beta=self.args.beta, 89 | alpha=self.args.alpha, 90 | sigma=self.args.sigma) 91 | 92 | model.to(self._device) 93 | 94 | # create optimizer 95 | optimizer_params = self._get_optimizer_params(model) 96 | optimizer = AdamW(optimizer_params, lr=args.lr, weight_decay=args.weight_decay, correct_bias=False) 97 | # create scheduler 98 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 99 | num_warmup_steps=args.lr_warmup * updates_total, 100 | num_training_steps=updates_total) 101 | 102 | # "AE", 103 | # "ALSC", 104 | # "AESC", 105 | # "OE", 106 | # "AOE", 107 | # "AOPE", 108 | # "TE" 109 | 110 | # create loss function 111 | pol_criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 112 | term_criterion = torch.nn.CrossEntropyLoss(reduction='none') 113 | compute_loss_AE = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 114 | compute_loss_OE = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 115 | compute_loss_AOE = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 116 | compute_loss_AOPE = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 117 | compute_loss_ALSC = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 118 | compute_loss_AESC = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 119 | compute_loss_TE = SyMuxLoss(pol_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 120 | 121 | # eval validation set 122 | if args.init_eval: 123 | self._eval(model, validation_dataset, input_reader, 0, updates_epoch) 124 | 125 | # train 126 | best_f1 = 0.0 127 | for epoch in range(args.epochs): 128 | # train epoch 129 | self._train_epoch(model, compute_loss_AE, compute_loss_OE, compute_loss_AOE, 130 | compute_loss_AOPE, compute_loss_ALSC, compute_loss_AESC, compute_loss_TE, 131 | optimizer, train_dataset, updates_epoch, epoch) 132 | 133 | # eval validation sets 134 | if not args.final_eval or (epoch == args.epochs - 1): 135 | rel_nec_eval = self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch) 136 | if best_f1 < rel_nec_eval[-1]: 137 | # save final model 138 | best_f1 = rel_nec_eval[-1] 139 | extra = dict(epoch=args.epochs, updates_epoch=updates_epoch, epoch_iteration=0) 140 | global_iteration = args.epochs * updates_epoch 141 | self._save_model(self._save_path, model, self._tokenizer, global_iteration, 142 | optimizer=optimizer if self.args.save_optimizer else None, save_as_best=True, 143 | extra=extra, include_iteration=False) 144 | 145 | self._logger.info("Logged in: %s" % self._log_path) 146 | self._logger.info("Saved in: %s" % self._save_path) 147 | self._close_summary_writer() 148 | 149 | def eval(self, dataset_path: str, types_path: str, input_reader_cls: BaseInputReader): 150 | args = self.args 151 | dataset_label = 'test' 152 | 153 | self._logger.info("Dataset: %s" % dataset_path) 154 | self._logger.info("Model: %s" % args.model_type) 155 | 156 | # create log csv files 157 | self._init_eval_logging(dataset_label) 158 | 159 | # read datasets 160 | input_reader = input_reader_cls(types_path, self._tokenizer, 161 | max_span_size=args.max_span_size, logger=self._logger) 162 | input_reader.read({dataset_label: dataset_path}) 163 | self._log_datasets(input_reader) 164 | 165 | # create model 166 | model_class = models.get_model(self.args.model_type) 167 | 168 | config = BertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path) 169 | util.check_version(config, model_class, self.args.model_path) 170 | 171 | model = model_class.from_pretrained(self.args.model_path, 172 | config=config, 173 | cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'), 174 | relation_types=input_reader.relation_type_count - 1, 175 | term_types=input_reader.term_type_count, 176 | max_pairs=self.args.max_pairs, 177 | prop_drop=self.args.prop_drop, 178 | size_embedding=self.args.size_embedding, 179 | freeze_transformer=self.args.freeze_transformer, 180 | args=self.args, 181 | beta=self.args.beta, 182 | alpha=self.args.alpha, 183 | sigma=self.args.sigma) 184 | 185 | model.to(self._device) 186 | 187 | # evaluate 188 | self._eval(model, input_reader.get_dataset(dataset_label), input_reader) 189 | 190 | self._logger.info("Logged in: %s" % self._log_path) 191 | self._close_summary_writer() 192 | 193 | def _train_epoch(self, model: torch.nn.Module, compute_loss_AE: Loss, compute_loss_OE: Loss, compute_loss_AOE: Loss, 194 | compute_loss_AOPE: Loss, compute_loss_ALSC: Loss, compute_loss_AESC: Loss, compute_loss_TE: Loss, 195 | optimizer: Optimizer, dataset: Dataset, 196 | updates_epoch: int, epoch: int): 197 | self._logger.info("Train epoch: %s" % epoch) 198 | 199 | # create data loader 200 | dataset.switch_mode(Dataset.TRAIN_MODE) 201 | data_loader = DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, drop_last=True, 202 | num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) 203 | 204 | model.zero_grad() 205 | 206 | iteration = 0 207 | total = dataset.document_count // self.args.train_batch_size 208 | for batch in tqdm(data_loader, total=total, desc='Train epoch %s' % epoch): 209 | model.train() 210 | batch = util.to_device(batch, self._device) 211 | 212 | # forward step 213 | AE_clf, OE_clf, AOE_clf, AOPE_clf, ALSC_clf, AESC_clf, TE_clf = model(encodings=batch['encodings'], context_masks=batch['context_masks'], 214 | term_masks=batch['term_masks'], term_sizes=batch['term_sizes'], 215 | term_spans=batch['term_spans'], term_types=batch['term_types'], 216 | relations=batch['rels'], rel_masks=batch['rel_masks'], 217 | simple_graph=batch['simple_graph'], graph=batch['graph'], 218 | relations3=batch['rels3'], rel_masks3=batch['rel_masks3'], 219 | pair_mask=batch['pair_mask'], pos=batch['pos']) 220 | 221 | # compute loss for each subtasks 222 | batch_loss_AE = compute_loss_AE.compute(term_logits=AE_clf, pol_logits=None, 223 | rel_types=batch['rel_types'], term_types=batch['term_types'], 224 | term_sample_masks=batch['term_sample_masks'], 225 | rel_sample_masks=batch['rel_sample_masks']) 226 | batch_loss_OE = compute_loss_OE.compute(term_logits=OE_clf, pol_logits=None, 227 | rel_types=batch['rel_types'], term_types=batch['term_types'], 228 | term_sample_masks=batch['term_sample_masks'], 229 | rel_sample_masks=batch['rel_sample_masks']) 230 | batch_loss_AOE = compute_loss_AOE.compute(term_logits=AOE_clf, pol_logits=None, 231 | rel_types=batch['rel_types'], term_types=batch['term_types'], 232 | term_sample_masks=batch['term_sample_masks'], 233 | rel_sample_masks=batch['rel_sample_masks']) 234 | batch_loss_AOPE = compute_loss_AOPE.compute(term_logits=AOPE_clf, pol_logits=None, 235 | rel_types=batch['rel_types'], term_types=batch['term_types'], 236 | term_sample_masks=batch['term_sample_masks'], 237 | rel_sample_masks=batch['rel_sample_masks']) 238 | batch_loss_ALSC = compute_loss_ALSC.compute(term_logits=ALSC_clf, pol_logits=None, 239 | rel_types=batch['rel_types'], term_types=batch['term_types'], 240 | term_sample_masks=batch['term_sample_masks'], 241 | rel_sample_masks=batch['rel_sample_masks']) 242 | batch_loss_AESC = compute_loss_AESC.compute(term_logits=AESC_clf, pol_logits=None, 243 | rel_types=batch['rel_types'], term_types=batch['term_types'], 244 | term_sample_masks=batch['term_sample_masks'], 245 | rel_sample_masks=batch['rel_sample_masks']) 246 | batch_loss_TE = compute_loss_TE.compute(term_logits=TE_clf, pol_logits=None, 247 | rel_types=batch['rel_types'], term_types=batch['term_types'], 248 | term_sample_masks=batch['term_sample_masks'], 249 | rel_sample_masks=batch['rel_sample_masks']) 250 | 251 | batch_loss = (batch_loss_AE + batch_loss_OE + batch_loss_AOE + batch_loss_AOPE + 252 | batch_loss_ALSC + batch_loss_AESC + batch_loss_TE) 253 | 254 | # logging 255 | iteration += 1 256 | global_iteration = epoch * updates_epoch + iteration 257 | 258 | if global_iteration % self.args.train_log_iter == 0: 259 | self._log_train(optimizer, batch_loss, epoch, iteration, global_iteration, dataset.label) 260 | 261 | return iteration 262 | 263 | def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInputReader, 264 | epoch: int = 0, updates_epoch: int = 0, iteration: int = 0): 265 | self._logger.info("Evaluate: %s" % dataset.label) 266 | 267 | if isinstance(model, DataParallel): 268 | # currently no multi GPU support during evaluation 269 | model = model.module 270 | 271 | # create evaluator 272 | evaluator = Evaluator(dataset, input_reader, self._tokenizer, 273 | self.args.rel_filter_threshold, self.args.no_overlapping, self._predictions_path, 274 | self._examples_path, self.args.example_count, epoch, dataset.label) 275 | 276 | # create data loader 277 | dataset.switch_mode(Dataset.EVAL_MODE) 278 | data_loader = DataLoader(dataset, batch_size=self.args.eval_batch_size, shuffle=False, drop_last=False, 279 | num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) 280 | 281 | with torch.no_grad(): 282 | model.eval() 283 | 284 | # iterate batches 285 | total = math.ceil(dataset.document_count / self.args.eval_batch_size) 286 | for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch): 287 | # move batch to selected device 288 | batch = util.to_device(batch, self._device) 289 | 290 | # run model (forward pass) 291 | AE_out, OE_out, AOE_out, AOPE_out, ALSC_out, AESC_out, TE_out = model(encodings=batch['encodings'], context_masks=batch['context_masks'], 292 | term_masks=batch['term_masks'], term_sizes=batch['term_sizes'], 293 | term_spans=batch['term_spans'], term_sample_masks=batch['term_sample_masks'], 294 | evaluate=True, simple_graph=batch['simple_graph'], graph=batch['graph'], 295 | pos=batch['pos']) # pos=batch['pos'] 296 | # term_clf, rel_clf, rels = result 297 | 298 | # evaluate batch 299 | evaluator.eval_batch(AE_out, OE_out, AOE_out, AOPE_out, ALSC_out, AESC_out, TE_out, batch) 300 | 301 | global_iteration = epoch * updates_epoch + iteration 302 | ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores() 303 | self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval, 304 | epoch, iteration, global_iteration, dataset.label) 305 | 306 | if self.args.store_predictions and not self.args.no_overlapping: 307 | evaluator.store_predictions() 308 | 309 | if self.args.store_examples: 310 | evaluator.store_examples() 311 | return rel_nec_eval 312 | 313 | def _get_optimizer_params(self, model): 314 | param_optimizer = list(model.named_parameters()) 315 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 316 | optimizer_params = [ 317 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 318 | 'weight_decay': self.args.weight_decay}, 319 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 320 | 321 | return optimizer_params 322 | 323 | def _log_train(self, optimizer: Optimizer, loss: float, epoch: int, 324 | iteration: int, global_iteration: int, label: str): 325 | # average loss 326 | avg_loss = loss / self.args.train_batch_size 327 | # get current learning rate 328 | lr = self._get_lr(optimizer)[0] 329 | 330 | # log to tensorboard 331 | self._log_tensorboard(label, 'loss', loss, global_iteration) 332 | self._log_tensorboard(label, 'loss_avg', avg_loss, global_iteration) 333 | self._log_tensorboard(label, 'lr', lr, global_iteration) 334 | 335 | # log to csv 336 | self._log_csv(label, 'loss', loss, epoch, iteration, global_iteration) 337 | self._log_csv(label, 'loss_avg', avg_loss, epoch, iteration, global_iteration) 338 | self._log_csv(label, 'lr', lr, epoch, iteration, global_iteration) 339 | 340 | def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: float, 341 | ner_prec_macro: float, ner_rec_macro: float, ner_f1_macro: float, 342 | 343 | rel_prec_micro: float, rel_rec_micro: float, rel_f1_micro: float, 344 | rel_prec_macro: float, rel_rec_macro: float, rel_f1_macro: float, 345 | 346 | rel_nec_prec_micro: float, rel_nec_rec_micro: float, rel_nec_f1_micro: float, 347 | rel_nec_prec_macro: float, rel_nec_rec_macro: float, rel_nec_f1_macro: float, 348 | epoch: int, iteration: int, global_iteration: int, label: str): 349 | 350 | # log to tensorboard 351 | self._log_tensorboard(label, 'eval/ner_prec_micro', ner_prec_micro, global_iteration) 352 | self._log_tensorboard(label, 'eval/ner_recall_micro', ner_rec_micro, global_iteration) 353 | self._log_tensorboard(label, 'eval/ner_f1_micro', ner_f1_micro, global_iteration) 354 | self._log_tensorboard(label, 'eval/ner_prec_macro', ner_prec_macro, global_iteration) 355 | self._log_tensorboard(label, 'eval/ner_recall_macro', ner_rec_macro, global_iteration) 356 | self._log_tensorboard(label, 'eval/ner_f1_macro', ner_f1_macro, global_iteration) 357 | 358 | self._log_tensorboard(label, 'eval/pol_prec_micro', rel_prec_micro, global_iteration) 359 | self._log_tensorboard(label, 'eval/pol_recall_micro', rel_rec_micro, global_iteration) 360 | self._log_tensorboard(label, 'eval/pol_f1_micro', rel_f1_micro, global_iteration) 361 | self._log_tensorboard(label, 'eval/pol_prec_macro', rel_prec_macro, global_iteration) 362 | self._log_tensorboard(label, 'eval/pol_recall_macro', rel_rec_macro, global_iteration) 363 | self._log_tensorboard(label, 'eval/pol_f1_macro', rel_f1_macro, global_iteration) 364 | 365 | self._log_tensorboard(label, 'eval/pol_nec_prec_micro', rel_nec_prec_micro, global_iteration) 366 | self._log_tensorboard(label, 'eval/pol_nec_recall_micro', rel_nec_rec_micro, global_iteration) 367 | self._log_tensorboard(label, 'eval/pol_nec_f1_micro', rel_nec_f1_micro, global_iteration) 368 | self._log_tensorboard(label, 'eval/pol_nec_prec_macro', rel_nec_prec_macro, global_iteration) 369 | self._log_tensorboard(label, 'eval/pol_nec_recall_macro', rel_nec_rec_macro, global_iteration) 370 | self._log_tensorboard(label, 'eval/pol_nec_f1_macro', rel_nec_f1_macro, global_iteration) 371 | 372 | # log to csv 373 | self._log_csv(label, 'eval', ner_prec_micro, ner_rec_micro, ner_f1_micro, 374 | ner_prec_macro, ner_rec_macro, ner_f1_macro, 375 | 376 | rel_prec_micro, rel_rec_micro, rel_f1_micro, 377 | rel_prec_macro, rel_rec_macro, rel_f1_macro, 378 | 379 | rel_nec_prec_micro, rel_nec_rec_micro, rel_nec_f1_micro, 380 | rel_nec_prec_macro, rel_nec_rec_macro, rel_nec_f1_macro, 381 | epoch, iteration, global_iteration) 382 | 383 | def _log_datasets(self, input_reader): 384 | self._logger.info("Relation type count: %s" % input_reader.relation_type_count) 385 | self._logger.info("Term type count: %s" % input_reader.term_type_count) 386 | 387 | self._logger.info("Terms:") 388 | for e in input_reader.term_types.values(): 389 | self._logger.info(e.verbose_name + '=' + str(e.index)) 390 | 391 | self._logger.info("Relations:") 392 | for r in input_reader.relation_types.values(): 393 | self._logger.info(r.verbose_name + '=' + str(r.index)) 394 | 395 | for k, d in input_reader.datasets.items(): 396 | self._logger.info('Dataset: %s' % k) 397 | self._logger.info("Document count: %s" % d.document_count) 398 | self._logger.info("Relation count: %s" % d.relation_count) 399 | self._logger.info("Term count: %s" % d.term_count) 400 | 401 | self._logger.info("Context size: %s" % input_reader.context_size) 402 | 403 | def _init_train_logging(self, label): 404 | self._add_dataset_logging(label, 405 | data={'lr': ['lr', 'epoch', 'iteration', 'global_iteration'], 406 | 'loss': ['loss', 'epoch', 'iteration', 'global_iteration'], 407 | 'loss_avg': ['loss_avg', 'epoch', 'iteration', 'global_iteration']}) 408 | 409 | def _init_eval_logging(self, label): 410 | self._add_dataset_logging(label, 411 | data={'eval': ['ner_prec_micro', 'ner_rec_micro', 'ner_f1_micro', 412 | 'ner_prec_macro', 'ner_rec_macro', 'ner_f1_macro', 413 | 'rel_prec_micro', 'rel_rec_micro', 'rel_f1_micro', 414 | 'rel_prec_macro', 'rel_rec_macro', 'rel_f1_macro', 415 | 'rel_nec_prec_micro', 'rel_nec_rec_micro', 'rel_nec_f1_micro', 416 | 'rel_nec_prec_macro', 'rel_nec_rec_macro', 'rel_nec_f1_macro', 417 | 'epoch', 'iteration', 'global_iteration']}) 418 | -------------------------------------------------------------------------------- /Engine/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn as nn 4 | from transformers import RobertaConfig 5 | from transformers import RobertaModel 6 | from transformers import RobertaTokenizer 7 | from Engine import sampling 8 | from Engine import util 9 | from Engine import encoder 10 | 11 | 12 | def get_token(h: torch.tensor, x: torch.tensor, token: int): 13 | """ Get specific token embedding (e.g. [CLS]) """ 14 | emb_size = h.shape[-1] 15 | 16 | token_h = h.view(-1, emb_size) 17 | flat = x.contiguous().view(-1) 18 | 19 | # get contextualized embedding of given token 20 | token_h = token_h[flat == token, :] 21 | 22 | return token_h 23 | 24 | 25 | def get_head_tail_rep(h, head_tail_index): 26 | """ 27 | 28 | :param h: torch.tensor [batch size, seq_len, feat_dim] 29 | :param head_tail_index: [batch size, term_num, 2] 30 | :return: 31 | """ 32 | res = [] 33 | batch_size = head_tail_index.size(0) 34 | term_num = head_tail_index.size(1) 35 | for b in range(batch_size): 36 | temp = [] 37 | for t in range(term_num): 38 | temp.append(torch.index_select(h[b], 0, head_tail_index[b][t]).view(-1)) 39 | res.append(torch.stack(temp, dim=0)) 40 | res = torch.stack(res) 41 | return res 42 | 43 | 44 | class SyMuxRoBERTa(RobertaModel): 45 | VERSION = '1.1' 46 | 47 | def __init__(self, config: RobertaConfig, cls_token: int, relation_types: int, term_types: int, 48 | size_embedding: int, prop_drop: float, freeze_transformer: bool, args, max_pairs: int = 100, 49 | beta: float = 0.3, alpha: float = 1.0, sigma: float = 1.0): 50 | super(SyMuxRoBERTa, self).__init__(config) 51 | 52 | # RoBERTa encoder 53 | self.roberta = RobertaModel(config) 54 | # Syntax encoder 55 | self.syntaxencoder = encoder.SyMuxEncoder(self.roberta, opt=args) 56 | 57 | # layers 58 | self.AE2OE = nn.Linear(config.hidden_size, config.hidden_size) 59 | self.pairing_classifier = nn.Linear(config.hidden_size * 6 + size_embedding * 2, relation_types) 60 | self.polarity_classifier = nn.Linear(config.hidden_size * 6 + size_embedding * 3, relation_types) 61 | self.term_classifier = nn.Linear(config.hidden_size * 8 + size_embedding, term_types) 62 | self.dep_linear = nn.Linear(config.hidden_size, relation_types) 63 | self.size_embeddings = nn.Embedding(100, size_embedding) 64 | self.dropout = nn.Dropout(prop_drop) 65 | 66 | # invariant multiplex kernel 67 | self.multiplex_kernel = nn.Linear(config.hidden_size, config.hidden_size) 68 | 69 | # variant private feats 70 | self.private_aoe = nn.Linear(config.hidden_size, config.hidden_size) 71 | self.private_aope = nn.Linear(config.hidden_size, config.hidden_size) 72 | self.private_alsc = nn.Linear(config.hidden_size, config.hidden_size) 73 | self.private_alec = nn.Linear(config.hidden_size, config.hidden_size) 74 | self.private_te = nn.Linear(config.hidden_size, config.hidden_size) 75 | 76 | # tagging_1_1d 77 | self.lb_cls_linear1_1d = nn.Linear(config.hidden_size, args.class_num1) 78 | # tagging_1_2d 79 | self.lb_cls_linear1_2d = nn.Linear(config.hidden_size, args.class_num1 * args.class_num1) 80 | # tagging_2_2d 81 | self.lb_cls_linear2_2d = nn.Linear(config.hidden_size, args.class_num2 * args.class_num2) 82 | 83 | self._cls_token = cls_token 84 | self._relation_types = relation_types 85 | self._term_types = term_types 86 | self._max_pairs = max_pairs 87 | self._beta = beta 88 | self._alpha = alpha 89 | self._sigma = sigma 90 | 91 | self.init_weights() 92 | 93 | if freeze_transformer: 94 | print("Freeze transformer weights") 95 | 96 | for param in self.bert.parameters(): 97 | param.requires_grad = False 98 | 99 | def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor, 100 | term_sizes: torch.tensor, term_spans: torch.tensor, term_types: torch.tensor, 101 | relations: torch.tensor, rel_masks: torch.tensor, 102 | simple_graph: torch.tensor, graph: torch.tensor, 103 | pol_claz: torch.tensor, rel_masks3: torch.tensor, pair_mask: torch.tensor, 104 | pos: torch.tensor = None): 105 | context_masks = context_masks.float() 106 | 107 | h, dep_output = self.syntaxencoder(input_ids=encodings, input_masks=context_masks, simple_graph=simple_graph, 108 | graph=graph, pos=pos) 109 | 110 | batch_size = encodings.shape[0] 111 | 112 | # classify terms 113 | size_embeddings = self.size_embeddings(term_sizes) # embed term candidate sizes 114 | term_clf, term_spans_pool = self._classify_terms(encodings, h, term_masks, size_embeddings) 115 | 116 | # pairing 117 | h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) 118 | pol_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to( 119 | self.pairing_classifier.weight.device) 120 | 121 | # "AE", 122 | # "OE", 123 | # get term representation 124 | a_term_repr, mapping_list = self.get_term_repr(term_spans, term_types, dep_output) 125 | o_term_repr = self.AE2OE(a_term_repr) 126 | AE_clf = self.lb_cls_linear1_1d(a_term_repr) 127 | OE_clf = self.lb_cls_linear1_1d(o_term_repr) 128 | 129 | # obtain pairing rep 130 | rep_m = self._syntax_guided_pairing(a_term_repr, o_term_repr, 131 | term_spans_pool, 132 | size_embeddings, 133 | relations, rel_masks, 134 | h_large, 135 | pol_claz, mapping_list) 136 | 137 | # representation multiplexing 138 | # "AOE", 139 | AOE_rep = self.multiplex_kernel(rep_m) + self.private_aoe(pol_clf) 140 | AOE_clf = self.lb_cls_linear1_2d(AOE_rep) 141 | 142 | # "AOPE", 143 | AOPE_rep = self.multiplex_kernel(AOE_rep) + self.private_aope(pol_clf) 144 | AOPE_clf = self.lb_cls_linear1_2d(AOPE_rep) 145 | 146 | # "ALSC", 147 | ALSC_rep = self.multiplex_kernel(AOPE_rep) + self.private_alsc(pol_clf) 148 | ALSC_clf = self.lb_cls_linear1_2d(AOPE_rep) 149 | 150 | # "AESC", 151 | AESC_rep = self.multiplex_kernel(ALSC_rep) + self.private_aesc(pol_clf) 152 | AESC_clf = self.lb_cls_linear2_2d(AESC_rep) 153 | 154 | # "TE" 155 | TE_rep = self.multiplex_kernel(AESC_rep) + self.private_te(pol_clf) 156 | TE_clf = self.lb_cls_linear2_2d(TE_rep) 157 | 158 | return AE_clf, OE_clf, AOE_clf, AOPE_clf, ALSC_clf, AESC_clf, TE_clf 159 | 160 | # def calcul_pol_log(self, ): 161 | # 162 | # for i in range(0, relations.shape[1], self._max_pairs): 163 | # chunk_rel_logits, chunk_rel_clf3, chunk_dep_score 164 | # # classify candidates 165 | # chunk_rel_logits3 = self.polarity_classifier(rel_repr) 166 | # 167 | # chunk_rel_clf3 = chunk_rel_logits3.view(batch_size, p_num, p_tris, -1) 168 | # chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3) 169 | # 170 | # chunk_rel_clf3 = torch.sum(chunk_rel_clf3, dim=2) 171 | # chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3) 172 | # 173 | # # return chunk_rel_logits, chunk_rel_clf3, batch_dep_score 174 | # # apply sigmoid 175 | # chunk_rel_clf = torch.sigmoid(chunk_rel_logits) 176 | # chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score 177 | # pol_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf 178 | # 179 | # max_clf = torch.full_like(pol_clf, torch.max(pol_clf).item()) 180 | # min_clf = torch.full_like(pol_clf, torch.min(pol_clf).item()) 181 | # inifite = torch.full_like(pol_clf, 1e-18) 182 | # pol_clf = torch.div(pol_clf - min_clf + inifite, max_clf - min_clf + inifite) 183 | 184 | def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor, 185 | term_sizes: torch.tensor, term_spans: torch.tensor, term_sample_masks: torch.tensor, 186 | simple_graph: torch.tensor, graph: torch.tensor, pos: torch.tensor = None): 187 | # get contextualized token embeddings from last transformer layer 188 | context_masks = context_masks.float() 189 | h, dep_output = self.syntaxencoder(input_ids=encodings, input_masks=context_masks, simple_graph=simple_graph, 190 | graph=graph, pos=pos) 191 | 192 | batch_size = encodings.shape[0] 193 | ctx_size = context_masks.shape[-1] 194 | 195 | # classify terms 196 | size_embeddings = self.size_embeddings(term_sizes) # embed term candidate sizes 197 | term_clf, term_spans_pool = self._classify_terms(encodings, h, term_masks, size_embeddings) 198 | 199 | # ignore term candidates that do not constitute an actual term for relations (based on classifier) 200 | relations, rel_masks, rel_sample_masks, pol_claz, rel_masks3, \ 201 | rel_sample_masks3, pair_mask, term_repr, mapping_list = self._filter_terms(term_clf, term_spans, 202 | term_sample_masks, 203 | ctx_size, dep_output) 204 | 205 | rel_sample_masks = rel_sample_masks.float().unsqueeze(-1) 206 | # h = self.rel_bert(input_ids=encodings, attention_mask=context_masks)[0] 207 | h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) 208 | pol_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to( 209 | self.pairing_classifier.weight.device) 210 | 211 | # # obtain pair logits 212 | # # chunk processing to reduce memory usage 213 | # for i in range(0, relations.shape[1], self._max_pairs): 214 | # # classify relation candidates 215 | # chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._syntax_guided_pairing(term_repr, 216 | # term_spans_pool, 217 | # size_embeddings, 218 | # relations, rel_masks, 219 | # h_large, 220 | # pol_claz, mapping_list) 221 | # # apply sigmoid 222 | # chunk_rel_clf = torch.sigmoid(chunk_rel_logits) 223 | # chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score 224 | # rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf 225 | # 226 | # max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item()) 227 | # min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item()) 228 | # inifite = torch.full_like(rel_clf, 1e-18) 229 | # rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite) 230 | # 231 | # rel_clf = rel_clf * rel_sample_masks # mask 232 | 233 | # # pairing 234 | # h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) 235 | # pol_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to( 236 | # self.pairing_classifier.weight.device) 237 | 238 | # "AE", 239 | # "OE", 240 | # get term representation 241 | a_term_repr, mapping_list = self.get_term_repr(term_spans, None, dep_output) 242 | o_term_repr = self.AE2OE(a_term_repr) 243 | AE_clf = self.lb_cls_linear1_1d(a_term_repr) 244 | AE_out = torch.softmax(AE_clf, dim=2) 245 | OE_clf = self.lb_cls_linear1_1d(o_term_repr) 246 | OE_out = torch.softmax(OE_clf, dim=2) 247 | 248 | # obtain pairing rep 249 | rep_m = self._syntax_guided_pairing(a_term_repr, o_term_repr, 250 | term_spans_pool, 251 | size_embeddings, 252 | relations, rel_masks, 253 | h_large, 254 | pol_claz, mapping_list) 255 | 256 | # representation multiplexing 257 | # "AOE", 258 | AOE_rep = self.multiplex_kernel(rep_m) + self.private_aoe(pol_clf) 259 | AOE_clf = self.lb_cls_linear1_2d(AOE_rep) 260 | AOE_out = torch.softmax(AOE_clf, dim=2) 261 | 262 | # "AOPE", 263 | AOPE_rep = self.multiplex_kernel(AOE_rep) + self.private_aope(pol_clf) 264 | AOPE_clf = self.lb_cls_linear1_2d(AOPE_rep) 265 | AOPE_out = torch.softmax(AOPE_clf, dim=2) 266 | 267 | # "ALSC", 268 | ALSC_rep = self.multiplex_kernel(AOPE_rep) + self.private_alsc(pol_clf) 269 | ALSC_clf = self.lb_cls_linear1_2d(AOPE_rep) 270 | ALSC_out = torch.softmax(ALSC_clf, dim=2) 271 | 272 | # "AESC", 273 | AESC_rep = self.multiplex_kernel(ALSC_rep) + self.private_aesc(pol_clf) 274 | AESC_clf = self.lb_cls_linear2_2d(AESC_rep) 275 | AESC_out = torch.softmax(AESC_clf, dim=2) 276 | 277 | # "TE" 278 | TE_rep = self.multiplex_kernel(AESC_rep) + self.private_te(pol_clf) 279 | TE_clf = self.lb_cls_linear2_2d(TE_rep) 280 | TE_out = torch.softmax(TE_clf, dim=2) 281 | 282 | return AE_out, OE_out, AOE_out, AOPE_out, ALSC_out, AESC_out, TE_out 283 | 284 | def _classify_terms(self, encodings, h, term_masks, size_embeddings): 285 | # max pool term candidate terms 286 | m = (term_masks.unsqueeze(-1) == 0).float() * (-1e30) 287 | term_spans_pool = m + h.unsqueeze(1).repeat(1, term_masks.shape[1], 1, 1) 288 | term_spans_pool = term_spans_pool.max(dim=2)[0] 289 | 290 | # get cls token as candidate context representation 291 | term_ctx = get_token(h, encodings, self._cls_token) 292 | 293 | m = term_masks.to(dtype=torch.long) 294 | k = torch.tensor(np.arange(0, term_masks.size(-1)), dtype=torch.long) 295 | k = k.unsqueeze(0).unsqueeze(0).repeat(term_masks.size(0), term_masks.size(1), 1).to(m.device) 296 | mk = torch.mul(m, k) # element-wise multiply 297 | mk_max = torch.argmax(mk, dim=-1, keepdim=True) 298 | mk_min = torch.argmin(mk, dim=-1, keepdim=True) 299 | mk = torch.cat([mk_min, mk_max], dim=-1) 300 | head_tail_rep = get_head_tail_rep(h, mk) # [batch size, term_num, bert_dim*2) 301 | 302 | # create candidate representations including context, max pooled span and size embedding 303 | term_repr = torch.cat([term_ctx.unsqueeze(1).repeat(1, term_spans_pool.shape[1], 1), 304 | term_spans_pool, size_embeddings, head_tail_rep], dim=2) 305 | term_repr = self.dropout(term_repr) 306 | 307 | # classify term candidates 308 | term_clf = self.term_classifier(term_repr) 309 | 310 | return term_clf, term_spans_pool 311 | 312 | def _syntax_guided_pairing(self, a_term_repr, o_term_repr, term_spans_repr, size_embeddings, relations, rel_masks, 313 | pol_claz, rel_to_span): 314 | batch_size = relations.shape[0] 315 | feat_dim = a_term_repr.size(-1) 316 | 317 | spans_matrix = torch.cat([a_term_repr, o_term_repr], dim=2) 318 | 319 | # create chunks if necessary 320 | # if relations.shape[1] > self._max_pairs: 321 | # # relations = relations[:, chunk_start:chunk_start + self._max_pairs] 322 | # # rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs] 323 | # h = h[:, :relations.shape[1], :] 324 | 325 | def get_span_idx(mapping_list, idx1, idx2): 326 | for x in mapping_list: 327 | if idx1 == x[0][0] and idx2 == x[0][1]: 328 | return x[1][0], x[1][1] 329 | 330 | batch_dep_score = [] 331 | for i in range(batch_size): 332 | rela = relations[i] 333 | dep_score_list = [] 334 | r_2_s = rel_to_span[i] 335 | for r in rela: 336 | i1, i2 = r[0].item(), r[1].item() 337 | idx1, idx2 = get_span_idx(r_2_s, i1, i2) 338 | try: 339 | feat = spans_matrix[i][idx1][idx2] 340 | except: 341 | print('Out of bundary', spans_matrix.size(), i, i1, i2) 342 | feat = torch.zeros(feat_dim) 343 | dep_socre = self.dep_linear(feat).item() 344 | dep_score_list.append([dep_socre]) 345 | batch_dep_score.append(dep_score_list) 346 | 347 | batch_dep_score = torch.sigmoid( 348 | torch.tensor(batch_dep_score).to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))) 349 | 350 | # get pairs of term candidate representations 351 | term_pairs = util.batch_index(term_spans_repr, relations) 352 | term_pairs = term_pairs.view(batch_size, term_pairs.shape[1], -1) 353 | 354 | # get corresponding size embeddings 355 | size_pair_embeddings = util.batch_index(size_embeddings, relations) 356 | size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1) 357 | 358 | # relation context (context between term candidate pair) 359 | # mask non term candidate tokens 360 | m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1) 361 | rel_ctx = m 362 | # max pooling 363 | rel_ctx = rel_ctx.max(dim=2)[0] 364 | # set the context vector of neighboring or adjacent term candidates to zero 365 | rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0 366 | 367 | # create relation candidate representations including context, max pooled term candidate pairs 368 | # and corresponding size embeddings 369 | rel_repr = torch.cat([rel_ctx, term_pairs, size_pair_embeddings], dim=2) 370 | rel_repr = self.dropout(rel_repr) 371 | # classify relation candidates 372 | chunk_rel_logits = self.pairing_classifier(rel_repr) 373 | 374 | # if pol_claz.shape[1] > self._max_pairs: 375 | # pol_claz = pol_claz[:, chunk_start:chunk_start + self._max_pairs] 376 | # rel_masks3 = rel_masks3[:, chunk_start:chunk_start + self._max_pairs] 377 | 378 | p_num = pol_claz.size(1) 379 | p_tris = pol_claz.size(2) 380 | 381 | pol_claz = pol_claz.view(batch_size, -1, 3) 382 | 383 | # get three pairs candidata representations 384 | term_pairs3 = util.batch_index(term_spans_repr, pol_claz) 385 | term_pairs3 = term_pairs3.view(batch_size, term_pairs3.shape[1], -1) 386 | 387 | size_pair_embeddings3 = util.batch_index(size_embeddings, pol_claz) 388 | size_pair_embeddings3 = size_pair_embeddings3.view(batch_size, size_pair_embeddings3.shape[1], -1) 389 | 390 | rel_repr = torch.cat([term_pairs3, size_pair_embeddings3], dim=2) 391 | rel_repr = self.dropout(rel_repr) 392 | return rel_repr 393 | 394 | def _filter_terms(self, term_clf, term_spans, term_sample_masks, ctx_size, token_repr): 395 | batch_size = term_clf.shape[0] 396 | feat_dim = token_repr.size(-1) 397 | term_logits_max = term_clf.argmax(dim=-1) * term_sample_masks.long() # get term type (including none) 398 | batch_relations = [] 399 | batch_rel_masks = [] 400 | batch_rel_sample_masks = [] 401 | 402 | batch_pol_claz = [] 403 | batch_rel_masks3 = [] 404 | batch_rel_sample_masks3 = [] 405 | batch_pair_mask = [] 406 | 407 | batch_span_repr = [] 408 | batch_rel_to_span = [] 409 | 410 | for i in range(batch_size): 411 | rels = [] 412 | rel_masks = [] 413 | sample_masks = [] 414 | rels3 = [] 415 | rel_masks3 = [] 416 | sample_masks3 = [] 417 | 418 | span_repr = [] 419 | rel_to_span = [] 420 | 421 | # get spans classified as terms 422 | non_zero_indices = (term_logits_max[i] != 0).nonzero().view(-1) 423 | non_zero_spans = term_spans[i][non_zero_indices].tolist() 424 | non_zero_indices = non_zero_indices.tolist() 425 | 426 | # create relations and masks 427 | pair_mask = [] 428 | for idx1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): 429 | temp = [] 430 | for idx2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)): 431 | if i1 != i2: 432 | rels.append((i1, i2)) 433 | rel_masks.append(sampling.create_rel_mask(s1, s2, ctx_size)) 434 | sample_masks.append(1) 435 | p_rels3 = [] 436 | p_masks3 = [] 437 | for i3, s3 in zip(non_zero_indices, non_zero_spans): 438 | if i1 != i2 and i1 != i3 and i2 != i3: 439 | p_rels3.append((i1, i2, i3)) 440 | p_masks3.append(sampling.create_rel_mask3(s1, s2, s3, ctx_size)) 441 | sample_masks3.append(1) 442 | if len(p_rels3) > 0: 443 | rels3.append(p_rels3) 444 | rel_masks3.append(p_masks3) 445 | pair_mask.append(1) 446 | else: 447 | rels3.append([(i1, i2, 0)]) 448 | rel_masks3.append([sampling.create_rel_mask3(s1, s2, (0, 0), ctx_size)]) 449 | pair_mask.append(0) 450 | rel_to_span.append([[i1, i2], [idx1, idx2]]) 451 | feat = \ 452 | torch.max(token_repr[i, s1[0]: s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view(-1, feat_dim), 453 | dim=0)[0] 454 | temp.append(feat) 455 | span_repr.append(temp) 456 | 457 | if not rels: 458 | # case: no more than two spans classified as terms 459 | batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long)) 460 | batch_rel_masks.append(torch.tensor([[0] * ctx_size], dtype=torch.bool)) 461 | batch_rel_sample_masks.append(torch.tensor([0], dtype=torch.bool)) 462 | batch_span_repr.append(torch.tensor([[[0] * feat_dim]], dtype=torch.float)) 463 | batch_rel_to_span.append([[[0, 0], [0, 0]]]) 464 | else: 465 | # case: more than two spans classified as terms 466 | batch_relations.append(torch.tensor(rels, dtype=torch.long)) 467 | batch_rel_masks.append(torch.stack(rel_masks)) 468 | batch_rel_sample_masks.append(torch.tensor(sample_masks, dtype=torch.bool)) 469 | batch_span_repr.append(torch.stack([torch.stack(x) for x in span_repr])) 470 | batch_rel_to_span.append(rel_to_span) 471 | 472 | if not rels3: 473 | batch_pol_claz.append(torch.tensor([[[0, 0, 0]]], dtype=torch.long)) 474 | batch_rel_masks3.append(torch.tensor([[0] * ctx_size], dtype=torch.bool)) 475 | batch_rel_sample_masks3.append(torch.tensor([0], dtype=torch.bool)) 476 | batch_pair_mask.append(torch.tensor([0], dtype=torch.bool)) 477 | 478 | else: 479 | max_tri = max([len(x) for x in rels3]) 480 | # print(max_tri) 481 | for idx, r in enumerate(rels3): 482 | r_len = len(r) 483 | if r_len < max_tri: 484 | rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len)) 485 | rel_masks3[idx].extend( 486 | [rel_masks3[idx][0]] * (max_tri - r_len)) 487 | batch_pol_claz.append(torch.tensor(rels3, dtype=torch.long)) 488 | batch_rel_masks3.append(torch.stack([torch.stack(x) for x in rel_masks3])) 489 | batch_rel_sample_masks3.append(torch.tensor(sample_masks3, dtype=torch.bool)) 490 | batch_pair_mask.append(torch.tensor(pair_mask, dtype=torch.bool)) 491 | 492 | # stack 493 | device = self.pairing_classifier.weight.device 494 | batch_relations = util.padded_stack(batch_relations).to(device) 495 | batch_rel_masks = util.padded_stack(batch_rel_masks).to(device) 496 | batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(device) 497 | batch_span_repr = util.padded_stack(batch_span_repr).to(device) 498 | 499 | batch_pol_claz = util.padded_stack(batch_pol_claz).to(device) 500 | batch_rel_masks3 = util.padded_stack(batch_rel_masks3).to(device) 501 | batch_rel_sample_masks3 = util.padded_stack(batch_rel_sample_masks3).to(device) 502 | batch_pair_mask = util.padded_stack(batch_pair_mask).to(device) 503 | 504 | return batch_relations, batch_rel_masks, batch_rel_sample_masks, \ 505 | batch_pol_claz, batch_rel_masks3, batch_rel_sample_masks3, batch_pair_mask, batch_span_repr, batch_rel_to_span 506 | 507 | def get_term_repr(self, term_spans, term_types, token_repr): 508 | """ 509 | 510 | :param term_spans: [batch size, span_num, 2] 511 | :param term_types: [batch size, span_num] 512 | :param token_repr: [batch size, seq_len, seq_len, feat_dim] 513 | :return: [batch size, span_num, span_num, feat_dim] 514 | """ 515 | batch_size = term_spans.size(0) 516 | feat_dim = token_repr.size(-1) 517 | batch_span_repr = [] 518 | batch_mapping_list = [] 519 | for i in range(batch_size): 520 | span_repr = [] 521 | mapping_list = [] 522 | # get target spans as aspect term or opinion term 523 | non_zero_indices = (term_types[i] != 0).nonzero().view(-1) 524 | non_zero_spans = term_spans[i][non_zero_indices].tolist() 525 | non_zero_indices = non_zero_indices.tolist() 526 | for x1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): 527 | temp = [] 528 | for x2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)): 529 | feat = \ 530 | torch.max(token_repr[i, s1[0]: s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view(-1, feat_dim), 531 | dim=0)[0] 532 | temp.append(feat) 533 | mapping_list.append([[i1, i2], [x1, x2]]) 534 | 535 | span_repr.append(torch.stack(temp)) 536 | batch_span_repr.append(torch.stack(span_repr)) 537 | batch_mapping_list.append(mapping_list) 538 | 539 | device = self.pairing_classifier.weight.device 540 | batch_span_repr = util.padded_stack(batch_span_repr).to(device) 541 | 542 | return batch_span_repr, batch_mapping_list 543 | 544 | def forward(self, *args, evaluate=False, **kwargs): 545 | if not evaluate: 546 | return self._forward_train(*args, **kwargs) 547 | else: 548 | return self._forward_eval(*args, **kwargs) 549 | 550 | 551 | # Model access 552 | 553 | _MODELS = { 554 | 'Engine': SyMuxRoBERTa, 555 | } 556 | 557 | 558 | def get_model(name): 559 | return _MODELS[name] 560 | --------------------------------------------------------------------------------