├── SynFue ├── __init__.py ├── opt.py ├── loss.py ├── cross_attn.py ├── Encoder.py ├── base_trainer.py ├── util.py ├── templates │ ├── term_examples.html │ └── relation_examples.html ├── input_reader.py ├── terms.py ├── sampling.py ├── trainer.py ├── evaluator.py └── models.py ├── data ├── datasets │ ├── towe │ │ └── 16res │ │ │ ├── types.json │ │ │ ├── pos_vocab.txt │ │ │ ├── dep_type_vocab.txt │ │ │ └── char_vocab.txt │ ├── data_config.yaml │ ├── str_utils.py │ ├── io_utils.py │ ├── vocab.py │ └── preprocess.py └── log │ └── towe_16res │ └── eval_valid.csv ├── configs ├── 16res_eval.conf └── 16res_train.conf ├── synfue.py ├── README.md ├── config_reader.py └── args.py /SynFue/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/datasets/towe/16res/types.json: -------------------------------------------------------------------------------- 1 | {"terms": {"Asp": {"short": "Asp", "verbose": "Aspect"}, "Opi": {"short": "Opi", "verbose": "Opinion"}}, "relations": {"Pair": {"short": "Pair", "verbose": "Pair", "symmetric": true}}} -------------------------------------------------------------------------------- /SynFue/opt.py: -------------------------------------------------------------------------------- 1 | # optional packages 2 | 3 | try: 4 | import tensorboardx 5 | except ImportError: 6 | tensorboardx = None 7 | 8 | 9 | try: 10 | import jinja2 11 | except ImportError: 12 | jinja2 = None 13 | -------------------------------------------------------------------------------- /data/datasets/data_config.yaml: -------------------------------------------------------------------------------- 1 | ori_data_dir: 'original/' 2 | new_data_dir: 'towe/' 3 | 4 | vocab_dir: 'data/vocab' 5 | token_vocab_file: 'token_vocab.txt' 6 | char_vocab_file: 'char_vocab.txt' 7 | pos_vocab_file: 'pos_vocab.txt' 8 | dep_type_vocab_file: 'dep_type_vocab.txt' 9 | prd_type_vocab_file: 'prd_type_vocab.txt' 10 | role_type_vocab_file: 'role_type_vocab.txt' 11 | action_vocab_file: 'action_vocab.txt' 12 | 13 | normalize_digits: false 14 | lower_case: True 15 | random_seed: 2942435 -------------------------------------------------------------------------------- /configs/16res_eval.conf: -------------------------------------------------------------------------------- 1 | [1] 2 | label = towe_16res 3 | model_type = synfue 4 | model_path = data/models/towe_16res 5 | tokenizer_path = data/models/towe_16res 6 | dataset_path = data/datasets/towe/16res/test.json 7 | types_path = data/datasets/towe/16res/types.json 8 | eval_batch_size = 1 9 | rel_filter_threshold = 0.4 10 | size_embedding = 25 11 | prop_drop = 0.1 12 | max_span_size = 10 13 | store_predictions = true 14 | store_examples = true 15 | sampling_processes = 4 16 | sampling_limit = 100 17 | max_pairs = 1000 18 | log_path = data/log/ 19 | alpha = 1.0 20 | beta = 0.4 21 | sigma = 1.0 22 | -------------------------------------------------------------------------------- /data/datasets/towe/16res/pos_vocab.txt: -------------------------------------------------------------------------------- 1 | special_token_size 1 2 | token_size 44 3 | singleton_size 0 4 | singleton_max_count 1 5 | 0 *PAD* 1 6 | 1 NN 3380 7 | 2 JJ 2407 8 | 3 DT 2164 9 | 4 IN 1507 10 | 5 . 1390 11 | 6 RB 1390 12 | 7 PRP 1005 13 | 8 CC 947 14 | 9 VBD 924 15 | 10 , 872 16 | 11 VBZ 675 17 | 12 NNS 674 18 | 13 VB 541 19 | 14 NNP 532 20 | 15 VBP 491 21 | 16 VBN 303 22 | 17 PRP$ 276 23 | 18 TO 192 24 | 19 MD 177 25 | 20 VBG 162 26 | 21 CD 145 27 | 22 HYPH 140 28 | 23 JJS 125 29 | 24 WDT 113 30 | 25 : 102 31 | 26 -RRB- 77 32 | 27 -LRB- 71 33 | 28 RP 61 34 | 29 JJR 49 35 | 30 WP 47 36 | 31 WRB 45 37 | 32 POS 43 38 | 33 SYM 41 39 | 34 UH 33 40 | 35 '' 26 41 | 36 EX 26 42 | 37 $ 25 43 | 38 PDT 25 44 | 39 RBS 19 45 | 40 RBR 16 46 | 41 FW 15 47 | 42 NFP 10 48 | 43 NNPS 7 49 | 44 `` 5 50 | -------------------------------------------------------------------------------- /data/datasets/towe/16res/dep_type_vocab.txt: -------------------------------------------------------------------------------- 1 | special_token_size 2 2 | token_size 39 3 | singleton_size 0 4 | singleton_max_count 1 5 | 0 *PAD* 1 6 | 1 SELF_LOOP 1 7 | 2 punct 2685 8 | 3 nsubj 2107 9 | 4 det 2045 10 | 5 advmod 1452 11 | 6 ROOT 1408 12 | 7 amod 1393 13 | 8 case 1245 14 | 9 cop 1152 15 | 10 conj 1020 16 | 11 cc 933 17 | 12 compound 758 18 | 13 obj 734 19 | 14 obl 592 20 | 15 dep 547 21 | 16 nmod 485 22 | 17 mark 429 23 | 18 aux 392 24 | 19 nmod:poss 297 25 | 20 parataxis 207 26 | 21 advcl 207 27 | 22 ccomp 172 28 | 23 xcomp 168 29 | 24 nummod 104 30 | 25 aux:pass 103 31 | 26 acl:relcl 103 32 | 27 appos 91 33 | 28 nsubj:pass 84 34 | 29 acl 76 35 | 30 compound:prt 65 36 | 31 obl:npmod 35 37 | 32 obl:tmod 34 38 | 33 det:predet 27 39 | 34 fixed 27 40 | 35 discourse 25 41 | 36 csubj 23 42 | 37 expl 23 43 | 38 iobj 19 44 | 39 cc:preconj 6 45 | 40 csubj:pass 2 46 | -------------------------------------------------------------------------------- /data/datasets/towe/16res/char_vocab.txt: -------------------------------------------------------------------------------- 1 | special_token_size 1 2 | token_size 59 3 | singleton_size 3 4 | singleton_max_count 1 5 | 0 *UNK* 1 6 | 1 e 10416 7 | 2 t 7362 8 | 3 a 7017 9 | 4 i 5843 10 | 5 o 5708 11 | 6 s 5465 12 | 7 r 4772 13 | 8 n 4667 14 | 9 h 3974 15 | 10 d 3312 16 | 11 l 3235 17 | 12 c 2385 18 | 13 u 2153 19 | 14 f 1874 20 | 15 w 1744 21 | 16 y 1744 22 | 17 m 1632 23 | 18 g 1631 24 | 19 p 1624 25 | 20 . 1400 26 | 21 b 1265 27 | 22 v 1193 28 | 23 , 834 29 | 24 k 598 30 | 25 z 284 31 | 26 ' 249 32 | 27 - 218 33 | 28 x 189 34 | 29 ! 184 35 | 30 j 126 36 | 31 q 86 37 | 32 ) 76 38 | 33 ( 71 39 | 34 – 40 40 | 35 0 36 41 | 36 : 25 42 | 37 $ 25 43 | 38 1 23 44 | 39 2 20 45 | 40 5 20 46 | 41 / 16 47 | 42 ’ 15 48 | 43 ; 15 49 | 44 3 12 50 | 45 8 10 51 | 46 4 9 52 | 47 & 8 53 | 48 ? 7 54 | 49 6 7 55 | 50 ` 6 56 | 51 7 5 57 | 52 * 5 58 | 53 9 4 59 | 54 + 3 60 | 55 % 2 61 | 56 é 2 62 | 57 @ 1 63 | 58 # 1 64 | 59 = 1 65 | -------------------------------------------------------------------------------- /configs/16res_train.conf: -------------------------------------------------------------------------------- 1 | [1] 2 | label = towe_16res 3 | model_type = synfue 4 | model_path = bert-base-cased 5 | tokenizer_path = bert-base-cased 6 | train_path = data/datasets/towe/16res/train.json 7 | valid_path = data/datasets/towe/16res/dev.json 8 | types_path = data/datasets/towe/16res/types.json 9 | train_batch_size = 16 10 | eval_batch_size = 1 11 | neg_term_count = 100 12 | neg_relation_count = 100 13 | epochs = 20 14 | lr = 4e-5 15 | lr_warmup = 0.1 16 | weight_decay = 0.01 17 | max_grad_norm = 1.0 18 | rel_filter_threshold = 0.4 19 | size_embedding = 25 20 | prop_drop = 0.4 21 | max_span_size = 8 22 | store_predictions = true 23 | store_examples = true 24 | sampling_processes = 4 25 | sampling_limit = 100 26 | max_pairs = 800 27 | final_eval = false 28 | log_path = data/log/ 29 | save_path = data/save/ 30 | bert_dim = 768 31 | dep_dim = 100 32 | dep_num = 42 33 | pos_num = 45 34 | pos_dim = 100 35 | w_size = 5 36 | bert_dropout = 0.1 37 | output_dropout = 0.1 38 | num_layer = 2 39 | alpha = 1.0 40 | beta = 0.4 41 | sigma = 1.0 42 | -------------------------------------------------------------------------------- /data/datasets/str_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import re 4 | 5 | 6 | def normalize_sent(sent): 7 | sent = str(sent).replace('``', '"') 8 | sent = str(sent).replace("''", '"') 9 | sent = str(sent).replace('-LRB-', '(') 10 | sent = str(sent).replace('-RRB-', ')') 11 | sent = str(sent).replace('-LSB-', '(') 12 | sent = str(sent).replace('-RSB-', ')') 13 | return sent 14 | 15 | 16 | def collapse_role_type(role_type): 17 | ''' 18 | collapse role types from 36 to 28 following Bishan Yang 2016 19 | we also have to handle types like 'Beneficiary#Recipient' 20 | :param role_type: 21 | :return: 22 | ''' 23 | if role_type.startswith('Time-'): 24 | return 'Time' 25 | idx = role_type.find('#') 26 | if idx != -1: 27 | role_type = role_type[:idx] 28 | 29 | return role_type 30 | 31 | 32 | def normalize_tok(tok, lower=False, normalize_digits=False): 33 | 34 | if lower: 35 | tok = tok.lower() 36 | if normalize_digits: 37 | tok = re.sub(r"\d", "0", tok) 38 | tok = re.sub(r"^(\d+[,])*(\d+)$", "0", tok) 39 | return tok 40 | 41 | 42 | def capitalize_first_char(sent): 43 | sent = str(sent[0]).upper() + sent[1:] 44 | return sent 45 | 46 | -------------------------------------------------------------------------------- /synfue.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from args import train_argparser, eval_argparser 4 | from config_reader import process_configs 5 | from SynFue import input_reader 6 | from SynFue.trainer import SynFueTrainer 7 | 8 | 9 | def __train(run_args): 10 | trainer = SynFueTrainer(run_args) 11 | trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path, 12 | types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader) 13 | 14 | 15 | def _train(): 16 | arg_parser = train_argparser() 17 | process_configs(target=__train, arg_parser=arg_parser) 18 | 19 | 20 | def __eval(run_args): 21 | trainer = SynFueTrainer(run_args) 22 | trainer.eval(dataset_path=run_args.dataset_path, types_path=run_args.types_path, 23 | input_reader_cls=input_reader.JsonInputReader) 24 | 25 | 26 | def _eval(): 27 | arg_parser = eval_argparser() 28 | process_configs(target=__eval, arg_parser=arg_parser) 29 | 30 | 31 | if __name__ == '__main__': 32 | arg_parser = argparse.ArgumentParser(add_help=False) 33 | arg_parser.add_argument('train_mode', type=str, default='train', help="Mode: 'train' or 'eval'") 34 | args, _ = arg_parser.parse_known_args() 35 | 36 | if args.train_mode == 'train': 37 | _train() 38 | elif args.train_mode == 'eval': 39 | _eval() 40 | else: 41 | raise Exception("Mode not in ['train', 'eval'], e.g. 'python SynFue.py train ...'") 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Syntax fusion Encoder for the PAOTE task (Synfue) 2 | This repository implements the method described in the paper [Learn from Syntax: Improving Pair-wise Aspect and Opinion Terms Extraction with Rich Syntactic Knowledge](https://www.ijcai.org/proceedings/2021/0545.pdf) 3 | 4 | 5 | ## Prerequisite 6 | * [pytorch Library](https://pytorch.org/) (3.8.0) 7 | * [transformers](https://huggingface.co/transformers/model_doc/bert.html) (4.5.1) 8 | * [corenlp](https://stanfordnlp.github.io/CoreNLP/) (4.2) 9 | * torch (1.7.1) 10 | * numpy (1.20.2) 11 | * gensim (4.0.1) 12 | * pandas (1.2.4) 13 | * scikit-learn (0.24.1) 14 | 15 | ## Usage (by examples) 16 | ### Data 17 | Orignal data comes from [TOWE](https://github.com/NJUNLP/TOWE/tree/master/data). 18 | 19 | 20 | ### Preprocessing 21 | We need to obtain the dependency sturcture and POS tags for each data, and save as json format. 22 | Pay attention to the file path and modify as needed. 23 | 24 | #### Get Dependency and POS 25 | To parse the dependency structure and POS tags, we employ the [CoreNLP](https://stanfordnlp.github.io/CoreNLP/) provided by stanfordnlp. 26 | So please download relavant files first and put it in `data/datasets/orignal`. 27 | We use the NLTK package to obtain the dependency and POS parsing, so we need to modify the code as follows in `process.py` line 24: 28 | ``` 29 | depparser = CoreNLPDependencyParser(url='http://127.0.0.1:9000') 30 | ``` 31 | The url is set according to the real scenario. 32 | 33 | #### Save 34 | ``` 35 | python preprocess.py 36 | ``` 37 | The proposed data will be sotored in the dicretory `data/datasets/towe/`. 38 | We also provide some preprocessed examples. 39 | 40 | ### Train 41 | We use embedding bert-cased by [bert-base-cased](https://huggingface.co/bert-base-cased) (768d) 42 | 43 | ``` 44 | python synfue.py train --config configs/16res_train.conf 45 | ``` 46 | ### Test 47 | ``` 48 | python synfue.py eval --config configs/16res_eval.conf 49 | ``` 50 | 51 | ## Note 52 | this code refers to the [SpERT](https://github.com/lavis-nlp/spert) 53 | -------------------------------------------------------------------------------- /SynFue/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import torch 3 | 4 | 5 | class Loss(ABC): 6 | def compute(self, *args, **kwargs): 7 | pass 8 | 9 | 10 | class SynFueLoss(Loss): 11 | def __init__(self, rel_criterion, term_criterion, model, optimizer, scheduler, max_grad_norm): 12 | self._rel_criterion = rel_criterion 13 | self._term_criterion = term_criterion 14 | self._model = model 15 | self._optimizer = optimizer 16 | self._scheduler = scheduler 17 | self._max_grad_norm = max_grad_norm 18 | 19 | def compute(self, term_logits, rel_logits, term_types, rel_types, term_sample_masks, rel_sample_masks): 20 | # term loss 21 | term_logits = term_logits.view(-1, term_logits.shape[-1]) 22 | term_types = term_types.view(-1) 23 | term_sample_masks = term_sample_masks.view(-1).float() 24 | 25 | term_loss = self._term_criterion(term_logits, term_types) 26 | term_loss = (term_loss * term_sample_masks).sum() / term_sample_masks.sum() 27 | 28 | # relation loss 29 | rel_sample_masks = rel_sample_masks.view(-1).float() 30 | rel_count = rel_sample_masks.sum() 31 | 32 | if rel_count.item() != 0: 33 | rel_logits = rel_logits.view(-1, rel_logits.shape[-1]) 34 | rel_types = rel_types.view(-1, rel_types.shape[-1]) 35 | 36 | rel_loss = self._rel_criterion(rel_logits, rel_types) 37 | rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1] 38 | rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count 39 | 40 | # joint loss 41 | train_loss = term_loss + rel_loss 42 | else: 43 | # corner case: no positive/negative relation samples 44 | train_loss = term_loss 45 | 46 | train_loss.backward() 47 | torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._max_grad_norm) 48 | self._optimizer.step() 49 | self._scheduler.step() 50 | self._model.zero_grad() 51 | return train_loss.item() 52 | -------------------------------------------------------------------------------- /config_reader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import multiprocessing as mp 3 | 4 | 5 | def process_configs(target, arg_parser): 6 | args, _ = arg_parser.parse_known_args() 7 | ctx = mp.get_context('spawn') 8 | 9 | for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args): 10 | p = ctx.Process(target=target, args=(run_args,)) 11 | p.start() 12 | p.join() 13 | 14 | 15 | def _read_config(path): 16 | lines = open(path).readlines() 17 | 18 | runs = [] 19 | run = [1, dict()] 20 | for line in lines: 21 | stripped_line = line.strip() 22 | 23 | # continue in case of comment 24 | if stripped_line.startswith('#'): 25 | continue 26 | 27 | if not stripped_line: 28 | if run[1]: 29 | runs.append(run) 30 | 31 | run = [1, dict()] 32 | continue 33 | 34 | if stripped_line.startswith('[') and stripped_line.endswith(']'): 35 | repeat = int(stripped_line[1:-1]) 36 | run[0] = repeat 37 | else: 38 | key, value = stripped_line.split('=') 39 | key, value = (key.strip(), value.strip()) 40 | run[1][key] = value 41 | 42 | if run[1]: 43 | runs.append(run) 44 | 45 | return runs 46 | 47 | 48 | def _convert_config(config): 49 | config_list = [] 50 | for k, v in config.items(): 51 | if v.lower() == 'true': 52 | config_list.append('--' + k) 53 | elif v.lower() != 'false': 54 | config_list.extend(['--' + k] + v.split(' ')) 55 | 56 | return config_list 57 | 58 | 59 | def _yield_configs(arg_parser, args, verbose=True): 60 | _print = (lambda x: print(x)) if verbose else lambda x: x 61 | 62 | if args.config: 63 | config = _read_config(args.config) 64 | 65 | for run_repeat, run_config in config: 66 | print("-" * 50) 67 | print("Config:") 68 | print(run_config) 69 | 70 | args_copy = copy.deepcopy(args) 71 | config_list = _convert_config(run_config) 72 | run_args = arg_parser.parse_args(config_list, namespace=args_copy) 73 | run_args_dict = vars(run_args) 74 | 75 | # set boolean values 76 | for k, v in run_config.items(): 77 | if v.lower() == 'false': 78 | run_args_dict[k] = False 79 | 80 | print("Repeat %s times" % run_repeat) 81 | print("-" * 50) 82 | 83 | for iteration in range(run_repeat): 84 | _print("Iteration %s" % iteration) 85 | _print("-" * 50) 86 | 87 | yield run_args, run_config, run_repeat 88 | 89 | else: 90 | yield args, None, None 91 | -------------------------------------------------------------------------------- /SynFue/cross_attn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | 10 | class MultiHeadAttentionLayer(nn.Module): 11 | def __init__(self, hid_dim, n_heads, dropout, device): 12 | super().__init__() 13 | 14 | assert hid_dim % n_heads == 0 15 | 16 | self.hid_dim = hid_dim 17 | self.n_heads = n_heads 18 | self.head_dim = hid_dim // n_heads 19 | 20 | self.fc_q = nn.Linear(hid_dim, hid_dim) 21 | self.fc_k = nn.Linear(hid_dim, hid_dim) 22 | self.fc_v = nn.Linear(hid_dim, hid_dim) 23 | 24 | self.fc_o = nn.Linear(hid_dim, hid_dim) 25 | 26 | self.dropout = nn.Dropout(dropout) 27 | 28 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 29 | 30 | def forward(self, query, key, value, mask=None): 31 | batch_size = query.shape[0] 32 | 33 | # query = [batch size, query len, hid dim] 34 | # key = [batch size, key len, hid dim] 35 | # value = [batch size, value len, hid dim] 36 | 37 | Q = self.fc_q(query) 38 | K = self.fc_k(key) 39 | V = self.fc_v(value) 40 | 41 | # Q = [batch size, query len, hid dim] 42 | # K = [batch size, key len, hid dim] 43 | # V = [batch size, value len, hid dim] 44 | 45 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 46 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 47 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 48 | 49 | # Q = [batch size, n heads, query len, head dim] 50 | # K = [batch size, n heads, key len, head dim] 51 | # V = [batch size, n heads, value len, head dim] 52 | 53 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 54 | 55 | # energy = [batch size, n heads, query len, key len] 56 | 57 | if mask is not None: 58 | energy = energy.masked_fill(mask == 0, -1e10) 59 | 60 | attention = torch.softmax(energy, dim=-1) 61 | 62 | # attention = [batch size, n heads, query len, key len] 63 | x = torch.matmul(self.dropout(attention), V) 64 | 65 | # x = [batch size, n heads, query len, head dim] 66 | 67 | x = x.permute(0, 2, 1, 3).contiguous() 68 | 69 | # x = [batch size, query len, n heads, head dim] 70 | 71 | x = x.view(batch_size, -1, self.hid_dim) 72 | 73 | # x = [batch size, query len, hid dim] 74 | 75 | x = self.fc_o(x) 76 | 77 | # x = [batch size, query len, hid dim] 78 | 79 | return x, attention 80 | 81 | 82 | class PositionwiseFeedforwardLayer(nn.Module): 83 | def __init__(self, hid_dim, pf_dim, dropout): 84 | super().__init__() 85 | 86 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 87 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 88 | 89 | self.dropout = nn.Dropout(dropout) 90 | 91 | def forward(self, x): 92 | # x = [batch size, seq len, hid dim] 93 | 94 | x = self.dropout(torch.relu(self.fc_1(x))) 95 | 96 | # x = [batch size, seq len, pf dim] 97 | 98 | x = self.fc_2(x) 99 | 100 | # x = [batch size, seq len, hid dim] 101 | 102 | return x 103 | 104 | 105 | class CA_module(nn.Module): 106 | def __init__(self, hid_dim, pf_dim, n_heads, dropout): 107 | super(CA_module, self).__init__() 108 | self.n_hidden = hid_dim 109 | self.multi_head_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 110 | self.FFN = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) 111 | self.layer_norm = nn.LayerNorm(hid_dim) 112 | 113 | def forward(self, inputs): 114 | 115 | z_hat = self.CR_2Datt(inputs) 116 | z = self.layer_norm(z_hat + inputs) 117 | outputs_hat = self.FFN(z) 118 | outputs = self.layer_norm(outputs_hat + z) 119 | return outputs 120 | 121 | def row_2Datt(self, inputs): 122 | max_len = inputs.size(1) 123 | z_row = inputs.reshape(-1, max_len, self.n_hidden) 124 | z_row, _ = self.multi_head_attention(z_row, z_row, z_row) 125 | return z_row.reshape(-1, max_len, max_len, self.n_hidden) 126 | 127 | def col_2Datt(self, inputs): 128 | z_col = inputs.permute(0, 2, 1, 3) 129 | z_col = self.row_2Datt(z_col) 130 | z_col = z_col.permute(0, 2, 1, 3) 131 | return z_col 132 | 133 | def CR_2Datt(self, inputs): 134 | z_row = self.row_2Datt(inputs) 135 | z_col = self.col_2Datt(inputs) 136 | outputs = (z_row + z_col) / 2. 137 | return outputs -------------------------------------------------------------------------------- /data/datasets/io_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import numpy as np 5 | 6 | import pickle 7 | import json 8 | import yaml 9 | import sys 10 | import logging 11 | 12 | 13 | def read_json_lines(path): 14 | res = [] 15 | for line in open(path, 'r'): 16 | res.append(json.loads(line)) 17 | return res 18 | 19 | 20 | def read_yaml(path, encoding='utf-8'): 21 | with open(path, 'r', encoding=encoding) as file: 22 | return yaml.load(file.read()) 23 | 24 | 25 | def read_lines(path, encoding='utf-8', return_list=False): 26 | with open(path, 'r', encoding=encoding) as file: 27 | if return_list: 28 | return file.readlines() 29 | for line in file: 30 | yield line.strip() 31 | 32 | 33 | def read_multi_line_sent(path, skip_line_start='-DOCSTART-'): 34 | sent_lines = [] 35 | for line in read_lines(path): 36 | line = line.strip() 37 | if len(line) == 0 or line.startswith(skip_line_start): 38 | if sent_lines: 39 | yield sent_lines 40 | sent_lines = [] 41 | continue 42 | sent_lines.append(line) 43 | if sent_lines: 44 | yield sent_lines 45 | 46 | 47 | def read_pickle(path): 48 | with open(path, 'rb') as file: 49 | return pickle.load(file) 50 | 51 | 52 | def save_pickle(path, obj): 53 | with open(path, 'wb') as file: 54 | pickle.dump(obj, file) 55 | 56 | 57 | def write_lines(txt_path, lines): 58 | with open(txt_path, 'w', encoding='utf-8') as file: 59 | for line in lines: 60 | file.write(line + '\n') 61 | 62 | 63 | def write_texts(txt_path, lines): 64 | with open(txt_path, 'w', encoding='utf-8') as file: 65 | for line in lines: 66 | file.write(line) 67 | 68 | 69 | def load_embed_txt(embed_file): 70 | """Load embed_file into a python dictionary. 71 | Note: the embed_file should be a Glove formated txt file. Assuming 72 | embed_size=5, for example: 73 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 74 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 75 | and 0.20327 0.47348 0.050877 0.002103 0.060547 76 | Args: 77 | embed_file: file path to the embedding file. 78 | Returns: 79 | a dictionary that maps word to vector, and the size of embedding dimensions. 80 | """ 81 | emb_dict = dict() 82 | emb_size = None 83 | for line in read_lines(embed_file): 84 | tokens = line.strip().split(" ") 85 | word = tokens[0] 86 | vec = list(map(float, tokens[1:])) 87 | emb_dict[word] = vec 88 | if emb_size: 89 | assert emb_size == len(vec), "All embedding size should be same." 90 | else: 91 | emb_size = len(vec) 92 | return emb_dict, emb_size 93 | 94 | 95 | def txt_to_npy(dirname, fname, output_name): 96 | emb_dict, emb_size = load_embed_txt(os.path.join(dirname,fname)) 97 | words = [] 98 | vec = np.empty((len(emb_dict), emb_size)) 99 | i = 0 100 | for word, vec_list in emb_dict.items(): 101 | words.append(word) 102 | vec[i] = vec_list 103 | i += 1 104 | with open(os.path.join(dirname,output_name+'.vocab.txt'), 'w') as file: 105 | for word in words: 106 | file.write(word+'\n') 107 | np.save(os.path.join(dirname,output_name+'.npy'), vec) 108 | 109 | 110 | def get_logger(name, log_dir=None, log_name=None, file_model='a', 111 | level=logging.INFO, handler=sys.stdout, 112 | formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'): 113 | logger = logging.getLogger(name) 114 | logger.setLevel(logging.INFO) 115 | formatter = logging.Formatter(formatter) 116 | stream_handler = logging.StreamHandler(handler) 117 | stream_handler.setLevel(level) 118 | stream_handler.setFormatter(formatter) 119 | logger.addHandler(stream_handler) 120 | if log_dir and log_name: 121 | filename = os.path.join(log_dir, log_name) 122 | file_handler = logging.FileHandler(filename, encoding='utf-8', mode=file_model) 123 | file_handler.setLevel(level) 124 | file_handler.setFormatter(formatter) 125 | logger.addHandler(file_handler) 126 | 127 | return logger 128 | 129 | 130 | def trigger_tag_to_list(trigger_tags, O_idx): 131 | trigger_list = [] 132 | for i, trigger_type in enumerate(trigger_tags): 133 | if trigger_type != O_idx: 134 | trigger_list.append([i, trigger_type]) 135 | return trigger_list 136 | 137 | 138 | def relative_position(ent_start, ent_end, tok_idx, max_position_len=150): 139 | if ent_start <= tok_idx <= ent_end: 140 | return 0 141 | elif tok_idx < ent_start: 142 | return ent_start - tok_idx 143 | elif tok_idx > ent_end: 144 | return tok_idx - ent_end + max_position_len 145 | 146 | return None 147 | 148 | 149 | def to_set(*list_vars): 150 | sets = [] 151 | for list_var in list_vars: 152 | sets.append({tuple(sub_list) for sub_list in list_var}) 153 | return sets -------------------------------------------------------------------------------- /SynFue/Encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | class LabelAwareGCN(nn.Module): 10 | """ 11 | Simple GCN layer 12 | """ 13 | def __init__(self, dep_dim, in_features, out_features, pos_dim=None, bias=True): 14 | super(LabelAwareGCN, self).__init__() 15 | self.dep_dim = dep_dim 16 | self.pos_dim = pos_dim 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | 20 | self.dep_attn = nn.Linear(dep_dim + pos_dim + in_features, out_features) 21 | self.dep_fc = nn.Linear(dep_dim, out_features) 22 | self.pos_fc = nn.Linear(pos_dim, out_features) 23 | 24 | def forward(self, text, adj, dep_embed, pos_embed=None): 25 | """ 26 | 27 | :param text: [batch size, seq_len, feat_dim] 28 | :param adj: [batch size, seq_len, seq_len] 29 | :param dep_embed: [batch size, seq_len, seq_len, dep_type_dim] 30 | :param pos_embed: [batch size, seq_len, pos_dim] 31 | :return: [batch size, seq_len, feat_dim] 32 | """ 33 | batch_size, seq_len, feat_dim = text.shape 34 | 35 | val_us = text.unsqueeze(dim=2) 36 | val_us = val_us.repeat(1, 1, seq_len, 1) 37 | pos_us = pos_embed.unsqueeze(dim=2).repeat(1, 1, seq_len, 1) 38 | # [batch size, seq_len, seq_len, feat_dim+pos_dim+dep_dim] 39 | val_sum = torch.cat([val_us, pos_us, dep_embed], dim=-1) 40 | 41 | r = self.dep_attn(val_sum) 42 | 43 | p = torch.sum(r, dim=-1) 44 | mask = (adj == 0).float() * (-1e30) 45 | p = p + mask 46 | p = torch.softmax(p, dim=2) 47 | p_us = p.unsqueeze(3).repeat(1, 1, 1, feat_dim) 48 | 49 | output = val_us + self.pos_fc(pos_us) + self.dep_fc(dep_embed) 50 | output = torch.mul(p_us, output) 51 | 52 | output_sum = torch.sum(output, dim=2) 53 | 54 | return r, output_sum, p 55 | 56 | 57 | class nLaGCN(nn.Module): 58 | def __init__(self, opt): 59 | super(nLaGCN, self).__init__() 60 | self.opt = opt 61 | self.model = nn.ModuleList([LabelAwareGCN(opt.dep_dim, opt.bert_dim, 62 | opt.bert_dim, opt.pos_dim) 63 | for i in range(self.opt.num_layer)]) 64 | self.dep_embedding = nn.Embedding(opt.dep_num, opt.dep_dim, padding_idx=0) 65 | 66 | def forward(self, x, simple_graph, graph, pos_embed=None, output_attention=False): 67 | 68 | dep_embed = self.dep_embedding(graph) 69 | 70 | attn_list = [] 71 | for lagcn in self.model: 72 | r, x, attn = lagcn(x, simple_graph, dep_embed, pos_embed=pos_embed) 73 | attn_list.append(attn) 74 | 75 | if output_attention is True: 76 | return x, r, attn_list 77 | else: 78 | return x, r 79 | 80 | 81 | class SynFueEncoder(nn.Module): 82 | def __init__(self, opt): 83 | super(SynFueEncoder, self).__init__() 84 | self.opt = opt 85 | self.lagcn = nLaGCN(opt) 86 | 87 | self.fc = nn.Linear(opt.bert_dim*2 + opt.pos_dim, opt.bert_dim*2) 88 | self.output_dropout = nn.Dropout(opt.output_dropout) 89 | 90 | self.pod_embedding = nn.Embedding(opt.pos_num, opt.pos_dim, padding_idx=0) 91 | 92 | def forward(self, word_reps, simple_graph, graph, pos=None, output_attention=False): 93 | """ 94 | 95 | :param word_reps: [B, L, H] 96 | :param simple_graph: [B, L, L] 97 | :param graph: [B, L, L] 98 | :param pos: [B, L] 99 | :param output_attention: bool 100 | :return: 101 | output: [B, L, H] 102 | dep_reps: [B, L, H] 103 | cls_reps: [B, H] 104 | """ 105 | 106 | pos_embed = self.pod_embedding(pos) 107 | 108 | lagcn_output = self.lagcn(word_reps, simple_graph, graph, pos_embed, output_attention) 109 | 110 | pos_output = self.local_attn(word_reps, pos_embed, self.opt.num_layer, self.opt.w_size) 111 | 112 | output = torch.cat((lagcn_output[0], pos_output, word_reps), dim=-1) 113 | output = self.fc(output) 114 | output = self.output_dropout(output) 115 | return output, lagcn_output[1] 116 | 117 | def local_attn(self, x, pos_embed, num_layer, w_size): 118 | """ 119 | 120 | :param x: 121 | :param pos_embed: 122 | :return: 123 | """ 124 | batch_size, seq_len, feat_dim = x.shape 125 | pos_dim = pos_embed.size(-1) 126 | output = pos_embed 127 | for i in range(num_layer): 128 | val_sum = torch.cat([x, output], dim=-1) # [batch size, seq_len, feat_dim+pos_dim] 129 | attn = torch.matmul(val_sum, val_sum.transpose(1, 2)) # [batch size, seq_len, seq_len] 130 | # pad size = seq_len + (window_size - 1) // 2 * 2 131 | pad_size = seq_len + w_size * 2 132 | mask = torch.zeros((batch_size, seq_len, pad_size), dtype=torch.float).to( 133 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 134 | for i in range(seq_len): 135 | mask[:, i, i:i + w_size] = 1.0 136 | pad_attn = torch.full((batch_size, seq_len, w_size), -1e18).to( 137 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 138 | attn = torch.cat([pad_attn, attn, pad_attn], dim=-1) 139 | local_attn = torch.softmax(torch.mul(attn, mask), dim=-1) 140 | local_attn = local_attn[:, :, w_size:pad_size - w_size] # [batch size, seq_len, seq_len] 141 | local_attn = local_attn.unsqueeze(dim=3).repeat(1, 1, 1, pos_dim) 142 | output = output.unsqueeze(dim=2).repeat(1, 1, seq_len, 1) 143 | output = torch.sum(torch.mul(output, local_attn), dim=2) # [batch size, seq_len, pos_dim] 144 | return output 145 | -------------------------------------------------------------------------------- /SynFue/base_trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import os 5 | import sys 6 | from typing import List, Dict, Tuple 7 | 8 | import torch 9 | from torch.nn import DataParallel 10 | from torch.optim import optimizer 11 | from transformers import PreTrainedModel 12 | from transformers import PreTrainedTokenizer 13 | 14 | from SynFue import util 15 | from SynFue.opt import tensorboardx 16 | 17 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 18 | 19 | 20 | class BaseTrainer: 21 | """ Trainer base class with common methods """ 22 | 23 | def __init__(self, args: argparse.Namespace): 24 | self.args = args 25 | self._debug = self.args.debug 26 | 27 | # logging 28 | name = str(datetime.datetime.now()).replace(' ', '_') 29 | self._log_path = os.path.join(self.args.log_path, self.args.label, name) 30 | util.create_directories_dir(self._log_path) 31 | 32 | if hasattr(args, 'save_path'): 33 | self._save_path = os.path.join(self.args.save_path, self.args.label) 34 | util.create_directories_dir(self._save_path) 35 | 36 | self._log_paths = dict() 37 | 38 | # file + console logging 39 | log_formatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") 40 | self._logger = logging.getLogger() 41 | util.reset_logger(self._logger) 42 | 43 | file_handler = logging.FileHandler(os.path.join(self._log_path, 'all.log')) 44 | file_handler.setFormatter(log_formatter) 45 | self._logger.addHandler(file_handler) 46 | 47 | console_handler = logging.StreamHandler(sys.stdout) 48 | console_handler.setFormatter(log_formatter) 49 | self._logger.addHandler(console_handler) 50 | 51 | if self._debug: 52 | self._logger.setLevel(logging.DEBUG) 53 | else: 54 | self._logger.setLevel(logging.INFO) 55 | 56 | # tensorboard summary 57 | self._summary_writer = tensorboardx.SummaryWriter(self._log_path) if tensorboardx is not None else None 58 | 59 | self._best_results = dict() 60 | self._log_arguments() 61 | 62 | # CUDA devices 63 | self._device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") 64 | self._gpu_count = torch.cuda.device_count() 65 | 66 | # set seed 67 | if args.seed is not None: 68 | util.set_seed(args.seed) 69 | 70 | def _add_dataset_logging(self, *labels, data: Dict[str, List[str]]): 71 | for label in labels: 72 | dic = dict() 73 | 74 | for key, columns in data.items(): 75 | path = os.path.join(self._log_path, '%s_%s.csv' % (key, label)) 76 | util.create_csv(path, *columns) 77 | dic[key] = path 78 | 79 | self._log_paths[label] = dic 80 | self._best_results[label] = 0 81 | 82 | def _log_arguments(self): 83 | util.save_dict(self._log_path, self.args, 'args') 84 | if self._summary_writer is not None: 85 | util.summarize_dict(self._summary_writer, self.args, 'args') 86 | 87 | def _log_tensorboard(self, dataset_label: str, data_label: str, data: object, iteration: int): 88 | if self._summary_writer is not None: 89 | self._summary_writer.add_scalar('data/%s/%s' % (dataset_label, data_label), data, iteration) 90 | 91 | def _log_csv(self, dataset_label: str, data_label: str, *data: Tuple[object]): 92 | logs = self._log_paths[dataset_label] 93 | util.append_csv(logs[data_label], *data) 94 | 95 | def _save_best(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, optimizer: optimizer, 96 | accuracy: float, iteration: int, label: str, extra=None): 97 | if accuracy > self._best_results[label]: 98 | self._logger.info("[%s] Best model in iteration %s: %s%% accuracy" % (label, iteration, accuracy)) 99 | self._save_model(self._save_path, model, tokenizer, iteration, 100 | optimizer=optimizer if self.args.save_optimizer else None, 101 | save_as_best=True, name='model_%s' % label, extra=extra) 102 | self._best_results[label] = accuracy 103 | 104 | def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, 105 | iteration: int, optimizer: optimizer = None, save_as_best: bool = False, 106 | extra: dict = None, include_iteration: int = True, name: str = 'model'): 107 | extra_state = dict(iteration=iteration) 108 | 109 | if optimizer: 110 | extra_state['optimizer'] = optimizer.state_dict() 111 | 112 | if extra: 113 | extra_state.update(extra) 114 | 115 | # if save_as_best: 116 | # dir_path = os.path.join(save_path, '%s_best' % name) 117 | # else: 118 | # dir_name = '%s_%s' % (name, iteration) if include_iteration else name 119 | # dir_path = os.path.join(save_path, dir_name) 120 | 121 | util.create_directories_dir(save_path) 122 | 123 | # save model 124 | if isinstance(model, DataParallel): 125 | model.module.save_pretrained(save_path) 126 | else: 127 | model.save_pretrained(save_path) 128 | 129 | # save vocabulary 130 | tokenizer.save_pretrained(save_path) 131 | 132 | # save extra 133 | state_path = os.path.join(save_path, 'extra.state') 134 | torch.save(extra_state, state_path) 135 | 136 | # def _load_model(self, save_path: str): 137 | # 138 | # model = torch.load() 139 | 140 | def _get_lr(self, optimizer): 141 | lrs = [] 142 | for group in optimizer.param_groups: 143 | lr_scheduled = group['lr'] 144 | lrs.append(lr_scheduled) 145 | return lrs 146 | 147 | def _close_summary_writer(self): 148 | if self._summary_writer is not None: 149 | self._summary_writer.close() 150 | -------------------------------------------------------------------------------- /data/log/towe_16res/eval_valid.csv: -------------------------------------------------------------------------------- 1 | ner_prec_micro;ner_rec_micro;ner_f1_micro;ner_prec_macro;ner_rec_macro;ner_f1_macro;rel_prec_micro;rel_rec_micro;rel_f1_micro;rel_prec_macro;rel_rec_macro;rel_f1_macro;rel_nec_prec_micro;rel_nec_rec_micro;rel_nec_f1_micro;rel_nec_prec_macro;rel_nec_rec_macro;rel_nec_f1_macro;epoch;iteration;global_iteration 2 | 64.01326699834162;82.30277185501066;72.01492537313432;64.17537970504073;82.15470141071893;72.02142739668096;36.92307692307693;60.0;45.714285714285715;36.92307692307693;60.0;45.714285714285715;36.92307692307693;60.0;45.714285714285715;36.92307692307693;60.0;45.714285714285715;1;0;269 3 | 82.13891951488424;79.42430703624733;80.75880758807588;82.65960576640188;79.52273936956651;80.78248656685518;66.46216768916156;62.5;64.42021803766104;66.46216768916156;62.5;64.42021803766104;66.46216768916156;62.5;64.42021803766104;66.46216768916156;62.5;64.42021803766104;2;0;538 4 | 82.32804232804233;82.94243070362474;82.63409453000531;82.28590291750504;82.86574741716974;82.56053056751291;67.3469387755102;69.8076923076923;68.55524079320115;67.3469387755102;69.8076923076923;68.55524079320115;67.3469387755102;69.8076923076923;68.55524079320115;67.3469387755102;69.8076923076923;68.55524079320115;3;0;807 5 | 81.32530120481928;86.35394456289978;83.76421923474663;81.4069844856457;86.38321876833912;83.76537823178138;64.321608040201;73.84615384615385;68.75559534467322;64.321608040201;73.84615384615385;68.75559534467322;64.321608040201;73.84615384615385;68.75559534467322;64.321608040201;73.84615384615385;68.75559534467322;4;0;1076 6 | 73.52185089974293;91.47121535181236;81.52019002375297;73.49442269770728;91.40557827647544;81.46932857501848;49.467455621301774;80.38461538461539;61.24542124542125;49.467455621301774;80.38461538461539;61.24542124542125;49.467455621301774;80.38461538461539;61.24542124542125;49.467455621301774;80.38461538461539;61.24542124542125;5;0;1345 7 | 79.41747572815534;87.20682302771856;83.13008130081302;79.40115154463084;87.1875241678305;83.10578512396695;63.8801261829653;77.88461538461539;70.19064124783363;63.8801261829653;77.88461538461539;70.19064124783363;63.8801261829653;77.88461538461539;70.19064124783363;63.8801261829653;77.88461538461539;70.19064124783363;6;0;1614 8 | 82.12512413108243;88.16631130063965;85.03856041131107;82.18998350577297;88.17220688117844;85.03409992069786;72.62969588550983;78.07692307692308;75.25486561631139;72.62969588550983;78.07692307692308;75.25486561631139;72.62969588550983;78.07692307692308;75.25486561631139;72.62969588550983;78.07692307692308;75.25486561631139;7;0;1883 9 | 78.9622641509434;89.23240938166312;83.78378378378379;79.18097806121175;89.2444624392108;83.81268288772196;58.55354659248957;80.96153846153847;67.95803066989508;58.55354659248957;80.96153846153847;67.95803066989508;58.55354659248957;80.96153846153847;67.95803066989508;58.55354659248957;80.96153846153847;67.95803066989508;8;0;2152 10 | 79.73484848484848;89.76545842217483;84.45336008024073;79.83453870035713;89.76421295896132;84.45565830456094;65.8267716535433;80.38461538461539;72.3809523809524;65.8267716535433;80.38461538461539;72.3809523809524;65.8267716535433;80.38461538461539;72.3809523809524;65.8267716535433;80.38461538461539;72.3809523809524;9;0;2421 11 | 82.72357723577237;86.78038379530916;84.70343392299688;82.89277220556302;86.78264192487387;84.71655535053753;68.62068965517241;76.53846153846153;72.36363636363636;68.62068965517241;76.53846153846153;72.36363636363636;68.62068965517241;76.53846153846153;72.36363636363636;68.62068965517241;76.53846153846153;72.36363636363636;10;0;2690 12 | 80.3073967339097;89.12579957356077;84.48711470439616;80.37625078462504;89.11867598957315;84.48152596659804;65.60509554140127;79.23076923076923;71.77700348432056;65.60509554140127;79.23076923076923;71.77700348432056;65.60509554140127;79.23076923076923;71.77700348432056;65.60509554140127;79.23076923076923;71.77700348432056;11;0;2959 13 | 81.21974830590513;89.4456289978678;85.13444951801115;81.2567949554251;89.43052630142346;85.12008221028701;68.90756302521008;78.84615384615384;73.54260089686098;68.90756302521008;78.84615384615384;73.54260089686098;68.90756302521008;78.84615384615384;73.54260089686098;68.90756302521008;78.84615384615384;73.54260089686098;12;0;3228 14 | 78.87453874538745;91.15138592750533;84.56973293768546;78.98414482535465;91.13194156957833;84.57265291746566;63.813813813813816;81.73076923076923;71.66947723440134;63.813813813813816;81.73076923076923;71.66947723440134;63.813813813813816;81.73076923076923;71.66947723440134;63.813813813813816;81.73076923076923;71.66947723440134;13;0;3497 15 | 82.14990138067061;88.80597014925374;85.34836065573771;82.21549895908322;88.7849893320353;85.33761652809271;72.7112676056338;79.42307692307692;75.91911764705883;72.7112676056338;79.42307692307692;75.91911764705883;72.7112676056338;79.42307692307692;75.91911764705883;72.7112676056338;79.42307692307692;75.91911764705883;14;0;3766 16 | 80.3639846743295;89.4456289978678;84.6619576185671;80.4291802953712;89.43598538784534;84.65509657133082;69.24342105263158;80.96153846153847;74.64539007092199;69.24342105263158;80.96153846153847;74.64539007092199;69.24342105263158;80.96153846153847;74.64539007092199;69.24342105263158;80.96153846153847;74.64539007092199;15;0;4035 17 | 81.16504854368932;89.12579957356077;84.95934959349594;81.16504854368931;89.09138055746371;84.92984282810254;69.49152542372882;78.84615384615384;73.87387387387386;69.49152542372882;78.84615384615384;73.87387387387386;69.49152542372882;78.84615384615384;73.87387387387386;69.49152542372882;78.84615384615384;73.87387387387386;16;0;4304 18 | 82.66129032258065;87.42004264392324;84.97409326424871;82.6448027901645;87.38996528930883;84.94307196562835;73.49177330895796;77.3076923076923;75.35145267104029;73.49177330895796;77.3076923076923;75.35145267104029;73.49177330895796;77.3076923076923;75.35145267104029;73.49177330895796;77.3076923076923;75.35145267104029;17;0;4573 19 | 82.73453093812375;88.37953091684435;85.4639175257732;82.80607364897179;88.36918891623486;85.45676385468668;72.15411558669001;79.23076923076923;75.52703941338221;72.15411558669001;79.23076923076923;75.52703941338221;72.15411558669001;79.23076923076923;75.52703941338221;72.15411558669001;79.23076923076923;75.52703941338221;18;0;4842 20 | 82.7037773359841;88.69936034115139;85.59670781893004;82.73430490171665;88.67012105524141;85.57540858667117;72.63157894736842;79.61538461538461;75.96330275229359;72.63157894736842;79.61538461538461;75.96330275229359;72.63157894736842;79.61538461538461;75.96330275229359;72.63157894736842;79.61538461538461;75.96330275229359;19;0;5111 21 | 82.76892430278885;88.59275053304904;85.58187435633367;82.7915236413421;88.56071186486942;85.5577241291527;73.00177619893428;79.03846153846153;75.90027700831024;73.00177619893428;79.03846153846153;75.90027700831024;73.00177619893428;79.03846153846153;75.90027700831024;73.00177619893428;79.03846153846153;75.90027700831024;20;0;5380 22 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def _add_common_args(arg_parser): 5 | arg_parser.add_argument('--config', type=str) 6 | 7 | # Input 8 | arg_parser.add_argument('--types_path', type=str, help="Path to type specifications") 9 | 10 | # Preprocessing 11 | arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer") 12 | arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans") 13 | arg_parser.add_argument('--lowercase', action='store_true', default=False, 14 | help="If true, input is lowercased during preprocessing") 15 | arg_parser.add_argument('--sampling_processes', type=int, default=4, 16 | help="Number of sampling processes. 0 = no multiprocessing for sampling") 17 | arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue") 18 | 19 | # Logging 20 | arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models") 21 | arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored") 22 | arg_parser.add_argument('--store_predictions', action='store_true', default=False, 23 | help="If true, store predictions on disc (in log directory)") 24 | arg_parser.add_argument('--store_examples', action='store_true', default=False, 25 | help="If true, store evaluation examples on disc (in log directory)") 26 | arg_parser.add_argument('--example_count', type=int, default=None, 27 | help="Count of evaluation example to store (if store_examples == True)") 28 | arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off") 29 | 30 | # Model / Training / Evaluation 31 | arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints") 32 | arg_parser.add_argument('--model_type', type=str, default="SynFue", help="Type of model") 33 | arg_parser.add_argument('--cpu', action='store_true', default=False, 34 | help="If true, train/evaluate on CPU even if a CUDA device is available") 35 | arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size") 36 | arg_parser.add_argument('--max_pairs', type=int, default=1000, 37 | help="Maximum term pairs to process during training/evaluation") 38 | arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations") 39 | arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding") 40 | arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT") 41 | arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights") 42 | arg_parser.add_argument('--no_overlapping', action='store_true', default=False, 43 | help="If true, do not evaluate on overlapping terms " 44 | "and relations with overlapping terms") 45 | 46 | # LaGCN 47 | arg_parser.add_argument('--bert_dim', type=int, default=768) 48 | arg_parser.add_argument('--dep_dim', type=int, default=100) 49 | arg_parser.add_argument('--bert_dropout', type=float, default=0.5) 50 | arg_parser.add_argument('--output_dropout', type=float, default=0.5) 51 | arg_parser.add_argument('--dep_num', type=int, default=42) # 40 + 1(None) + 1(self-loop) 52 | arg_parser.add_argument('--num_layer', type=int, default=3) 53 | 54 | # Misc 55 | arg_parser.add_argument('--seed', type=int, default=58986, help="Seed") 56 | arg_parser.add_argument('--cache_path', type=str, default=None, 57 | help="Path to cache transformer models (for HuggingFace transformers library)") 58 | 59 | arg_parser.add_argument('--beta', type=float, default=0.3, help='weight for triaffine') 60 | arg_parser.add_argument('--alpha', type=float, default=0.3, help='weight for biaffine') 61 | arg_parser.add_argument('--sigma', type=float, default=0.3, help='weight for syntactic-aware score') 62 | 63 | # pos 64 | # arg_parser.add_argument('--use_pos', type=bool, default=True) 65 | arg_parser.add_argument('--pos_num', type=int, default=45) 66 | arg_parser.add_argument('--pos_dim', type=int, default=100) 67 | arg_parser.add_argument('--w_size', type=int, default=1, help='the window size of local attention') 68 | 69 | 70 | def train_argparser(): 71 | arg_parser = argparse.ArgumentParser() 72 | 73 | # Input 74 | arg_parser.add_argument('--train_path', type=str, help="Path to train dataset") 75 | arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset") 76 | 77 | # Logging 78 | arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored") 79 | arg_parser.add_argument('--init_eval', action='store_true', default=False, 80 | help="If true, evaluate validation set before training") 81 | arg_parser.add_argument('--save_optimizer', action='store_true', default=False, 82 | help="Save optimizer alongside model") 83 | arg_parser.add_argument('--train_log_iter', type=int, default=1, help="Log training process every x iterations") 84 | arg_parser.add_argument('--final_eval', action='store_true', default=False, 85 | help="Evaluate the model only after training, not at every epoch") 86 | 87 | # Model / Training 88 | arg_parser.add_argument('--train_batch_size', type=int, default=2, help="Training batch size") 89 | arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs") 90 | arg_parser.add_argument('--neg_term_count', type=int, default=100, 91 | help="Number of negative term samples per document (sentence)") 92 | arg_parser.add_argument('--neg_relation_count', type=int, default=100, 93 | help="Number of negative relation samples per document (sentence)") 94 | arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate") 95 | arg_parser.add_argument('--lr_warmup', type=float, default=0.1, 96 | help="Proportion of total train iterations to warmup in linear increase/decrease schedule") 97 | arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply") 98 | arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm") 99 | 100 | # arg_parser.add_argument('--beta', type=float, default=0.3, help='weight for triaffine') 101 | 102 | _add_common_args(arg_parser) 103 | 104 | return arg_parser 105 | 106 | 107 | def eval_argparser(): 108 | arg_parser = argparse.ArgumentParser() 109 | 110 | # Input 111 | arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset") 112 | 113 | _add_common_args(arg_parser) 114 | 115 | return arg_parser 116 | -------------------------------------------------------------------------------- /SynFue/util.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import random 5 | import shutil 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from SynFue.terms import TokenSpan 11 | 12 | CSV_DELIMETER = ';' 13 | 14 | 15 | def create_directories_file(f): 16 | d = os.path.dirname(f) 17 | 18 | if d and not os.path.exists(d): 19 | os.makedirs(d) 20 | 21 | return f 22 | 23 | 24 | def create_directories_dir(d): 25 | if d and not os.path.exists(d): 26 | os.makedirs(d) 27 | 28 | return d 29 | 30 | 31 | def create_csv(file_path, *column_names): 32 | if not os.path.exists(file_path): 33 | with open(file_path, 'w', newline='') as csv_file: 34 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 35 | 36 | if column_names: 37 | writer.writerow(column_names) 38 | 39 | 40 | def append_csv(file_path, *row): 41 | if not os.path.exists(file_path): 42 | raise Exception("The given file doesn't exist") 43 | 44 | with open(file_path, 'a', newline='') as csv_file: 45 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 46 | writer.writerow(row) 47 | 48 | 49 | def append_csv_multiple(file_path, *rows): 50 | if not os.path.exists(file_path): 51 | raise Exception("The given file doesn't exist") 52 | 53 | with open(file_path, 'a', newline='') as csv_file: 54 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 55 | for row in rows: 56 | writer.writerow(row) 57 | 58 | 59 | def read_csv(file_path): 60 | lines = [] 61 | with open(file_path, 'r') as csv_file: 62 | reader = csv.reader(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL) 63 | for row in reader: 64 | lines.append(row) 65 | 66 | return lines[0], lines[1:] 67 | 68 | 69 | def copy_python_directory(source, dest, ignore_dirs=None): 70 | source = source if source.endswith('/') else source + '/' 71 | for (dir_path, dir_names, file_names) in os.walk(source): 72 | tail = '/'.join(dir_path.split(source)[1:]) 73 | new_dir = os.path.join(dest, tail) 74 | 75 | if ignore_dirs and True in [(ignore_dir in tail) for ignore_dir in ignore_dirs]: 76 | continue 77 | 78 | create_directories_dir(new_dir) 79 | 80 | for file_name in file_names: 81 | if file_name.endswith('.py'): 82 | file_path = os.path.join(dir_path, file_name) 83 | shutil.copy2(file_path, new_dir) 84 | 85 | 86 | def save_dict(log_path, dic, name): 87 | # save arguments 88 | # 1. as json 89 | path = os.path.join(log_path, '%s.json' % name) 90 | f = open(path, 'w') 91 | json.dump(vars(dic), f) 92 | f.close() 93 | 94 | # 2. as string 95 | path = os.path.join(log_path, '%s.txt' % name) 96 | f = open(path, 'w') 97 | args_str = ["%s = %s" % (key, value) for key, value in vars(dic).items()] 98 | f.write('\n'.join(args_str)) 99 | f.close() 100 | 101 | 102 | def summarize_dict(summary_writer, dic, name): 103 | table = 'Argument|Value\n-|-' 104 | 105 | for k, v in vars(dic).items(): 106 | row = '\n%s|%s' % (k, v) 107 | table += row 108 | summary_writer.add_text(name, table) 109 | 110 | 111 | def set_seed(seed): 112 | random.seed(seed) 113 | np.random.seed(seed) 114 | torch.manual_seed(seed) 115 | torch.cuda.manual_seed_all(seed) 116 | 117 | 118 | def reset_logger(logger): 119 | for handler in logger.handlers[:]: 120 | logger.removeHandler(handler) 121 | 122 | for f in logger.filters[:]: 123 | logger.removeFilters(f) 124 | 125 | 126 | def flatten(l): 127 | return [i for p in l for i in p] 128 | 129 | 130 | def get_as_list(dic, key): 131 | if key in dic: 132 | return [dic[key]] 133 | else: 134 | return [] 135 | 136 | 137 | def extend_tensor(tensor, extended_shape, fill=0): 138 | tensor_shape = tensor.shape 139 | 140 | extended_tensor = torch.zeros(extended_shape, dtype=tensor.dtype).to(tensor.device) 141 | extended_tensor = extended_tensor.fill_(fill) 142 | 143 | if len(tensor_shape) == 1: 144 | extended_tensor[:tensor_shape[0]] = tensor 145 | elif len(tensor_shape) == 2: 146 | extended_tensor[:tensor_shape[0], :tensor_shape[1]] = tensor 147 | elif len(tensor_shape) == 3: 148 | extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2]] = tensor 149 | elif len(tensor_shape) == 4: 150 | extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2], :tensor_shape[3]] = tensor 151 | 152 | return extended_tensor 153 | 154 | 155 | def padded_stack(tensors, padding=0): 156 | dim_count = len(tensors[0].shape) 157 | 158 | max_shape = [max([t.shape[d] for t in tensors]) for d in range(dim_count)] 159 | padded_tensors = [] 160 | 161 | for t in tensors: 162 | e = extend_tensor(t, max_shape, fill=padding) 163 | padded_tensors.append(e) 164 | 165 | stacked = torch.stack(padded_tensors) 166 | return stacked 167 | 168 | 169 | def batch_index(tensor, index, pad=False): 170 | if tensor.shape[0] != index.shape[0]: 171 | raise Exception() 172 | 173 | if not pad: 174 | return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])]) 175 | else: 176 | return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])]) 177 | 178 | 179 | def padded_nonzero(tensor, padding=0): 180 | indices = padded_stack([tensor[i].nonzero().view(-1) for i in range(tensor.shape[0])], padding) 181 | return indices 182 | 183 | 184 | def swap(v1, v2): 185 | return v2, v1 186 | 187 | 188 | def get_span_tokens(tokens, span): 189 | inside = False 190 | span_tokens = [] 191 | 192 | for t in tokens: 193 | if t.span[0] == span[0]: 194 | inside = True 195 | 196 | if inside: 197 | span_tokens.append(t) 198 | 199 | if inside and t.span[1] == span[1]: 200 | return TokenSpan(span_tokens) 201 | 202 | return None 203 | 204 | 205 | def to_device(batch, device): 206 | converted_batch = dict() 207 | for key in batch.keys(): 208 | converted_batch[key] = batch[key].to(device) 209 | 210 | return converted_batch 211 | 212 | 213 | def check_version(config, model_class, model_path): 214 | if os.path.exists(model_path): 215 | model_path = model_path if model_path.endswith('.bin') else os.path.join(model_path, 'pytorch_model.bin') 216 | state_dict = torch.load(model_path, map_location=torch.device('cpu')) 217 | config_dict = config.to_dict() 218 | 219 | # version check 220 | loaded_version = config_dict.get('spert_version', '1.2') 221 | if 'rel_classifier.weight' in state_dict and loaded_version != model_class.VERSION: 222 | msg = ("Current SpERT version (%s) does not match the version of the loaded model (%s). " 223 | % (model_class.VERSION, loaded_version)) 224 | msg += "Use the code matching your version or train a new model." 225 | raise Exception(msg) 226 | 227 | 228 | -------------------------------------------------------------------------------- /SynFue/templates/term_examples.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Entity Extraction Examples 6 | 7 | 9 | 10 | 11 | 12 | 104 | 105 | 106 | 107 |

Entity Extraction Examples ({{ examples|length }})

108 | 109 |
110 |
111 | check_circle_outline   F1 = 100.000
112 | check_circle_outline   F1 >= 50.00
113 | highlight_off   F1 < 50.00
114 |
115 | 116 |
117 |
118 |   True Positive  
119 |   False Positive  
120 |   False Negative
121 |
122 |
123 |

124 | 125 |
126 |
127 | {% for example in examples %} 128 | {% set outer_loop = loop %} 129 | 130 |
131 | 166 | 167 |
168 |
169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | {% for tp in example["tp"] %} 181 | 182 | 185 | 186 | 187 | {% endfor %} 188 | 189 | {% for fp in example["fp"] %} 190 | 191 | 194 | 195 | 196 | {% endfor %} 197 | 198 | {% for fn in example["fn"] %} 199 | 200 | 201 | 202 | 203 | {% endfor %} 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 |
ScoreText
183 | {{ "%.4f"|format(tp[2]) }} 184 | {{ tp[0] | safe }}
192 | {{ "%.4f"|format(fp[2]) }} 193 | {{ fp[0] | safe }}
{{ fn[0] | safe }}
212 |
213 |
214 |
215 | {% endfor %} 216 |
217 |
218 | 219 | 222 | 225 | 228 | 229 | -------------------------------------------------------------------------------- /data/datasets/vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import Counter 3 | 4 | 5 | class Vocab(object): 6 | PAD = '*PAD*' 7 | UNK = '*UNK*' 8 | NULL = '*NULL*' 9 | START = '*START*' 10 | END = '*END*' 11 | ROOT = '*ROOT*' 12 | SELF_LOOP = 'SELF_LOOP' 13 | 14 | def __init__(self): 15 | 16 | self.tok2idx = {} 17 | self.idx2count = {} 18 | self.idx2tok = {} 19 | self.special_token_size = 0 20 | self.singleton_size = 0 21 | self.singleton_max_count = 0 22 | self.min_count = 0 23 | self.max_count = 0 24 | 25 | def __len__(self): 26 | return len(self.tok2idx) 27 | 28 | def items(self): 29 | for k, v in self.tok2idx.items(): 30 | yield k, v 31 | 32 | def __getitem__(self, item): 33 | return self.tok2idx[item] 34 | 35 | def keys(self): 36 | return self.tok2idx.keys() 37 | 38 | def vals(self): 39 | return self.tok2idx.values() 40 | 41 | def add_counter(self, 42 | counter, 43 | min_count=1, 44 | max_count=1e7, 45 | singleton_max_count=1, 46 | update_count=False 47 | ): 48 | ''' 49 | 50 | :param counter: 51 | :param min_count: 52 | :param max_count: 53 | :param singleton_max_count: int, we treat a token as a singleton 54 | when 0 < token count <= singleton_max_count, 55 | this is in favor of some UNK replace strategies 56 | :return: 57 | ''' 58 | self.min_count = min_count 59 | self.max_count = max_count 60 | self.singleton_max_count = singleton_max_count 61 | 62 | for tok, count in counter.most_common(n=int(max_count)): 63 | if count >= self.min_count: 64 | self.add_token(tok, count, update_count) 65 | 66 | def add_spec_toks(self, 67 | pad_tok=True, 68 | unk_tok=True, 69 | start_tok=False, 70 | end_tok=False, 71 | root_tok=False, 72 | null_tok=False, 73 | self_loop_tok=False): 74 | if pad_tok: 75 | self.add_token(Vocab.PAD) 76 | self.special_token_size += 1 77 | 78 | if unk_tok: 79 | self.add_token(Vocab.UNK) 80 | self.special_token_size += 1 81 | 82 | if start_tok: 83 | self.add_token(Vocab.START) 84 | self.special_token_size += 1 85 | 86 | if end_tok: 87 | self.add_token(Vocab.END) 88 | self.special_token_size += 1 89 | 90 | if root_tok: 91 | self.add_token(Vocab.ROOT) 92 | self.special_token_size += 1 93 | 94 | if null_tok: 95 | self.add_token(Vocab.NULL) 96 | self.special_token_size += 1 97 | 98 | if self_loop_tok: 99 | self.add_token(Vocab.SELF_LOOP) 100 | self.special_token_size += 1 101 | 102 | def add_token(self, token, count=1, update_count=False): 103 | idx = self.tok2idx.get(token, None) 104 | if idx is None: 105 | idx = len(self.tok2idx) 106 | self.tok2idx[token] = idx 107 | self.idx2count[idx] = count 108 | if count <= self.singleton_max_count: 109 | self.singleton_size += 1 110 | elif update_count: 111 | new_count = self.idx2count[idx] + count 112 | self.idx2count[idx] = new_count 113 | if new_count > self.singleton_max_count: 114 | self.singleton_size -= 1 115 | 116 | return idx 117 | 118 | def get_vocab_size(self): 119 | return len(self.tok2idx) 120 | 121 | def get_vocab_size_without_spec(self): 122 | return len(self.tok2idx) - self.special_token_size 123 | 124 | def get_index(self, token, default_value='*UNK*'): 125 | idx = self.tok2idx.get(token, None) 126 | if idx is None: 127 | if default_value: 128 | try: 129 | return self.tok2idx[default_value] 130 | except KeyError: 131 | print('Token %s not found' % token) 132 | exit(1) 133 | else: 134 | raise RuntimeError('Token %s not found' % token) 135 | else: 136 | return idx 137 | 138 | def __iter__(self): 139 | for tok, idx in self.tok2idx.items(): 140 | yield idx, tok 141 | 142 | def get_token(self, index): 143 | if len(self.idx2tok) == 0: 144 | for tok, idx in self.tok2idx.items(): 145 | self.idx2tok[idx] = tok 146 | 147 | return self.idx2tok[index] 148 | 149 | def get_token_set(self): 150 | return self.tok2idx.keys() 151 | 152 | def recount_singleton_size(self): 153 | singleton_size = 0 154 | for count in self.idx2count.values(): 155 | if count <= self.singleton_max_count: 156 | singleton_size += 1 157 | self.singleton_max_count = singleton_size 158 | 159 | def get_singleton_size(self, re_count=False): 160 | if re_count: 161 | self.recount_singleton_size() 162 | return self.singleton_size 163 | 164 | def is_singleton(self, token_or_index): 165 | if isinstance(token_or_index, str): 166 | idx = self.tok2idx.get(token_or_index, None) 167 | if idx is None: 168 | # we treat OOV as singletons 169 | return True 170 | elif isinstance(token_or_index, int): 171 | idx = token_or_index 172 | else: 173 | raise TypeError('Unknown type %s' % (type(token_or_index))) 174 | 175 | count = self.idx2count[idx] 176 | return count <= self.singleton_max_count 177 | 178 | def __contains__(self, token): 179 | return token in self.tok2idx 180 | 181 | def __str__(self): 182 | spec_tok_size_str = 'special_token_size\t' + str(self.special_token_size) + '\n' 183 | tok_size_str = 'token_size\t' + str(self.get_vocab_size_without_spec()) + '\n' 184 | singleton_size_str = 'singleton_size\t' + str(self.singleton_size) + '\n' 185 | singleton_max_count_str = 'singleton_max_count\t' + str(self.singleton_max_count) + '\n' 186 | return spec_tok_size_str + tok_size_str + singleton_size_str + singleton_max_count_str 187 | 188 | def save(self, file_path, format='text'): 189 | ''' 190 | 191 | :param format: 'pickle' or 'text' 192 | :return: 193 | ''' 194 | 195 | if format == 'pickle': 196 | with open(file_path, 'wb', encoding='utf-8') as file: 197 | pickle.dump(self, file) 198 | 199 | elif format == 'text': 200 | with open(file_path, 'w', encoding='utf-8') as file: 201 | file.write(str(self)) 202 | # write tokens by index increase order 203 | tok2idx_list = sorted(list(self.tok2idx.items()), key=lambda x: x[1]) 204 | for tok, idx in tok2idx_list: 205 | count = self.idx2count[idx] 206 | file.write(str(idx) + '\t' + tok + '\t' + str(count) + '\n') 207 | else: 208 | raise RuntimeError('Unknown save format') 209 | 210 | @staticmethod 211 | def load(file_path, format='text'): 212 | 213 | if format == 'pickle': 214 | with open(file_path, 'rb') as file: 215 | return pickle.load(file) 216 | 217 | elif format == 'text': 218 | with open(file_path, 'r', encoding='utf-8') as file: 219 | vocab = Vocab() 220 | lines = list(file.readlines()) 221 | spec_tok_size = int(lines[0].split('\t')[1]) 222 | tok_size = int(lines[1].split('\t')[1]) 223 | singleton_size = int(lines[2].split('\t')[1]) 224 | singleton_max_count = int(lines[3].split('\t')[1]) 225 | 226 | vocab.special_token_size = spec_tok_size 227 | vocab.singleton_size = singleton_size 228 | vocab.singleton_max_count = singleton_max_count 229 | 230 | offset = 4 # skip head information 231 | for i in range(offset, offset + tok_size + spec_tok_size): 232 | line_arr = lines[i].strip().split('\t') 233 | idx, tok, count = int(line_arr[0]), line_arr[1], int(line_arr[2]) 234 | vocab.tok2idx[tok] = idx 235 | vocab.idx2count[idx] = count 236 | return vocab 237 | 238 | else: 239 | raise RuntimeError('Unknown load format') 240 | -------------------------------------------------------------------------------- /SynFue/input_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import abstractmethod, ABC 3 | from collections import OrderedDict 4 | from logging import Logger 5 | from typing import Iterable, List 6 | 7 | from tqdm import tqdm 8 | from transformers import BertTokenizer 9 | 10 | from SynFue import util 11 | from SynFue.terms import Dataset, TermType, RelationType, Term, Relation, Document 12 | 13 | 14 | class BaseInputReader(ABC): 15 | def __init__(self, types_path: str, tokenizer: BertTokenizer, neg_term_count: int = None, 16 | neg_rel_count: int = None, max_span_size: int = None, logger: Logger = None): 17 | types = json.load(open(types_path), object_pairs_hook=OrderedDict) # term + relation types 18 | 19 | self._term_types = OrderedDict() 20 | self._idx2term_type = OrderedDict() 21 | self._relation_types = OrderedDict() 22 | self._idx2relation_type = OrderedDict() 23 | 24 | # terms 25 | # add 'None' term type 26 | none_term_type = TermType('None', 0, 'None', 'No Term') 27 | self._term_types['None'] = none_term_type 28 | self._idx2term_type[0] = none_term_type 29 | 30 | # specified term types 31 | for i, (key, v) in enumerate(types['terms'].items()): 32 | term_type = TermType(key, i + 1, v['short'], v['verbose']) 33 | self._term_types[key] = term_type 34 | self._idx2term_type[i + 1] = term_type 35 | 36 | # relations 37 | # add 'None' relation type 38 | none_relation_type = RelationType('None', 0, 'None', 'No Relation') 39 | self._relation_types['None'] = none_relation_type 40 | self._idx2relation_type[0] = none_relation_type 41 | 42 | # specified relation types 43 | for i, (key, v) in enumerate(types['relations'].items()): 44 | relation_type = RelationType(key, i + 1, v['short'], v['verbose'], v['symmetric']) 45 | self._relation_types[key] = relation_type 46 | self._idx2relation_type[i + 1] = relation_type 47 | 48 | self._neg_term_count = neg_term_count 49 | self._neg_rel_count = neg_rel_count 50 | self._max_span_size = max_span_size 51 | 52 | self._datasets = dict() 53 | 54 | self._tokenizer = tokenizer 55 | self._logger = logger 56 | 57 | self._vocabulary_size = tokenizer.vocab_size 58 | self._context_size = -1 59 | 60 | @abstractmethod 61 | def read(self, datasets): 62 | pass 63 | 64 | def get_dataset(self, label) -> Dataset: 65 | return self._datasets[label] 66 | 67 | def get_term_type(self, idx) -> TermType: 68 | term = self._idx2term_type[idx] 69 | return term 70 | 71 | def get_relation_type(self, idx) -> RelationType: 72 | relation = self._idx2relation_type[idx] 73 | return relation 74 | 75 | def _calc_context_size(self, datasets: Iterable[Dataset]): 76 | sizes = [] 77 | 78 | for dataset in datasets: 79 | for doc in dataset.documents: 80 | sizes.append(len(doc.encoding)) 81 | 82 | context_size = max(sizes) 83 | return context_size 84 | 85 | def _log(self, text): 86 | if self._logger is not None: 87 | self._logger.info(text) 88 | 89 | @property 90 | def datasets(self): 91 | return self._datasets 92 | 93 | @property 94 | def term_types(self): 95 | return self._term_types 96 | 97 | @property 98 | def relation_types(self): 99 | return self._relation_types 100 | 101 | @property 102 | def relation_type_count(self): 103 | return len(self._relation_types) 104 | 105 | @property 106 | def term_type_count(self): 107 | return len(self._term_types) 108 | 109 | @property 110 | def vocabulary_size(self): 111 | return self._vocabulary_size 112 | 113 | @property 114 | def context_size(self): 115 | return self._context_size 116 | 117 | def __str__(self): 118 | string = "" 119 | for dataset in self._datasets.values(): 120 | string += "Dataset: %s\n" % dataset 121 | string += str(dataset) 122 | 123 | return string 124 | 125 | def __repr__(self): 126 | return self.__str__() 127 | 128 | 129 | class JsonInputReader(BaseInputReader): 130 | def __init__(self, types_path: str, tokenizer: BertTokenizer, neg_term_count: int = None, 131 | neg_rel_count: int = None, max_span_size: int = None, logger: Logger = None): 132 | super().__init__(types_path, tokenizer, neg_term_count, neg_rel_count, max_span_size, logger) 133 | 134 | def read(self, dataset_paths): 135 | for dataset_label, dataset_path in dataset_paths.items(): 136 | dataset = Dataset(dataset_label, self._relation_types, self._term_types, self._neg_term_count, 137 | self._neg_rel_count, self._max_span_size) 138 | self._parse_dataset(dataset_path, dataset) 139 | self._datasets[dataset_label] = dataset 140 | 141 | self._context_size = self._calc_context_size(self._datasets.values()) 142 | 143 | def _parse_dataset(self, dataset_path, dataset): 144 | documents = json.load(open(dataset_path)) 145 | for document in tqdm(documents, desc="Parse dataset '%s'" % dataset.label): 146 | self._parse_document(document, dataset) 147 | 148 | def _parse_document(self, doc, dataset) -> Document: 149 | jtokens = doc['tokens'] 150 | jrelations = doc['relations'] 151 | jterms = doc['entities'] 152 | jdep_label = doc['dep_label'] 153 | jdep_label_indices = doc['dep_label_indices'] 154 | jdep = doc['dep'] 155 | jpos = doc['pos'] 156 | jpos_indices = doc['pos_indices'] 157 | 158 | # parse tokens 159 | doc_tokens, doc_encoding = self._parse_tokens(jtokens, dataset) 160 | 161 | # parse term mentions 162 | terms = self._parse_terms(jterms, doc_tokens, dataset) 163 | 164 | # parse relations 165 | relations = self._parse_relations(jrelations, terms, dataset) 166 | 167 | # create document 168 | document = dataset.create_document(doc_tokens, terms, relations, doc_encoding, jdep_label, 169 | jdep_label_indices, jdep, jpos, jpos_indices) 170 | 171 | return document 172 | 173 | def _parse_tokens(self, jtokens, dataset): 174 | doc_tokens = [] 175 | 176 | # full document encoding including special tokens ([CLS] and [SEP]) and byte-pair encodings of original tokens 177 | doc_encoding = [self._tokenizer.convert_tokens_to_ids('[CLS]')] 178 | 179 | # parse tokens 180 | for i, token_phrase in enumerate(jtokens): 181 | token_encoding = self._tokenizer.encode(token_phrase, add_special_tokens=False) 182 | span_start, span_end = i, i + 1 183 | sub_token_start, sub_token_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding)) 184 | 185 | token = dataset.create_token(i, span_start, span_end, token_phrase, sub_token_start, sub_token_end) 186 | 187 | doc_tokens.append(token) 188 | doc_encoding += token_encoding 189 | 190 | doc_encoding += [self._tokenizer.convert_tokens_to_ids('[SEP]')] 191 | 192 | return doc_tokens, doc_encoding 193 | 194 | def _parse_terms(self, jterms, doc_tokens, dataset) -> List[Term]: 195 | terms = [] 196 | 197 | for term_idx, jterm in enumerate(jterms): 198 | term_type = self._term_types[jterm['type']] 199 | start, end = jterm['start'], jterm['end'] 200 | 201 | # create term mention 202 | tokens = doc_tokens[start:end+1] 203 | phrase = " ".join([t.phrase for t in tokens]) 204 | term = dataset.create_term(term_type, tokens, phrase) 205 | terms.append(term) 206 | 207 | return terms 208 | 209 | def _parse_relations(self, jrelations, terms, dataset) -> List[Relation]: 210 | relations = [] 211 | 212 | for jrelation in jrelations: 213 | relation_type = self._relation_types[jrelation['type']] 214 | 215 | head_idx = jrelation['head'] 216 | tail_idx = jrelation['tail'] 217 | 218 | # create relation 219 | head = terms[head_idx] 220 | tail = terms[tail_idx] 221 | 222 | reverse = int(tail.tokens[0].index) < int(head.tokens[0].index) 223 | 224 | # for symmetric relations: head occurs before tail in sentence 225 | if relation_type.symmetric and reverse: 226 | head, tail = util.swap(head, tail) 227 | 228 | relation = dataset.create_relation(relation_type, head_term=head, tail_term=tail, reverse=reverse) 229 | relations.append(relation) 230 | 231 | return relations 232 | -------------------------------------------------------------------------------- /SynFue/templates/relation_examples.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Relation Extraction Examples 6 | 7 | 9 | 10 | 11 | 12 | 112 | 113 | 114 | 115 |

Relation Extraction Examples ({{ examples|length }})

116 | 117 |
118 |
119 | check_circle_outline   F1 = 100.000
120 | check_circle_outline   F1 >= 50.00
121 | highlight_off   F1 < 50.00
122 |
123 | 124 |
125 |
126 |   True Positive  
127 |   False Positive  
128 |   False Negative
129 |
130 |
131 |

132 | 133 |
134 |
135 | {% for example in examples %} 136 | {% set outer_loop = loop %} 137 | 138 |
139 | 174 | 175 |
176 |
177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | {% for tp in example["tp"] %} 190 | 191 | 194 | 195 | 196 | 197 | 198 | {% endfor %} 199 | 200 | {% for fp in example["fp"] %} 201 | 202 | 205 | 206 | 207 | 208 | {% endfor %} 209 | 210 | {% for fn in example["fn"] %} 211 | 212 | 213 | 214 | 215 | 216 | {% endfor %} 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 |
ScoreRelationText (Head - Tail)
192 | {{ "%.4f"|format(tp[2]) }} 193 | {{ tp[1] }}{{ tp[0] | safe }}
203 | {{ "%.4f"|format(fp[2]) }} 204 | {{ fp[1] }}{{ fp[0] | safe }}
{{ fn[1] }}{{ fn[0] | safe }}
226 |
227 |
228 |
229 | {% endfor %} 230 |
231 |
232 | 233 | 236 | 239 | 242 | 243 | -------------------------------------------------------------------------------- /SynFue/terms.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List 3 | from torch.utils.data import Dataset as TorchDataset 4 | 5 | from SynFue import sampling 6 | 7 | 8 | class RelationType: 9 | def __init__(self, identifier, index, short_name, verbose_name, symmetric=False): 10 | self._identifier = identifier 11 | self._index = index 12 | self._short_name = short_name 13 | self._verbose_name = verbose_name 14 | self._symmetric = symmetric 15 | 16 | @property 17 | def identifier(self): 18 | return self._identifier 19 | 20 | @property 21 | def index(self): 22 | return self._index 23 | 24 | @property 25 | def short_name(self): 26 | return self._short_name 27 | 28 | @property 29 | def verbose_name(self): 30 | return self._verbose_name 31 | 32 | @property 33 | def symmetric(self): 34 | return self._symmetric 35 | 36 | def __int__(self): 37 | return self._index 38 | 39 | def __eq__(self, other): 40 | if isinstance(other, RelationType): 41 | return self._identifier == other._identifier 42 | return False 43 | 44 | def __hash__(self): 45 | return hash(self._identifier) 46 | 47 | 48 | class TermType: 49 | def __init__(self, identifier, index, short_name, verbose_name): 50 | self._identifier = identifier 51 | self._index = index 52 | self._short_name = short_name 53 | self._verbose_name = verbose_name 54 | 55 | @property 56 | def identifier(self): 57 | return self._identifier 58 | 59 | @property 60 | def index(self): 61 | return self._index 62 | 63 | @property 64 | def short_name(self): 65 | return self._short_name 66 | 67 | @property 68 | def verbose_name(self): 69 | return self._verbose_name 70 | 71 | def __int__(self): 72 | return self._index 73 | 74 | def __eq__(self, other): 75 | if isinstance(other, TermType): 76 | return self._identifier == other._identifier 77 | return False 78 | 79 | def __hash__(self): 80 | return hash(self._identifier) 81 | 82 | 83 | class Token: 84 | def __init__(self, tid: int, index: int, span_start: int, span_end: int, phrase: str, 85 | sub_token_start: int, sub_token_end: int): 86 | self._tid = tid # ID within the corresponding dataset 87 | self._index = index # original token index in document 88 | 89 | self._span_start = span_start # start of token in document (inclusive) 90 | self._span_end = span_end # end of token in document (exclusive) 91 | self._phrase = phrase 92 | self._sub_token_start = sub_token_start # start of sub word in a document after a WordPiece Tokenizer 93 | self._sub_token_end = sub_token_end # end of sub word in a document after a WordPiece Tokenizer 94 | 95 | @property 96 | def index(self): 97 | return self._index 98 | 99 | @property 100 | def span_start(self): 101 | return self._span_start 102 | 103 | @property 104 | def span_end(self): 105 | return self._span_end 106 | 107 | @property 108 | def span(self): 109 | return self._span_start, self._span_end 110 | 111 | @property 112 | def sub_token_start(self): 113 | return self._sub_token_start 114 | 115 | @property 116 | def sub_token_end(self): 117 | return self._sub_token_end 118 | 119 | @property 120 | def sub_token(self): 121 | return self._sub_token_start, self._sub_token_end 122 | 123 | @property 124 | def phrase(self): 125 | return self._phrase 126 | 127 | def __eq__(self, other): 128 | if isinstance(other, Token): 129 | return self._tid == other._tid 130 | return False 131 | 132 | def __hash__(self): 133 | return hash(self._tid) 134 | 135 | def __str__(self): 136 | return self._phrase 137 | 138 | def __repr__(self): 139 | return self._phrase 140 | 141 | 142 | class TokenSpan: 143 | def __init__(self, tokens): 144 | self._tokens = tokens 145 | 146 | @property 147 | def span_start(self): 148 | return self._tokens[0].span_start 149 | 150 | @property 151 | def span_end(self): 152 | return self._tokens[-1].span_end 153 | 154 | @property 155 | def span(self): 156 | return self.span_start, self.span_end 157 | 158 | @property 159 | def sub_token_start(self): 160 | return self._tokens[0].sub_token_start 161 | 162 | @property 163 | def sub_token_end(self): 164 | return self._tokens[0].sub_token_end 165 | 166 | @property 167 | def sub_token(self): 168 | return self.sub_token_start, self.sub_token_end 169 | 170 | def __getitem__(self, s): 171 | if isinstance(s, slice): 172 | return TokenSpan(self._tokens[s.start:s.stop:s.step]) 173 | else: 174 | try: 175 | return self._tokens[s] 176 | except: 177 | print(self._tokens) 178 | print(len(self._tokens)) 179 | print(s) 180 | 181 | def __iter__(self): 182 | return iter(self._tokens) 183 | 184 | def __len__(self): 185 | return len(self._tokens) 186 | 187 | 188 | class Term: 189 | def __init__(self, eid: int, term_type: TermType, tokens: List[Token], phrase: str): 190 | self._eid = eid # ID within the corresponding dataset 191 | 192 | self._term_type = term_type 193 | 194 | self._tokens = tokens 195 | self._phrase = phrase 196 | 197 | def as_tuple(self): 198 | return self.span_start, self.span_end, self._term_type 199 | 200 | @property 201 | def term_type(self): 202 | return self._term_type 203 | 204 | @property 205 | def tokens(self): 206 | return TokenSpan(self._tokens) 207 | 208 | @property 209 | def span_start(self): 210 | return self._tokens[0].span_start 211 | 212 | @property 213 | def span_end(self): 214 | return self._tokens[-1].span_end 215 | 216 | @property 217 | def span(self): 218 | return self.span_start, self.span_end 219 | 220 | @property 221 | def sub_token_start(self): 222 | return self._tokens[0].sub_token_start 223 | 224 | @property 225 | def sub_token_end(self): 226 | return self._tokens[-1].sub_token_end 227 | 228 | @property 229 | def sub_token(self): 230 | return self.sub_token_start, self.sub_token_end 231 | 232 | 233 | @property 234 | def phrase(self): 235 | return self._phrase 236 | 237 | def __eq__(self, other): 238 | if isinstance(other, Term): 239 | return self._eid == other._eid 240 | return False 241 | 242 | def __hash__(self): 243 | return hash(self._eid) 244 | 245 | def __str__(self): 246 | return self._phrase 247 | 248 | 249 | class Relation: 250 | def __init__(self, rid: int, relation_type: RelationType, head_term: Term, 251 | tail_term: Term, reverse: bool = False): 252 | self._rid = rid # ID within the corresponding dataset 253 | self._relation_type = relation_type 254 | 255 | self._head_term = head_term 256 | self._tail_term = tail_term 257 | 258 | self._reverse = reverse 259 | 260 | self._first_term = head_term if not reverse else tail_term 261 | self._second_term = tail_term if not reverse else head_term 262 | 263 | def as_tuple(self): 264 | head = self._head_term 265 | tail = self._tail_term 266 | head_start, head_end = (head.span_start, head.span_end) 267 | tail_start, tail_end = (tail.span_start, tail.span_end) 268 | 269 | t = ((head_start, head_end, head.term_type), 270 | (tail_start, tail_end, tail.term_type), self._relation_type) 271 | return t 272 | 273 | @property 274 | def relation_type(self): 275 | return self._relation_type 276 | 277 | @property 278 | def head_term(self): 279 | return self._head_term 280 | 281 | @property 282 | def tail_term(self): 283 | return self._tail_term 284 | 285 | @property 286 | def first_term(self): 287 | return self._first_term 288 | 289 | @property 290 | def second_term(self): 291 | return self._second_term 292 | 293 | @property 294 | def reverse(self): 295 | return self._reverse 296 | 297 | def __eq__(self, other): 298 | if isinstance(other, Relation): 299 | return self._rid == other._rid 300 | return False 301 | 302 | def __hash__(self): 303 | return hash(self._rid) 304 | 305 | 306 | class Document: 307 | def __init__(self, doc_id: int, tokens: List[Token], terms: List[Term], relations: List[Relation], 308 | encoding: List[int], dep_label: List[int], dep_label_indices: List[int], dep: List[int], 309 | pos: List[str], pos_indices: List[int]): 310 | self._doc_id = doc_id # ID within the corresponding dataset 311 | 312 | self._tokens = tokens 313 | self._terms = terms 314 | self._relations = relations 315 | 316 | # byte-pair document encoding including special tokens ([CLS] and [SEP]) 317 | self._encoding = encoding 318 | 319 | self._dep_label = dep_label 320 | self._dep_label_indices = dep_label_indices 321 | self._dep = dep 322 | 323 | self._pos = pos 324 | self._pos_indices = pos_indices 325 | 326 | @property 327 | def doc_id(self): 328 | return self._doc_id 329 | 330 | @property 331 | def terms(self): 332 | return self._terms 333 | 334 | @property 335 | def relations(self): 336 | return self._relations 337 | 338 | @property 339 | def tokens(self): 340 | return TokenSpan(self._tokens) 341 | 342 | @property 343 | def encoding(self): 344 | return self._encoding 345 | 346 | @property 347 | def dep_label(self): 348 | return self._dep_label 349 | 350 | @property 351 | def dep_label_indices(self): 352 | return self._dep_label_indices 353 | 354 | @property 355 | def dep(self): 356 | return self._dep 357 | 358 | @property 359 | def pos_indices(self): 360 | return self._pos_indices 361 | 362 | @property 363 | def pos(self): 364 | return self._pos 365 | 366 | @encoding.setter 367 | def encoding(self, value): 368 | self._encoding = value 369 | 370 | def __eq__(self, other): 371 | if isinstance(other, Document): 372 | return self._doc_id == other._doc_id 373 | return False 374 | 375 | def __hash__(self): 376 | return hash(self._doc_id) 377 | 378 | 379 | class BatchIterator: 380 | def __init__(self, terms, batch_size, order=None, truncate=False): 381 | self._terms = terms 382 | self._batch_size = batch_size 383 | self._truncate = truncate 384 | self._length = len(self._terms) 385 | self._order = order 386 | 387 | if order is None: 388 | self._order = list(range(len(self._terms))) 389 | 390 | self._i = 0 391 | 392 | def __iter__(self): 393 | return self 394 | 395 | def __next__(self): 396 | if self._truncate and self._i + self._batch_size > self._length: 397 | raise StopIteration 398 | elif not self._truncate and self._i >= self._length: 399 | raise StopIteration 400 | else: 401 | terms = [self._terms[n] for n in self._order[self._i:self._i + self._batch_size]] 402 | self._i += self._batch_size 403 | return terms 404 | 405 | 406 | class Dataset(TorchDataset): 407 | TRAIN_MODE = 'train' 408 | EVAL_MODE = 'eval' 409 | 410 | def __init__(self, label, rel_types, term_types, neg_term_count, 411 | neg_rel_count, max_span_size): 412 | self._label = label 413 | self._rel_types = rel_types 414 | self._term_types = term_types 415 | self._neg_term_count = neg_term_count 416 | self._neg_rel_count = neg_rel_count 417 | self._max_span_size = max_span_size 418 | self._mode = Dataset.TRAIN_MODE 419 | 420 | self._documents = OrderedDict() 421 | self._terms = OrderedDict() 422 | self._relations = OrderedDict() 423 | 424 | # current ids 425 | self._doc_id = 0 426 | self._rid = 0 427 | self._eid = 0 428 | self._tid = 0 429 | 430 | def iterate_documents(self, batch_size, order=None, truncate=False): 431 | return BatchIterator(self.documents, batch_size, order=order, truncate=truncate) 432 | 433 | def iterate_relations(self, batch_size, order=None, truncate=False): 434 | return BatchIterator(self.relations, batch_size, order=order, truncate=truncate) 435 | 436 | def create_token(self, idx, span_start, span_end, phrase, sub_token_start, sub_token_end) -> Token: 437 | token = Token(self._tid, idx, span_start, span_end, phrase, sub_token_start, sub_token_end) 438 | self._tid += 1 439 | return token 440 | 441 | def create_document(self, tokens, term_mentions, relations, doc_encoding, dep_label, dep_label_indices, dep, 442 | pos, pos_indices) -> Document: 443 | document = Document(self._doc_id, tokens, term_mentions, relations, doc_encoding, dep_label, 444 | dep_label_indices, dep, pos, pos_indices) 445 | self._documents[self._doc_id] = document 446 | self._doc_id += 1 447 | 448 | return document 449 | 450 | def create_term(self, term_type, tokens, phrase) -> Term: 451 | mention = Term(self._eid, term_type, tokens, phrase) 452 | self._terms[self._eid] = mention 453 | self._eid += 1 454 | return mention 455 | 456 | def create_relation(self, relation_type, head_term, tail_term, reverse=False) -> Relation: 457 | relation = Relation(self._rid, relation_type, head_term, tail_term, reverse) 458 | self._relations[self._rid] = relation 459 | self._rid += 1 460 | return relation 461 | 462 | def __len__(self): 463 | return len(self._documents) 464 | 465 | def __getitem__(self, index: int): 466 | doc = self._documents[index] 467 | 468 | if self._mode == Dataset.TRAIN_MODE: 469 | return sampling.create_train_sample(doc, self._neg_term_count, self._neg_rel_count, 470 | self._max_span_size, len(self._rel_types)) 471 | else: 472 | return sampling.create_eval_sample(doc, self._max_span_size) 473 | 474 | def switch_mode(self, mode): 475 | self._mode = mode 476 | 477 | @property 478 | def label(self): 479 | return self._label 480 | 481 | @property 482 | def input_reader(self): 483 | return self._input_reader 484 | 485 | @property 486 | def documents(self): 487 | return list(self._documents.values()) 488 | 489 | @property 490 | def terms(self): 491 | return list(self._terms.values()) 492 | 493 | @property 494 | def relations(self): 495 | return list(self._relations.values()) 496 | 497 | @property 498 | def document_count(self): 499 | return len(self._documents) 500 | 501 | @property 502 | def term_count(self): 503 | return len(self._terms) 504 | 505 | @property 506 | def relation_count(self): 507 | return len(self._relations) 508 | -------------------------------------------------------------------------------- /data/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Read data from JSON files, 5 | in the meantime, we do preprocess like capitalize the first character of a sentence or normalize digits 6 | ''' 7 | import os 8 | import json 9 | from collections import Counter 10 | from nltk.parse import CoreNLPDependencyParser 11 | import numpy as np 12 | import argparse 13 | from tqdm import tqdm 14 | from io_utils import read_yaml 15 | from str_utils import normalize_tok 16 | from vocab import Vocab 17 | from sklearn.model_selection import train_test_split 18 | 19 | config = read_yaml('data_config.yaml') 20 | print('seed:', config['random_seed']) 21 | normalize_digits = config['normalize_digits'] 22 | lower_case = config['lower_case'] 23 | 24 | depparser = CoreNLPDependencyParser(url='http://127.0.0.1:9000') 25 | 26 | 27 | def build_vocab(data_list, file_path): 28 | token_list = [] 29 | char_list = [] 30 | pos_list = [] 31 | dep_list = [] 32 | for inst in tqdm(data_list, total=len(data_list)): 33 | words = inst['words'] 34 | 35 | temp_parser_res = depparser.parse(words) 36 | parser_res = [] 37 | for i in temp_parser_res: 38 | temp = i.to_conll(4).strip().split('\n') 39 | for t in temp: 40 | parser_res.append(t.split('\t')) 41 | pos_list.extend([a[1] for a in parser_res]) 42 | dep_list.extend([a[3] for a in parser_res]) 43 | 44 | for word in words: 45 | word = normalize_tok(word, lower_case, normalize_digits) 46 | token_list.append(word) 47 | char_list.extend(list(word)) 48 | 49 | token_vocab_file = os.path.join(file_path, config['token_vocab_file']) 50 | char_vocab_file = os.path.join(file_path, config['char_vocab_file']) 51 | pos_vocab_file = os.path.join(file_path, config['pos_vocab_file']) 52 | dep_type_vocab_file = os.path.join(file_path, config['dep_type_vocab_file']) 53 | 54 | print('--------token_vocab---------------') 55 | token_vocab = Vocab() 56 | token_vocab.add_spec_toks(unk_tok=True, pad_tok=False) 57 | token_vocab.add_counter(Counter(token_list)) 58 | token_vocab.save(token_vocab_file) 59 | print(token_vocab) 60 | 61 | print('--------char_vocab---------------') 62 | char_vocab = Vocab() 63 | char_vocab.add_spec_toks(unk_tok=True, pad_tok=False) 64 | char_vocab.add_counter(Counter(char_list)) 65 | char_vocab.save(char_vocab_file) 66 | print(char_vocab) 67 | 68 | print('--------pos_vocab---------------') 69 | pos_vocab = Vocab() 70 | pos_vocab.add_spec_toks(pad_tok=True, unk_tok=False) 71 | pos_vocab.add_counter(Counter(pos_list)) 72 | pos_vocab.save(pos_vocab_file) 73 | print(pos_vocab) 74 | 75 | print('--------dep_vocab---------------') 76 | dep_vocab = Vocab() 77 | dep_vocab.add_spec_toks(pad_tok=True, unk_tok=False, self_loop_tok=True) 78 | dep_vocab.add_counter(Counter(dep_list)) 79 | dep_vocab.save(dep_type_vocab_file) 80 | print(dep_vocab) 81 | 82 | return dep_vocab, pos_vocab 83 | 84 | 85 | def load_data_from_ori(file_path): 86 | pre_id = '' 87 | data_list = [] 88 | data_dic = {} 89 | token = [] 90 | opinion = [] 91 | opinion_idx = [] 92 | aspect = [] 93 | aspect_idx = [] 94 | pair = [] 95 | pair_idx = [] 96 | with open(file_path, 'r', encoding='utf-8') as f: 97 | f.readline() 98 | while True: 99 | line = f.readline() 100 | if line == '': 101 | token = token_temp 102 | aspect.append(a_temp) 103 | aspect_idx.append(a_temp_idx) 104 | opinion.append(o_temp) 105 | opinion_idx.append(o_temp_idx) 106 | pair.extend(pair_temp) 107 | pair_idx.extend(pair_idx_temp) 108 | data_dic['words'] = token 109 | data_dic['aspects'] = aspect 110 | data_dic['aspects_idx'] = aspect_idx 111 | data_dic['opinions'] = opinion 112 | data_dic['opinions_idx'] = opinion_idx 113 | data_dic['pair'] = pair 114 | data_dic['pair_idx'] = pair_idx 115 | data_list.append(data_dic) 116 | break 117 | line = line.split('\t') 118 | if pre_id == line[0]: 119 | token = token_temp 120 | aspect.append(a_temp) 121 | aspect_idx.append(a_temp_idx) 122 | opinion.append(o_temp) 123 | opinion_idx.append(o_temp_idx) 124 | pair.extend(pair_temp) 125 | pair_idx.extend(pair_idx_temp) 126 | elif pre_id != '' and pre_id != line[0]: 127 | token = token_temp 128 | aspect.append(a_temp) 129 | aspect_idx.append(a_temp_idx) 130 | opinion.append(o_temp) 131 | opinion_idx.append(o_temp_idx) 132 | pair.extend(pair_temp) 133 | pair_idx.extend(pair_idx_temp) 134 | data_dic['words'] = token 135 | data_dic['aspects'] = aspect 136 | data_dic['aspects_idx'] = aspect_idx 137 | data_dic['opinions'] = opinion 138 | data_dic['opinions_idx'] = opinion_idx 139 | data_dic['pair'] = pair 140 | data_dic['pair_idx'] = pair_idx 141 | 142 | data_list.append(data_dic) 143 | data_dic = {} 144 | token = [] 145 | opinion = [] 146 | opinion_idx = [] 147 | aspect = [] 148 | aspect_idx = [] 149 | pair = [] 150 | pair_idx = [] 151 | try: 152 | token_temp = line[1].strip().split() 153 | except: 154 | print(line) 155 | # aspect term 156 | a_temp = [] 157 | a_temp_idx = [] 158 | for idx, a in enumerate(line[2].strip().split()): 159 | if a.strip().split('\\')[1] != 'O': 160 | a_temp.append(a.strip().split('\\')[0]) 161 | a_temp_idx.append(idx) 162 | # opinion term 163 | o_temp = [] 164 | o_temp_idx = [] 165 | for idx, o in enumerate(line[3].strip().split()): 166 | if o.strip().split('\\')[1] != 'O': 167 | o_temp.append(o.strip().split('\\')[0]) 168 | o_temp_idx.append(idx) 169 | pair_temp = [] 170 | pair_idx_temp = [] 171 | pair_temp.append((a_temp, o_temp)) 172 | pair_idx_temp.append((a_temp_idx, o_temp_idx)) 173 | pre_id = line[0] 174 | return data_list 175 | 176 | 177 | def sep_discontinuous_term(term_idx): 178 | """ 179 | according to the source term idx, get the new idx, term rep = [start_idx, end_idx] and it is consequent 180 | :param term_idx: 181 | :return: 182 | """ 183 | bert_idx = [] 184 | flag = 0 # Whether to include discontinuous terms 185 | for idx in term_idx: 186 | if len(idx) < 2: 187 | if [idx[0], idx[-1]] not in bert_idx: 188 | bert_idx.append([idx[0], idx[-1]]) 189 | else: 190 | if (idx[0] + len(idx) - 1) == idx[-1]: 191 | if [idx[0], idx[-1]] not in bert_idx: 192 | bert_idx.append([idx[0], idx[-1]]) 193 | else: 194 | temp_flag = 0 195 | s_idx = e_idx = idx[0] # start_idx = end_idx 196 | for i in idx[1:]: 197 | if i == e_idx + 1: 198 | e_idx = i 199 | else: 200 | temp_flag = 1 201 | if [s_idx, e_idx] not in bert_idx: 202 | bert_idx.append([s_idx, e_idx]) 203 | s_idx = e_idx = i 204 | if [s_idx, e_idx] not in bert_idx: 205 | bert_idx.append([s_idx, e_idx]) 206 | 207 | if temp_flag == 1: 208 | flag += 1 209 | temp_flag = 0 210 | 211 | return bert_idx, flag 212 | 213 | 214 | def sep_pair_idx(pair_idx): 215 | """ 216 | according to the source pair idx, get the new idx, term rep = [[start_idx, end_idx], [start_idx, end_idx]] 217 | and it is consequent 218 | :param pair_idx: 219 | :return: 220 | """ 221 | bert_idx = [] 222 | flag = a_flag = o_flag = 0 223 | for p in pair_idx: 224 | bert_a_idx, a_flag = sep_discontinuous_term([p[0]]) 225 | bert_o_idx, o_flag = sep_discontinuous_term([p[1]]) 226 | for a in bert_a_idx: 227 | for b in bert_o_idx: 228 | bert_idx.append((a, b)) 229 | 230 | if a_flag != 0 or o_flag != 0: 231 | flag = 1 232 | return bert_idx, flag 233 | 234 | 235 | def construct_instance(inst_list, dep_vocab, pos_vocab): 236 | data = [] 237 | idx = 0 238 | for inst in tqdm(inst_list, total=len(inst_list)): 239 | inst_dict = {} 240 | words = inst['words'] 241 | words_processed = [normalize_tok(w, lower_case, normalize_digits) for w in words] 242 | temp_parser_res = depparser.parse(words_processed) 243 | parser_res = [] 244 | for i in temp_parser_res: 245 | temp = i.to_conll(4).strip().split('\n') 246 | for t in temp: 247 | parser_res.append(t.split('\t')) 248 | words = [a[0] for a in parser_res] 249 | inst['words'] = words 250 | inst_dict['tokens'] = words 251 | aspects_idx = inst['aspects_idx'] 252 | opinions_idx = inst['opinions_idx'] 253 | pair_idx = inst['pair_idx'] 254 | 255 | new_aspects_idx, _ = sep_discontinuous_term(aspects_idx) 256 | new_opinions_idx, _ = sep_discontinuous_term(opinions_idx) 257 | new_pair_idx, _ = sep_pair_idx(pair_idx) 258 | 259 | s_to_t = {} 260 | i = j = 0 261 | while i < len(inst['words']): 262 | if inst['words'][i] == words[j]: 263 | s_to_t[i] = [j] 264 | i += 1 265 | j += 1 266 | else: 267 | s_to_t[i] = [] 268 | if i + 1 > len(inst['words']) - 1: 269 | s_to_t[i] = [x for x in range(j, len(words))] 270 | else: 271 | next_token = inst['words'][i + 1] 272 | while words[j] != '-RRB-' and words[j] != next_token and words[j] not in next_token and j <= len( 273 | words) - 1: 274 | s_to_t[i].append(j) 275 | j += 1 276 | i += 1 277 | 278 | def get_new_term(old_term): 279 | new_term = [] 280 | for i in old_term: 281 | temp = [] 282 | for j in i: 283 | temp.extend(s_to_t[j]) 284 | new_term.append(temp) 285 | return new_term 286 | 287 | new_aspects = get_new_term(new_aspects_idx) 288 | new_opinions = get_new_term(new_opinions_idx) 289 | inst['aspects_idx'] = new_aspects 290 | inst['opinions_idx'] = new_opinions 291 | 292 | new_pairs = [] 293 | for p in new_pair_idx: 294 | new_p_a = [] 295 | for a in p[0]: 296 | new_p_a.extend(s_to_t[a]) 297 | 298 | new_p_o = [] 299 | for a in p[1]: 300 | new_p_o.extend(s_to_t[a]) 301 | new_pairs.append((new_p_a, new_p_o)) 302 | inst['pair_idx'] = new_pairs 303 | 304 | entities = [] 305 | for a in new_aspects_idx: 306 | entities.append({'type': "Asp", "start": a[0], "end": a[-1]}) 307 | for o in new_opinions_idx: 308 | entities.append({'type': "Opi", "start": o[0], "end": o[-1]}) 309 | inst_dict['entities'] = entities.copy() 310 | 311 | relations = [] 312 | for p in new_pair_idx: 313 | s1, s2 = p[0], p[-1] 314 | head = new_aspects_idx.index(s1) # the index of entities 315 | tail = new_opinions_idx.index(s2) 316 | relations.append({"type": "Pair", "head": head, "tail": tail+len(new_aspects_idx)}) 317 | inst_dict['relations'] = relations 318 | inst_dict['orig_id'] = idx 319 | inst_dict['dep'] = [a[2] for a in parser_res] 320 | inst_dict['dep_label'] = [a[3] for a in parser_res] 321 | inst_dict['dep_label_indices'] = [dep_vocab.get_index(a[3], default_value='dep') for a in parser_res] 322 | inst_dict['pos'] = [a[1] for a in parser_res] 323 | inst_dict['pos_indices'] = [pos_vocab.get_index(a[1]) for a in parser_res] 324 | assert len(inst_dict['tokens']) == len(inst_dict['pos']) == len(inst_dict['dep']) 325 | # try: 326 | # assert len(inst['pos']) == len(inst['pos_indices']) == len(inst['words']) 327 | # except: 328 | # print(inst['pos']) 329 | # print(len(inst['pos'])) 330 | # print(inst['words']) 331 | # print(len(inst['words'])) 332 | # print(len(inst['dep_label'])) 333 | # print(inst['dep_label']) 334 | # print(len(inst['dep'])) 335 | # idx = inst['words'].index('1/2') 336 | # inst_dict['tokens'] = inst['words'][:idx] + inst['words'][idx+1:] 337 | # assert len(inst['tag_type']) == len(inst['tag_type_indices']) == len(inst_dict['tokens']) 338 | # print(inst_dict['tokens']) 339 | data.append(inst_dict.copy()) 340 | idx += 1 341 | return data 342 | 343 | 344 | def save_data_to_json(train_list, dev_list, test_list, target_dir, dep_vocab, pos_vocab): 345 | train_path = os.path.join(target_dir, 'train.json') 346 | dev_path = os.path.join(target_dir, 'dev.json') 347 | test_path = os.path.join(target_dir, 'test.json') 348 | 349 | train_data = construct_instance(train_list, dep_vocab, pos_vocab) 350 | dev_data = construct_instance(dev_list, dep_vocab, pos_vocab) 351 | test_data = construct_instance(test_list, dep_vocab, pos_vocab) 352 | 353 | json.dump(train_data, open(train_path, 'w', encoding='utf-8')) 354 | json.dump(dev_data, open(dev_path, 'w', encoding='utf-8')) 355 | json.dump(test_data, open(test_path, 'w', encoding='utf-8')) 356 | print("preprocess data successful") 357 | 358 | 359 | def count_men_len_and_rel_dis(inst_lists): 360 | """ 361 | Count the average length of mention and the distance between the head token of pairs 362 | Args: 363 | inst_lists: 364 | 365 | Returns: 366 | avg_mention_len: float 367 | avg_rel_distance: float 368 | """ 369 | 370 | mention_lens = [] 371 | rel_distances = [] 372 | for inst in tqdm(inst_lists, total=len(inst_lists)): 373 | for p in inst["pair_idx"]: 374 | mention_lens.append(p[0][-1]-p[0][0]+1) 375 | mention_lens.append(p[1][-1]-p[1][0]+1) 376 | rel_distances.append(abs(p[0][0]-p[1][0])+1) 377 | 378 | # orl = inst['orl'] 379 | # for x in orl: 380 | # if x[-1] == 'DSE': 381 | # mention_lens.append(x[3] - x[2] + 1) 382 | # elif x[-1] == 'AGENT' or x[-1] == 'TARGET': 383 | # mention_lens.append(x[3] - x[2] + 1) 384 | # rel_distances.append(abs(x[0]-x[2])+1) 385 | # else: 386 | # raise KeyError('annotation error, check {}'.format(' '.join(inst['sentences']))) 387 | avg_mention_len = sum(mention_lens) / len(mention_lens) 388 | avg_rel_distance = sum(rel_distances) / len(rel_distances) 389 | print("he average length of mentions: ", avg_mention_len) 390 | print("the distance between the head token of pairs: ", avg_rel_distance) 391 | 392 | 393 | 394 | if __name__ == '__main__': 395 | 396 | train_list = [] 397 | dev_list = [] 398 | test_list = [] 399 | dataset = ['14lap', '14res', '15res', '16res'] 400 | for d_name in dataset: 401 | print(d_name) 402 | file_path = os.path.join(config['ori_data_dir'], d_name) 403 | train = load_data_from_ori(os.path.join(file_path, 'train.tsv')) 404 | train, dev = train_test_split(train, test_size=0.2, shuffle=True) 405 | test = load_data_from_ori(os.path.join(file_path, 'test.tsv')) 406 | # count_men_len_and_rel_dis(train+dev+test) 407 | dep_vocab, pos_vocab = build_vocab(train+dev+test, os.path.join(config['new_data_dir'], d_name)) 408 | save_data_to_json(train, dev, test, os.path.join(config['new_data_dir'], d_name), 409 | dep_vocab, pos_vocab) 410 | 411 | -------------------------------------------------------------------------------- /SynFue/sampling.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from SynFue import util 6 | 7 | 8 | def create_train_sample(doc, neg_term_count: int, neg_rel_count: int, max_span_size: int, rel_type_count: int): 9 | encodings = doc.encoding 10 | token_count = len(doc.tokens) # the size of tokens in a document, NOT INCLUDE [CLS] AND [SEP] 11 | context_size = len( 12 | encodings) # the size of sub-tokens in a document after WordPiece Tokenizer, include [CLS] AND [SEP] 13 | 14 | pieces2word = torch.zeros((token_count, context_size), dtype=torch.bool) 15 | start = 0 16 | for i, token in enumerate(doc.tokens): 17 | pieces = list(range(token.sub_token_start, token.sub_token_end)) 18 | pieces2word[i, pieces[0]: pieces[-1] + 1] = 1 19 | start += len(pieces) 20 | 21 | # positive terms 22 | pos_term_spans, pos_term_types, pos_term_masks, pos_term_sizes = [], [], [], [] 23 | for e in doc.terms: 24 | pos_term_spans.append(e.span) 25 | pos_term_types.append(e.term_type.index) 26 | pos_term_masks.append(create_term_mask(*e.span, token_count)) 27 | pos_term_sizes.append(len(e.tokens)) 28 | 29 | # positive relations 30 | pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], [] 31 | pos_rels3, pos_rel_masks3, pos_rel_spans3 = [], [], [] 32 | pos_pair_mask = [] # which triplet rel is true 33 | for rel in doc.relations: 34 | s1, s2 = rel.head_term.span, rel.tail_term.span 35 | pos_rels.append((pos_term_spans.index(s1), pos_term_spans.index(s2))) 36 | pos_rel_spans.append((s1, s2)) 37 | pos_rel_types.append(rel.relation_type) 38 | pos_rel_masks.append(create_rel_mask(s1, s2, token_count)) 39 | 40 | def is_in_relation(head, tail, relations): 41 | for _rel in relations: 42 | _s1, _s2 = _rel.head_term.span, _rel.tail_term.span 43 | if (_s1 == head and _s2 == tail) or (_s1 == tail and _s2 == head): 44 | return 1 45 | return 0 46 | 47 | for x in range(len(doc.relations)): 48 | s1, s2 = doc.relations[x].head_term, doc.relations[x].tail_term 49 | x1, x2 = pos_term_spans.index(s1.span), pos_term_spans.index(s2.span) 50 | t_p_rels3 = [] 51 | t_p_rels3_mask = [] 52 | t_p_rel_span3 = [] 53 | for idx, _term_span in enumerate(pos_term_spans): 54 | if idx != x1 and idx != x2: 55 | if is_in_relation(s1, _term_span, doc.relations) or is_in_relation(s2, _term_span, doc.relations): 56 | t_p_rels3.append((x1, x2, idx)) 57 | t_p_rels3_mask.append(create_rel_mask3(s1.span, s2.span, _term_span, token_count)) 58 | t_p_rel_span3.append((s1.span, s2.span, _term_span)) 59 | # t_p_rel_types3.append(1) 60 | if len(t_p_rels3) > 0: 61 | pos_rels3.append(t_p_rels3) 62 | pos_rel_masks3.append(t_p_rels3_mask) 63 | pos_pair_mask.append(1) 64 | pos_rel_spans3.append(t_p_rel_span3) 65 | # pos_rel_types3.append(t_p_rel_types3) 66 | else: 67 | pos_rels3.append([(x1, x2, 0)]) 68 | pos_rel_masks3.append([(create_rel_mask3(s1.span, s2.span, (0, 0), token_count))]) 69 | pos_pair_mask.append(0) 70 | 71 | assert len(pos_rels) == len(pos_rels3) == len(pos_pair_mask) 72 | 73 | # negative terms 74 | neg_term_spans, neg_term_sizes = [], [] 75 | for size in range(1, max_span_size + 1): 76 | for i in range(0, (token_count - size) + 1): 77 | span = doc.tokens[i:i + size].span 78 | if span not in pos_term_spans: 79 | neg_term_spans.append(span) 80 | neg_term_sizes.append(size) 81 | 82 | # sample negative terms 83 | neg_term_samples = random.sample(list(zip(neg_term_spans, neg_term_sizes)), 84 | min(len(neg_term_spans), neg_term_count)) 85 | neg_term_spans, neg_term_sizes = zip(*neg_term_samples) if neg_term_samples else ([], []) 86 | 87 | neg_term_masks = [create_term_mask(*span, token_count) for span in neg_term_spans] 88 | neg_term_types = [0] * len(neg_term_spans) 89 | 90 | # negative relations 91 | # use only strong negative relations, i.e. pairs of actual (labeled) terms that are not related 92 | # neg_rels3 = [] 93 | neg_rel_spans = [] 94 | neg_rel_spans3 = [] 95 | neg_pair_mask = [] 96 | 97 | for i1, s1 in enumerate(pos_term_spans): 98 | for i2, s2 in enumerate(pos_term_spans): 99 | rev = (s2, s1) 100 | rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric 101 | 102 | # do not add as negative relation sample: 103 | # neg. relations from an term to itself 104 | # term pairs that are related according to gt 105 | # term pairs whose reverse exists as a symmetric relation in gt 106 | if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric: 107 | neg_rel_spans.append((s1, s2)) 108 | 109 | p_rel_span3 = [] 110 | for i3, s3 in enumerate(pos_term_spans): 111 | # three spans are different from each other and not exist in pos_rel_span3 112 | if s1 != s2 and s1 != s3 and s2 != s3 and (s1, s2, s3) not in pos_rel_spans3: 113 | p_rel_span3.append((s1, s2, s3)) 114 | if len(p_rel_span3) > 0: 115 | neg_rel_spans3.append(p_rel_span3) 116 | neg_pair_mask.append(1) 117 | else: 118 | neg_rel_spans3.append([(s1, s2, (0, 0))]) 119 | neg_pair_mask.append(0) 120 | 121 | # sample negative relations 122 | 123 | assert len(neg_rel_spans) == len(neg_rel_spans3) == len(neg_pair_mask) 124 | 125 | neg_rel_spans_samples = random.sample(list(zip(neg_rel_spans, neg_rel_spans3, neg_pair_mask)), 126 | min(len(neg_rel_spans), neg_rel_count)) 127 | neg_rel_spans, neg_rel_spans3, neg_pair_mask = zip(*neg_rel_spans_samples) if neg_rel_spans_samples else ( 128 | [], [], []) 129 | 130 | neg_rels = [(pos_term_spans.index(s1), pos_term_spans.index(s2)) for s1, s2 in neg_rel_spans] 131 | neg_rels3 = [[(pos_term_spans.index(s1), pos_term_spans.index(s2), pos_term_spans.index(s3)) for s1, s2, s3 in x] 132 | for x in neg_rel_spans3] 133 | 134 | assert len(neg_rels3) == len(neg_rel_spans3) == len(neg_pair_mask) 135 | 136 | neg_rel_masks = [create_rel_mask(*spans, token_count) for spans in neg_rel_spans] 137 | neg_rel_masks3 = [[create_rel_mask3(*sps, token_count) for sps in spans] for spans in neg_rel_spans3] 138 | neg_rel_types = [0] * len(neg_rel_spans) 139 | # neg_rel_types3 = [0] * len(neg_rel_spans3) 140 | 141 | # merge 142 | term_types = pos_term_types + neg_term_types 143 | term_masks = pos_term_masks + neg_term_masks 144 | term_sizes = pos_term_sizes + list(neg_term_sizes) 145 | term_spans = pos_term_spans + list(neg_term_spans) 146 | 147 | rels = pos_rels + neg_rels 148 | rel_types = [r.index for r in pos_rel_types] + neg_rel_types 149 | rel_masks = pos_rel_masks + neg_rel_masks 150 | 151 | rels3 = pos_rels3 + neg_rels3 152 | # rel_types3 = pos_rel_types3 + neg_rel_types3 153 | rel_masks3 = pos_rel_masks3 + neg_rel_masks3 154 | pair_mask = pos_pair_mask + list(neg_pair_mask) 155 | 156 | assert len(term_masks) == len(term_sizes) == len(term_types) 157 | try: 158 | assert len(rels) == len(rel_masks) == len(rel_types) == len(rels3) == len(pair_mask) 159 | except: 160 | print(len(rels)) 161 | print(len(rels3)) 162 | print(len(pair_mask)) 163 | 164 | encodings = torch.tensor(encodings, dtype=torch.long) 165 | # masking of tokens 166 | context_masks = torch.ones(context_size, dtype=torch.bool) 167 | 168 | # also create samples_masks: 169 | # tensors to mask term/relation samples of batch 170 | # since samples are stacked into batches, "padding" terms/relations possibly must be created 171 | # these are later masked during loss computation 172 | if term_masks: 173 | term_types = torch.tensor(term_types, dtype=torch.long) 174 | term_masks = torch.stack(term_masks) 175 | term_sizes = torch.tensor(term_sizes, dtype=torch.long) 176 | term_sample_masks = torch.ones([term_masks.shape[0]], dtype=torch.bool) 177 | term_spans = torch.tensor(term_spans, dtype=torch.long) 178 | else: 179 | # corner case handling (no pos/neg terms) 180 | term_types = torch.zeros([1], dtype=torch.long) 181 | term_masks = torch.zeros([1, token_count], dtype=torch.bool) 182 | term_sizes = torch.zeros([1], dtype=torch.long) 183 | term_sample_masks = torch.zeros([1], dtype=torch.bool) 184 | term_spans = torch.tensor([1, 2], dtype=torch.long) 185 | 186 | if rels: 187 | rels = torch.tensor(rels, dtype=torch.long) 188 | rel_masks = torch.stack(rel_masks) 189 | rel_types = torch.tensor(rel_types, dtype=torch.long) 190 | rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool) 191 | else: 192 | # corner case handling (no pos/neg relations) 193 | rels = torch.zeros([1, 2], dtype=torch.long) 194 | rel_types = torch.zeros([1], dtype=torch.long) 195 | rel_masks = torch.zeros([1, token_count], dtype=torch.bool) 196 | rel_sample_masks = torch.zeros([1], dtype=torch.bool) 197 | 198 | if rels3: 199 | max_tri = max([len(x) for x in rels3]) 200 | for idx, r in enumerate(rels3): 201 | r_len = len(r) 202 | if r_len < max_tri: 203 | rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len)) 204 | rel_masks3[idx].extend([rel_masks3[idx][0]] * (max_tri - r_len)) 205 | rels3 = torch.tensor(rels3, dtype=torch.long) 206 | try: 207 | rel_masks3 = torch.stack([torch.stack(x) for x in rel_masks3]) 208 | except: 209 | print(rel_masks3) 210 | rel_sample_masks3 = torch.ones([rels3.shape[0]], dtype=torch.bool) 211 | pair_mask = torch.tensor(pair_mask, dtype=torch.bool) 212 | else: 213 | rels3 = torch.zeros([1, 3], dtype=torch.long) 214 | rel_masks3 = torch.zeros([1, token_count], dtype=torch.bool) 215 | rel_sample_masks3 = torch.zeros([1], dtype=torch.bool) 216 | pair_mask = torch.tensor(pair_mask, dtype=torch.bool) 217 | 218 | # relation types to one-hot encoding 219 | rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32) 220 | rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1) 221 | rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation 222 | 223 | simple_graph = None 224 | graph = None 225 | try: 226 | simple_graph = torch.tensor(get_simple_graph(token_count, doc.dep), dtype=torch.long) # only the relation 227 | except: 228 | print(context_size) 229 | print(token_count) 230 | print(encodings) 231 | print(doc.dep) 232 | print(doc.dep_label_indices) 233 | try: 234 | graph = torch.tensor(get_graph(token_count, doc.dep, doc.dep_label_indices), 235 | dtype=torch.long) # relation and the type of relation 236 | except: 237 | print(context_size) 238 | print(token_count) 239 | print(encodings) 240 | print(doc.dep) 241 | print(doc.dep_label_indices) 242 | 243 | pos = torch.tensor(get_pos(token_count, doc.pos_indices), dtype=torch.long) 244 | 245 | return dict(encodings=encodings, context_masks=context_masks, term_masks=term_masks, 246 | term_sizes=term_sizes, term_types=term_types, term_spans=term_spans, 247 | rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot, 248 | rels3=rels3, rel_sample_masks3=rel_sample_masks3, rel_masks3=rel_masks3, 249 | pair_mask=pair_mask, 250 | term_sample_masks=term_sample_masks, rel_sample_masks=rel_sample_masks, 251 | simple_graph=simple_graph, graph=graph, pos=pos, pieces2word=pieces2word) 252 | 253 | 254 | def create_eval_sample(doc, max_span_size: int): 255 | encodings = doc.encoding 256 | token_count = len(doc.tokens) 257 | context_size = len(encodings) 258 | 259 | pieces2word = torch.zeros((token_count, context_size), dtype=torch.bool) 260 | start = 0 261 | for i, token in enumerate(doc.tokens): 262 | pieces = list(range(token.sub_token_start, token.sub_token_end)) 263 | pieces2word[i, pieces[0]: pieces[-1] + 1] = 1 264 | start += len(pieces) 265 | 266 | # create term candidates 267 | term_spans = [] 268 | term_masks = [] 269 | term_sizes = [] 270 | 271 | for size in range(1, max_span_size + 1): 272 | for i in range(0, (token_count - size) + 1): 273 | span = doc.tokens[i:i + size].span 274 | term_spans.append(span) 275 | term_masks.append(create_term_mask(*span, token_count)) 276 | term_sizes.append(size) 277 | 278 | # create tensors 279 | # token indices 280 | # _encoding = encodings 281 | # encodings = torch.zeros(context_size, dtype=torch.long) 282 | # encodings[:len(_encoding)] = torch.tensor(_encoding, dtype=torch.long) 283 | encodings = torch.tensor(encodings, dtype=torch.long) 284 | 285 | # masking of tokens 286 | context_masks = torch.ones(context_size, dtype=torch.bool) 287 | # context_masks[:len(_encoding)] = 1 288 | 289 | # terms 290 | if term_masks: 291 | term_masks = torch.stack(term_masks) 292 | term_sizes = torch.tensor(term_sizes, dtype=torch.long) 293 | term_spans = torch.tensor(term_spans, dtype=torch.long) 294 | 295 | # tensors to mask term samples of batch 296 | # since samples are stacked into batches, "padding" terms possibly must be created 297 | # these are later masked during evaluation 298 | term_sample_masks = torch.tensor([1] * term_masks.shape[0], dtype=torch.bool) 299 | else: 300 | # corner case handling (no terms) 301 | term_masks = torch.zeros([1, token_count], dtype=torch.bool) 302 | term_sizes = torch.zeros([1], dtype=torch.long) 303 | term_spans = torch.zeros([1, 2], dtype=torch.long) 304 | term_sample_masks = torch.zeros([1], dtype=torch.bool) 305 | 306 | simple_graph = torch.tensor(get_simple_graph(token_count, doc.dep), dtype=torch.long) # only the relation 307 | graph = torch.tensor(get_graph(token_count, doc.dep, doc.dep_label_indices), 308 | dtype=torch.long) # relation and the type of relation 309 | pos = torch.tensor(get_pos(token_count, doc.pos_indices), dtype=torch.long) 310 | 311 | return dict(encodings=encodings, context_masks=context_masks, term_masks=term_masks, 312 | term_sizes=term_sizes, term_spans=term_spans, term_sample_masks=term_sample_masks, 313 | simple_graph=simple_graph, graph=graph, pos=pos, pieces2word=pieces2word) 314 | 315 | 316 | def create_term_mask(start, end, context_size): 317 | mask = torch.zeros(context_size, dtype=torch.bool) 318 | mask[start:end] = 1 319 | return mask 320 | 321 | 322 | def create_rel_mask(s1, s2, context_size): 323 | start = s1[1] if s1[1] < s2[0] else s2[1] 324 | end = s2[0] if s1[1] < s2[0] else s1[0] 325 | mask = create_term_mask(start, end, context_size) 326 | return mask 327 | 328 | 329 | def create_rel_mask3(s1, s2, s3, context_size): 330 | mask = torch.zeros(context_size, dtype=torch.bool) 331 | start = min(s1[0], s1[1], s2[0], s2[1], s3[0], s3[1]) 332 | end = max(s1[0], s1[1], s2[0], s2[1], s3[0], s3[1]) 333 | mask[start:end] = 1 334 | return mask 335 | 336 | 337 | def collate_fn_padding(batch): 338 | padded_batch = dict() 339 | keys = batch[0].keys() 340 | 341 | for key in keys: 342 | samples = [s[key] for s in batch] 343 | if not batch[0][key].shape: 344 | padded_batch[key] = torch.stack(samples) 345 | else: 346 | padded_batch[key] = util.padded_stack([s[key] for s in batch]) 347 | 348 | return padded_batch 349 | 350 | 351 | def get_graph(seq_len, feature_data, feature2id): 352 | """ 353 | To create a table t_{i,j} in T. t_{i,j} = r, r is the dependency relation label between word i and word j. 354 | :param seq_len: token 355 | :param feature_data: dependency head. Specifically, '0' represents the head word is ROOT 356 | :param feature2id: dependency label indices 357 | :return: 358 | """ 359 | assert len(feature2id) == len(feature_data) == seq_len 360 | ret = [[0] * seq_len for _ in range(seq_len)] 361 | for i, item in enumerate(feature_data): 362 | # the head word is ROOT, so this token only has a self-loop edge 363 | if int(item) == 0: 364 | ret[i][i] = 1 365 | continue 366 | ret[i][int(item) - 1] = feature2id[i] 367 | ret[int(item) - 1][i] = feature2id[i] 368 | ret[i][i] = 1 369 | return ret 370 | 371 | 372 | def get_simple_graph(seq_len, feature_data): 373 | """ 374 | To create a table t_{i,j} in T. t_{i,j} = 1, which means there is an edge between the word i and word j. 375 | :param seq_len: token 376 | :param feature_data: dependency head. Specifically, '0' represents the head word is ROOT 377 | :return: 378 | """ 379 | assert len(feature_data) == seq_len 380 | ret = [[0] * seq_len for _ in range(seq_len)] 381 | for i, item in enumerate(feature_data): 382 | if int(item) == 0: 383 | ret[i][i] = 1 384 | continue 385 | ret[i][int(item) - 1] = 1 386 | ret[int(item) - 1][i] = 1 387 | ret[i][i] = 1 388 | return ret 389 | 390 | 391 | def get_pos(seq_len, pos_indices): 392 | assert len(pos_indices) == seq_len 393 | ret = [0] * seq_len 394 | for i, item in enumerate(pos_indices): 395 | ret[i] = pos_indices[i] 396 | return ret 397 | -------------------------------------------------------------------------------- /SynFue/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch.nn import DataParallel 7 | from torch.optim import Optimizer 8 | import transformers 9 | from torch.utils.data import DataLoader 10 | from transformers import AdamW, BertConfig 11 | from transformers import BertTokenizer 12 | 13 | from SynFue import models 14 | from SynFue import sampling 15 | from SynFue import util 16 | from SynFue.terms import Dataset 17 | from SynFue.evaluator import Evaluator 18 | from SynFue.input_reader import JsonInputReader, BaseInputReader 19 | from SynFue.loss import SynFueLoss, Loss 20 | from tqdm import tqdm 21 | from SynFue.base_trainer import BaseTrainer 22 | 23 | from torchsummary import summary 24 | 25 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 26 | 27 | 28 | class SynFueTrainer(BaseTrainer): 29 | """ Joint term and relation extraction training and evaluation """ 30 | 31 | def __init__(self, args: argparse.Namespace): 32 | super().__init__(args) 33 | 34 | # byte-pair encoding 35 | self._tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path, 36 | do_lower_case=args.lowercase, 37 | cache_dir=args.cache_path) 38 | 39 | # path to export predictions to 40 | self._predictions_path = os.path.join(self._log_path, 'predictions_%s_epoch_%s.json') 41 | 42 | # path to export relation extraction examples to 43 | self._examples_path = os.path.join(self._log_path, 'examples_%s_%s_epoch_%s.html') 44 | 45 | def train(self, train_path: str, valid_path: str, types_path: str, input_reader_cls: BaseInputReader): 46 | args = self.args 47 | train_label, valid_label = 'train', 'valid' 48 | 49 | self._logger.info("Datasets: %s, %s" % (train_path, valid_path)) 50 | self._logger.info("Model type: %s" % args.model_type) 51 | 52 | # create log csv files 53 | self._init_train_logging(train_label) 54 | self._init_eval_logging(valid_label) 55 | 56 | # read datasets 57 | input_reader = input_reader_cls(types_path, self._tokenizer, args.neg_term_count, 58 | args.neg_relation_count, args.max_span_size, self._logger) 59 | input_reader.read({train_label: train_path, valid_label: valid_path}) 60 | self._log_datasets(input_reader) 61 | 62 | train_dataset = input_reader.get_dataset(train_label) 63 | train_sample_count = train_dataset.document_count 64 | updates_epoch = train_sample_count // args.train_batch_size 65 | updates_total = updates_epoch * args.epochs 66 | 67 | validation_dataset = input_reader.get_dataset(valid_label) 68 | 69 | self._logger.info("Updates per epoch: %s" % updates_epoch) 70 | self._logger.info("Updates total: %s" % updates_total) 71 | 72 | # create model 73 | model_class = models.get_model(self.args.model_type) 74 | 75 | # print('Training:') 76 | # print('model_path: ', self.args.model_path) 77 | # print('cache_path: ', self.args.cache_path) 78 | 79 | # load model 80 | config = BertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path) 81 | util.check_version(config, model_class, self.args.model_path) 82 | 83 | config.model_version = model_class.VERSION 84 | model = model_class.from_pretrained(self.args.model_path, 85 | config=config, 86 | cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'), 87 | relation_types=input_reader.relation_type_count - 1, 88 | term_types=input_reader.term_type_count, 89 | max_pairs=self.args.max_pairs, 90 | prop_drop=self.args.prop_drop, 91 | size_embedding=self.args.size_embedding, 92 | freeze_transformer=self.args.freeze_transformer, 93 | args=self.args, 94 | beta=self.args.beta, 95 | alpha=self.args.alpha, 96 | sigma=self.args.sigma) 97 | 98 | model.to(self._device) 99 | # summary(model) 100 | # create optimizer 101 | optimizer_params = self._get_optimizer_params(model) 102 | optimizer = AdamW(optimizer_params, lr=args.lr, weight_decay=args.weight_decay, correct_bias=False) 103 | # create scheduler 104 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 105 | num_warmup_steps=args.lr_warmup * updates_total, 106 | num_training_steps=updates_total) 107 | # create loss function 108 | rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 109 | term_criterion = torch.nn.CrossEntropyLoss(reduction='none') 110 | compute_loss = SynFueLoss(rel_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm) 111 | 112 | # eval validation set 113 | if args.init_eval: 114 | self._eval(model, validation_dataset, input_reader, 0, updates_epoch) 115 | 116 | # train 117 | best_f1 = 0.0 118 | for epoch in range(args.epochs): 119 | # train epoch 120 | self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch) 121 | # rel_nec_eval = self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch) 122 | # eval validation sets 123 | if not args.final_eval or (epoch == args.epochs - 1): 124 | rel_nec_eval = self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch) 125 | if best_f1 < rel_nec_eval[-1]: 126 | # save final model 127 | best_f1 = rel_nec_eval[-1] 128 | extra = dict(epoch=args.epochs, updates_epoch=updates_epoch, epoch_iteration=0) 129 | global_iteration = args.epochs * updates_epoch 130 | self._save_model(self._save_path, model, self._tokenizer, global_iteration, 131 | optimizer=optimizer if self.args.save_optimizer else None, save_as_best=True, 132 | extra=extra, include_iteration=False) 133 | 134 | self._logger.info("Logged in: %s" % self._log_path) 135 | self._logger.info("Saved in: %s" % self._save_path) 136 | self._close_summary_writer() 137 | 138 | def eval(self, dataset_path: str, types_path: str, input_reader_cls: BaseInputReader): 139 | args = self.args 140 | dataset_label = 'test' 141 | 142 | self._logger.info("Dataset: %s" % dataset_path) 143 | self._logger.info("Model: %s" % args.model_type) 144 | 145 | # create log csv files 146 | self._init_eval_logging(dataset_label) 147 | 148 | # read datasets 149 | input_reader = input_reader_cls(types_path, self._tokenizer, 150 | max_span_size=args.max_span_size, logger=self._logger) 151 | input_reader.read({dataset_label: dataset_path}) 152 | self._log_datasets(input_reader) 153 | 154 | # create model 155 | model_class = models.get_model(self.args.model_type) 156 | 157 | # print('Eval:') 158 | # print('model_path: ', self.args.model_path) 159 | # print('cache_path: ', self.args.cache_path) 160 | 161 | config = BertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path) 162 | 163 | util.check_version(config, model_class, self.args.model_path) 164 | 165 | model = model_class.from_pretrained(self.args.model_path, 166 | config=config, 167 | cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'), 168 | relation_types=input_reader.relation_type_count - 1, 169 | term_types=input_reader.term_type_count, 170 | max_pairs=self.args.max_pairs, 171 | prop_drop=self.args.prop_drop, 172 | size_embedding=self.args.size_embedding, 173 | freeze_transformer=self.args.freeze_transformer, 174 | args=self.args, 175 | beta=self.args.beta, 176 | alpha=self.args.alpha, 177 | sigma=self.args.sigma) 178 | # model = torch.load(os.path.join(self.args.model_path, 'pytorch_model.bin')) 179 | 180 | model.to(self._device) 181 | # summary(model) 182 | # evaluate 183 | self._eval(model, input_reader.get_dataset(dataset_label), input_reader) 184 | 185 | self._logger.info("Logged in: %s" % self._log_path) 186 | self._close_summary_writer() 187 | 188 | def _train_epoch(self, model: torch.nn.Module, compute_loss: Loss, optimizer: Optimizer, dataset: Dataset, 189 | updates_epoch: int, epoch: int): 190 | self._logger.info("Train epoch: %s" % epoch) 191 | 192 | # create data loader 193 | dataset.switch_mode(Dataset.TRAIN_MODE) 194 | data_loader = DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, drop_last=True, 195 | num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) 196 | 197 | model.zero_grad() 198 | 199 | iteration = 0 200 | total = dataset.document_count // self.args.train_batch_size 201 | for batch in tqdm(data_loader, total=total, desc='Train epoch %s' % epoch): 202 | model.train() 203 | batch = util.to_device(batch, self._device) 204 | 205 | # forward step 206 | term_logits, rel_logits = model(encodings=batch['encodings'], context_masks=batch['context_masks'], 207 | term_masks=batch['term_masks'], term_sizes=batch['term_sizes'], 208 | term_spans=batch['term_spans'], term_types=batch['term_types'], 209 | relations=batch['rels'], rel_masks=batch['rel_masks'], 210 | simple_graph=batch['simple_graph'], graph=batch['graph'], 211 | relations3=batch['rels3'], rel_masks3=batch['rel_masks3'], 212 | pair_mask=batch['pair_mask'], pos=batch['pos'], 213 | pieces2word=batch['pieces2word']) 214 | 215 | # compute loss and optimize parameters 216 | batch_loss = compute_loss.compute(term_logits=term_logits, rel_logits=rel_logits, 217 | rel_types=batch['rel_types'], term_types=batch['term_types'], 218 | term_sample_masks=batch['term_sample_masks'], 219 | rel_sample_masks=batch['rel_sample_masks']) 220 | 221 | # logging 222 | iteration += 1 223 | global_iteration = epoch * updates_epoch + iteration 224 | 225 | if global_iteration % self.args.train_log_iter == 0: 226 | self._log_train(optimizer, batch_loss, epoch, iteration, global_iteration, dataset.label) 227 | 228 | return iteration 229 | 230 | def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInputReader, 231 | epoch: int = 0, updates_epoch: int = 0, iteration: int = 0): 232 | self._logger.info("Evaluate: %s" % dataset.label) 233 | 234 | if isinstance(model, DataParallel): 235 | # currently no multi GPU support during evaluation 236 | model = model.module 237 | 238 | # create evaluator 239 | evaluator = Evaluator(dataset, input_reader, self._tokenizer, 240 | self.args.rel_filter_threshold, self.args.no_overlapping, self._predictions_path, 241 | self._examples_path, self.args.example_count, epoch, dataset.label) 242 | 243 | # create data loader 244 | dataset.switch_mode(Dataset.EVAL_MODE) 245 | data_loader = DataLoader(dataset, batch_size=self.args.eval_batch_size, shuffle=False, drop_last=False, 246 | num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) 247 | 248 | with torch.no_grad(): 249 | model.eval() 250 | 251 | # iterate batches 252 | total = math.ceil(dataset.document_count / self.args.eval_batch_size) 253 | for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch): 254 | # move batch to selected device 255 | batch = util.to_device(batch, self._device) 256 | 257 | # run model (forward pass) 258 | result = model(encodings=batch['encodings'], context_masks=batch['context_masks'], 259 | term_masks=batch['term_masks'], term_sizes=batch['term_sizes'], 260 | term_spans=batch['term_spans'], term_sample_masks=batch['term_sample_masks'], 261 | evaluate=True, simple_graph=batch['simple_graph'], graph=batch['graph'], 262 | pos=batch['pos'], pieces2word=batch['pieces2word']) # pos=batch['pos'] 263 | term_clf, rel_clf, rels = result 264 | 265 | # evaluate batch 266 | evaluator.eval_batch(term_clf, rel_clf, rels, batch) 267 | 268 | global_iteration = epoch * updates_epoch + iteration 269 | ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores() 270 | self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval, 271 | epoch, iteration, global_iteration, dataset.label) 272 | 273 | if self.args.store_predictions and not self.args.no_overlapping: 274 | evaluator.store_predictions() 275 | 276 | if self.args.store_examples: 277 | evaluator.store_examples() 278 | return rel_nec_eval 279 | 280 | def _get_optimizer_params(self, model): 281 | param_optimizer = list(model.named_parameters()) 282 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 283 | optimizer_params = [ 284 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 285 | 'weight_decay': self.args.weight_decay}, 286 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 287 | 288 | return optimizer_params 289 | 290 | def _log_train(self, optimizer: Optimizer, loss: float, epoch: int, 291 | iteration: int, global_iteration: int, label: str): 292 | # average loss 293 | avg_loss = loss / self.args.train_batch_size 294 | # get current learning rate 295 | lr = self._get_lr(optimizer)[0] 296 | 297 | # log to tensorboard 298 | self._log_tensorboard(label, 'loss', loss, global_iteration) 299 | self._log_tensorboard(label, 'loss_avg', avg_loss, global_iteration) 300 | self._log_tensorboard(label, 'lr', lr, global_iteration) 301 | 302 | # log to csv 303 | self._log_csv(label, 'loss', loss, epoch, iteration, global_iteration) 304 | self._log_csv(label, 'loss_avg', avg_loss, epoch, iteration, global_iteration) 305 | self._log_csv(label, 'lr', lr, epoch, iteration, global_iteration) 306 | 307 | def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: float, 308 | ner_prec_macro: float, ner_rec_macro: float, ner_f1_macro: float, 309 | 310 | rel_prec_micro: float, rel_rec_micro: float, rel_f1_micro: float, 311 | rel_prec_macro: float, rel_rec_macro: float, rel_f1_macro: float, 312 | 313 | rel_nec_prec_micro: float, rel_nec_rec_micro: float, rel_nec_f1_micro: float, 314 | rel_nec_prec_macro: float, rel_nec_rec_macro: float, rel_nec_f1_macro: float, 315 | epoch: int, iteration: int, global_iteration: int, label: str): 316 | 317 | # log to tensorboard 318 | self._log_tensorboard(label, 'eval/ner_prec_micro', ner_prec_micro, global_iteration) 319 | self._log_tensorboard(label, 'eval/ner_recall_micro', ner_rec_micro, global_iteration) 320 | self._log_tensorboard(label, 'eval/ner_f1_micro', ner_f1_micro, global_iteration) 321 | self._log_tensorboard(label, 'eval/ner_prec_macro', ner_prec_macro, global_iteration) 322 | self._log_tensorboard(label, 'eval/ner_recall_macro', ner_rec_macro, global_iteration) 323 | self._log_tensorboard(label, 'eval/ner_f1_macro', ner_f1_macro, global_iteration) 324 | 325 | self._log_tensorboard(label, 'eval/rel_prec_micro', rel_prec_micro, global_iteration) 326 | self._log_tensorboard(label, 'eval/rel_recall_micro', rel_rec_micro, global_iteration) 327 | self._log_tensorboard(label, 'eval/rel_f1_micro', rel_f1_micro, global_iteration) 328 | self._log_tensorboard(label, 'eval/rel_prec_macro', rel_prec_macro, global_iteration) 329 | self._log_tensorboard(label, 'eval/rel_recall_macro', rel_rec_macro, global_iteration) 330 | self._log_tensorboard(label, 'eval/rel_f1_macro', rel_f1_macro, global_iteration) 331 | 332 | self._log_tensorboard(label, 'eval/rel_nec_prec_micro', rel_nec_prec_micro, global_iteration) 333 | self._log_tensorboard(label, 'eval/rel_nec_recall_micro', rel_nec_rec_micro, global_iteration) 334 | self._log_tensorboard(label, 'eval/rel_nec_f1_micro', rel_nec_f1_micro, global_iteration) 335 | self._log_tensorboard(label, 'eval/rel_nec_prec_macro', rel_nec_prec_macro, global_iteration) 336 | self._log_tensorboard(label, 'eval/rel_nec_recall_macro', rel_nec_rec_macro, global_iteration) 337 | self._log_tensorboard(label, 'eval/rel_nec_f1_macro', rel_nec_f1_macro, global_iteration) 338 | 339 | # log to csv 340 | self._log_csv(label, 'eval', ner_prec_micro, ner_rec_micro, ner_f1_micro, 341 | ner_prec_macro, ner_rec_macro, ner_f1_macro, 342 | 343 | rel_prec_micro, rel_rec_micro, rel_f1_micro, 344 | rel_prec_macro, rel_rec_macro, rel_f1_macro, 345 | 346 | rel_nec_prec_micro, rel_nec_rec_micro, rel_nec_f1_micro, 347 | rel_nec_prec_macro, rel_nec_rec_macro, rel_nec_f1_macro, 348 | epoch, iteration, global_iteration) 349 | 350 | def _log_datasets(self, input_reader): 351 | self._logger.info("Relation type count: %s" % input_reader.relation_type_count) 352 | self._logger.info("Term type count: %s" % input_reader.term_type_count) 353 | 354 | self._logger.info("Terms:") 355 | for e in input_reader.term_types.values(): 356 | self._logger.info(e.verbose_name + '=' + str(e.index)) 357 | 358 | self._logger.info("Relations:") 359 | for r in input_reader.relation_types.values(): 360 | self._logger.info(r.verbose_name + '=' + str(r.index)) 361 | 362 | for k, d in input_reader.datasets.items(): 363 | self._logger.info('Dataset: %s' % k) 364 | self._logger.info("Document count: %s" % d.document_count) 365 | self._logger.info("Relation count: %s" % d.relation_count) 366 | self._logger.info("Term count: %s" % d.term_count) 367 | 368 | self._logger.info("Context size: %s" % input_reader.context_size) 369 | 370 | def _init_train_logging(self, label): 371 | self._add_dataset_logging(label, 372 | data={'lr': ['lr', 'epoch', 'iteration', 'global_iteration'], 373 | 'loss': ['loss', 'epoch', 'iteration', 'global_iteration'], 374 | 'loss_avg': ['loss_avg', 'epoch', 'iteration', 'global_iteration']}) 375 | 376 | def _init_eval_logging(self, label): 377 | self._add_dataset_logging(label, 378 | data={'eval': ['ner_prec_micro', 'ner_rec_micro', 'ner_f1_micro', 379 | 'ner_prec_macro', 'ner_rec_macro', 'ner_f1_macro', 380 | 'rel_prec_micro', 'rel_rec_micro', 'rel_f1_micro', 381 | 'rel_prec_macro', 'rel_rec_macro', 'rel_f1_macro', 382 | 'rel_nec_prec_micro', 'rel_nec_rec_micro', 'rel_nec_f1_micro', 383 | 'rel_nec_prec_macro', 'rel_nec_rec_macro', 'rel_nec_f1_macro', 384 | 'epoch', 'iteration', 'global_iteration']}) 385 | -------------------------------------------------------------------------------- /SynFue/evaluator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | from typing import List, Tuple, Dict 5 | 6 | import torch 7 | from sklearn.metrics import precision_recall_fscore_support as prfs 8 | from transformers import BertTokenizer 9 | 10 | from SynFue import util 11 | from SynFue.terms import Document, Dataset, TermType 12 | from SynFue.input_reader import JsonInputReader 13 | from SynFue.opt import jinja2 14 | 15 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | 18 | class Evaluator: 19 | def __init__(self, dataset: Dataset, input_reader: JsonInputReader, text_encoder: BertTokenizer, 20 | rel_filter_threshold: float, no_overlapping: bool, 21 | predictions_path: str, examples_path: str, example_count: int, epoch: int, dataset_label: str): 22 | self._text_encoder = text_encoder 23 | self._input_reader = input_reader 24 | self._dataset = dataset 25 | self._rel_filter_threshold = rel_filter_threshold 26 | self._no_overlapping = no_overlapping 27 | 28 | self._epoch = epoch 29 | self._dataset_label = dataset_label 30 | 31 | self._predictions_path = predictions_path 32 | 33 | self._examples_path = examples_path 34 | self._example_count = example_count 35 | 36 | # relations 37 | self._gt_relations = [] # ground truth 38 | self._pred_relations = [] # prediction 39 | 40 | # terms 41 | self._gt_terms = [] # ground truth 42 | self._pred_terms = [] # prediction 43 | 44 | self._pseudo_term_type = TermType('Term', 1, 'Term', 'Term') # for span only evaluation 45 | 46 | self._convert_gt(self._dataset.documents) 47 | 48 | def eval_batch(self, batch_term_clf: torch.tensor, batch_rel_clf: torch.tensor, 49 | batch_rels: torch.tensor, batch: dict): 50 | batch_size = batch_rel_clf.shape[0] 51 | rel_class_count = batch_rel_clf.shape[2] 52 | # get maximum activation (index of predicted term type) 53 | batch_term_types = batch_term_clf.argmax(dim=-1) 54 | # apply term sample mask 55 | batch_term_types *= batch['term_sample_masks'].long() 56 | 57 | batch_rel_clf = batch_rel_clf.view(batch_size, -1) 58 | 59 | # apply threshold to relations 60 | if self._rel_filter_threshold > 0: 61 | batch_rel_clf[batch_rel_clf < self._rel_filter_threshold] = 0 62 | 63 | for i in range(batch_size): 64 | # get model predictions for sample 65 | rel_clf = batch_rel_clf[i] 66 | term_types = batch_term_types[i] 67 | 68 | # get predicted relation labels and corresponding term pairs 69 | rel_nonzero = rel_clf.nonzero().view(-1) 70 | rel_scores = rel_clf[rel_nonzero] 71 | 72 | rel_types = (rel_nonzero % rel_class_count) + 1 # model does not predict None class (+1) 73 | rel_indices = rel_nonzero // rel_class_count 74 | 75 | rels = batch_rels[i][rel_indices] 76 | 77 | # get masks of terms in relation 78 | rel_term_spans = batch['term_spans'][i][rels].long() 79 | 80 | # get predicted term types 81 | rel_term_types = torch.zeros([rels.shape[0], 2]) 82 | if rels.shape[0] != 0: 83 | rel_term_types = torch.stack([term_types[rels[j]] for j in range(rels.shape[0])]) 84 | 85 | # convert predicted relations for evaluation 86 | sample_pred_relations = self._convert_pred_relations(rel_types, rel_term_spans, 87 | rel_term_types, rel_scores) 88 | 89 | # get terms that are not classified as 'None' 90 | valid_term_indices = term_types.nonzero().view(-1) 91 | valid_term_types = term_types[valid_term_indices] 92 | valid_term_spans = batch['term_spans'][i][valid_term_indices] 93 | valid_term_scores = torch.gather(batch_term_clf[i][valid_term_indices], 1, 94 | valid_term_types.unsqueeze(1)).view(-1) 95 | sample_pred_terms = self._convert_pred_terms(valid_term_types, valid_term_spans, 96 | valid_term_scores) 97 | 98 | if self._no_overlapping: 99 | sample_pred_terms, sample_pred_relations = self._remove_overlapping(sample_pred_terms, 100 | sample_pred_relations) 101 | 102 | self._pred_terms.append(sample_pred_terms) 103 | self._pred_relations.append(sample_pred_relations) 104 | 105 | def compute_scores(self): 106 | print("Evaluation") 107 | 108 | print("") 109 | print("--- Terms (named term recognition (NTR)) ---") 110 | print("An term is considered correct if the term type and span is predicted correctly") 111 | print("") 112 | gt, pred = self._convert_by_setting(self._gt_terms, self._pred_terms, include_term_types=True) 113 | ner_eval = self._score(gt, pred, print_results=True) 114 | 115 | print("") 116 | print("--- Relations ---") 117 | print("") 118 | print("Without named term classification (NTC)") 119 | print("A relation is considered correct if the relation type and the spans of the two " 120 | "related terms are predicted correctly (term type is not considered)") 121 | print("") 122 | gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_term_types=False) 123 | rel_eval = self._score(gt, pred, print_results=True) 124 | 125 | print("") 126 | print("With named term classification (NTC)") 127 | print("A relation is considered correct if the relation type and the two " 128 | "related terms are predicted correctly (in span and term type)") 129 | print("") 130 | gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_term_types=True) 131 | rel_nec_eval = self._score(gt, pred, print_results=True) 132 | 133 | return ner_eval, rel_eval, rel_nec_eval 134 | 135 | def store_predictions(self): 136 | predictions = [] 137 | 138 | for i, doc in enumerate(self._dataset.documents): 139 | tokens = doc.tokens 140 | pred_terms = self._pred_terms[i] 141 | pred_relations = self._pred_relations[i] 142 | 143 | # convert terms 144 | converted_terms = [] 145 | for term in pred_terms: 146 | term_span = term[:2] 147 | span_tokens = util.get_span_tokens(tokens, term_span) 148 | term_type = term[2].identifier 149 | # if term_type == 'None': 150 | # continue 151 | converted_term = dict(type=term_type, start=span_tokens[0].index, end=span_tokens[-1].index + 1) 152 | converted_terms.append(converted_term) 153 | converted_terms = sorted(converted_terms, key=lambda e: e['start']) 154 | 155 | # print('converted_terms: ', converted_terms) 156 | 157 | # convert relations 158 | converted_relations = [] 159 | for relation in pred_relations: 160 | head, tail = relation[:2] 161 | head_span, head_type = head[:2], head[2].identifier 162 | tail_span, tail_type = tail[:2], tail[2].identifier 163 | head_span_tokens = util.get_span_tokens(tokens, head_span) 164 | tail_span_tokens = util.get_span_tokens(tokens, tail_span) 165 | relation_type = relation[2].identifier 166 | 167 | converted_head = dict(type=head_type, start=head_span_tokens[0].index, 168 | end=head_span_tokens[-1].index + 1) 169 | converted_tail = dict(type=tail_type, start=tail_span_tokens[0].index, 170 | end=tail_span_tokens[-1].index + 1) 171 | 172 | # print(converted_tail) 173 | head_idx = converted_terms.index(converted_head) 174 | tail_idx = converted_terms.index(converted_tail) 175 | 176 | converted_relation = dict(type=relation_type, head=head_idx, tail=tail_idx) 177 | converted_relations.append(converted_relation) 178 | converted_relations = sorted(converted_relations, key=lambda r: r['head']) 179 | 180 | doc_predictions = dict(tokens=[t.phrase for t in tokens], terms=converted_terms, 181 | relations=converted_relations) 182 | predictions.append(doc_predictions) 183 | 184 | # store as json 185 | label, epoch = self._dataset_label, self._epoch 186 | with open(self._predictions_path % (label, epoch), 'w') as predictions_file: 187 | json.dump(predictions, predictions_file) 188 | 189 | def store_examples(self): 190 | if jinja2 is None: 191 | warnings.warn("Examples cannot be stored since Jinja2 is not installed.") 192 | return 193 | 194 | term_examples = [] 195 | rel_examples = [] 196 | rel_examples_nec = [] 197 | 198 | for i, doc in enumerate(self._dataset.documents): 199 | # terms 200 | term_example = self._convert_example(doc, self._gt_terms[i], self._pred_terms[i], 201 | include_term_types=True, to_html=self._term_to_html) 202 | term_examples.append(term_example) 203 | 204 | # relations 205 | # without term types 206 | rel_example = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i], 207 | include_term_types=False, to_html=self._rel_to_html) 208 | rel_examples.append(rel_example) 209 | 210 | # with term types 211 | rel_example_nec = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i], 212 | include_term_types=True, to_html=self._rel_to_html) 213 | rel_examples_nec.append(rel_example_nec) 214 | 215 | label, epoch = self._dataset_label, self._epoch 216 | 217 | # terms 218 | self._store_examples(term_examples[:self._example_count], 219 | file_path=self._examples_path % ('terms', label, epoch), 220 | template='term_examples.html') 221 | 222 | self._store_examples(sorted(term_examples[:self._example_count], 223 | key=lambda k: k['length']), 224 | file_path=self._examples_path % ('terms_sorted', label, epoch), 225 | template='term_examples.html') 226 | 227 | # relations 228 | # without term types 229 | self._store_examples(rel_examples[:self._example_count], 230 | file_path=self._examples_path % ('rel', label, epoch), 231 | template='relation_examples.html') 232 | 233 | self._store_examples(sorted(rel_examples[:self._example_count], 234 | key=lambda k: k['length']), 235 | file_path=self._examples_path % ('rel_sorted', label, epoch), 236 | template='relation_examples.html') 237 | 238 | # with term types 239 | self._store_examples(rel_examples_nec[:self._example_count], 240 | file_path=self._examples_path % ('rel_nec', label, epoch), 241 | template='relation_examples.html') 242 | 243 | self._store_examples(sorted(rel_examples_nec[:self._example_count], 244 | key=lambda k: k['length']), 245 | file_path=self._examples_path % ('rel_nec_sorted', label, epoch), 246 | template='relation_examples.html') 247 | 248 | def _convert_gt(self, docs: List[Document]): 249 | for doc in docs: 250 | gt_relations = doc.relations 251 | gt_terms = doc.terms 252 | 253 | # convert ground truth relations and terms for precision/recall/f1 evaluation 254 | sample_gt_terms = [term.as_tuple() for term in gt_terms] 255 | sample_gt_relations = [rel.as_tuple() for rel in gt_relations] 256 | 257 | if self._no_overlapping: 258 | sample_gt_terms, sample_gt_relations = self._remove_overlapping(sample_gt_terms, 259 | sample_gt_relations) 260 | 261 | self._gt_terms.append(sample_gt_terms) 262 | self._gt_relations.append(sample_gt_relations) 263 | 264 | def _convert_pred_terms(self, pred_types: torch.tensor, pred_spans: torch.tensor, pred_scores: torch.tensor): 265 | converted_preds = [] 266 | 267 | for i in range(pred_types.shape[0]): 268 | label_idx = pred_types[i].item() 269 | term_type = self._input_reader.get_term_type(label_idx) 270 | 271 | start, end = pred_spans[i].tolist() 272 | score = pred_scores[i].item() 273 | 274 | converted_pred = (start, end, term_type, score) 275 | converted_preds.append(converted_pred) 276 | 277 | return converted_preds 278 | 279 | def _convert_pred_relations(self, pred_rel_types: torch.tensor, pred_term_spans: torch.tensor, 280 | pred_term_types: torch.tensor, pred_scores: torch.tensor): 281 | converted_rels = [] 282 | check = set() 283 | 284 | for i in range(pred_rel_types.shape[0]): 285 | label_idx = pred_rel_types[i].item() 286 | pred_rel_type = self._input_reader.get_relation_type(label_idx) 287 | pred_head_type_idx, pred_tail_type_idx = pred_term_types[i][0].item(), pred_term_types[i][1].item() 288 | pred_head_type = self._input_reader.get_term_type(pred_head_type_idx) 289 | pred_tail_type = self._input_reader.get_term_type(pred_tail_type_idx) 290 | score = pred_scores[i].item() 291 | 292 | spans = pred_term_spans[i] 293 | head_start, head_end = spans[0].tolist() 294 | tail_start, tail_end = spans[1].tolist() 295 | 296 | converted_rel = ((head_start, head_end, pred_head_type), 297 | (tail_start, tail_end, pred_tail_type), pred_rel_type) 298 | converted_rel = self._adjust_rel(converted_rel) 299 | 300 | if converted_rel not in check: 301 | check.add(converted_rel) 302 | converted_rels.append(tuple(list(converted_rel) + [score])) 303 | 304 | return converted_rels 305 | 306 | def _remove_overlapping(self, terms, relations): 307 | non_overlapping_terms = [] 308 | non_overlapping_relations = [] 309 | 310 | for term in terms: 311 | if not self._is_overlapping(term, terms): 312 | non_overlapping_terms.append(term) 313 | 314 | for rel in relations: 315 | e1, e2 = rel[0], rel[1] 316 | if not self._check_overlap(e1, e2): 317 | non_overlapping_relations.append(rel) 318 | 319 | return non_overlapping_terms, non_overlapping_relations 320 | 321 | def _is_overlapping(self, e1, terms): 322 | for e2 in terms: 323 | if self._check_overlap(e1, e2): 324 | return True 325 | 326 | return False 327 | 328 | def _check_overlap(self, e1, e2): 329 | if e1 == e2 or e1[1] <= e2[0] or e2[1] <= e1[0]: 330 | return False 331 | else: 332 | return True 333 | 334 | def _adjust_rel(self, rel: Tuple): 335 | adjusted_rel = rel 336 | if rel[-1].symmetric: 337 | head, tail = rel[:2] 338 | if tail[0] < head[0]: 339 | adjusted_rel = tail, head, rel[-1] 340 | 341 | return adjusted_rel 342 | 343 | def _convert_by_setting(self, gt: List[List[Tuple]], pred: List[List[Tuple]], 344 | include_term_types: bool = True, include_score: bool = False): 345 | assert len(gt) == len(pred) 346 | 347 | # either include or remove term types based on setting 348 | def convert(t): 349 | if not include_term_types: 350 | # remove term type and score for evaluation 351 | if type(t[0]) == int: # term 352 | c = [t[0], t[1], self._pseudo_term_type] 353 | else: # relation 354 | c = [(t[0][0], t[0][1], self._pseudo_term_type), 355 | (t[1][0], t[1][1], self._pseudo_term_type), t[2]] 356 | else: 357 | c = list(t[:3]) 358 | 359 | if include_score and len(t) > 3: 360 | # include prediction scores 361 | c.append(t[3]) 362 | 363 | return tuple(c) 364 | 365 | converted_gt, converted_pred = [], [] 366 | 367 | for sample_gt, sample_pred in zip(gt, pred): 368 | converted_gt.append([convert(t) for t in sample_gt]) 369 | converted_pred.append([convert(t) for t in sample_pred]) 370 | 371 | return converted_gt, converted_pred 372 | 373 | def _score(self, gt: List[List[Tuple]], pred: List[List[Tuple]], print_results: bool = False): 374 | assert len(gt) == len(pred) 375 | 376 | gt_flat = [] 377 | pred_flat = [] 378 | types = set() 379 | 380 | for (sample_gt, sample_pred) in zip(gt, pred): 381 | union = set() 382 | union.update(sample_gt) 383 | union.update(sample_pred) 384 | 385 | for s in union: 386 | if s in sample_gt: 387 | t = s[2] 388 | gt_flat.append(t.index) 389 | types.add(t) 390 | else: 391 | gt_flat.append(0) 392 | 393 | if s in sample_pred: 394 | t = s[2] 395 | pred_flat.append(t.index) 396 | types.add(t) 397 | else: 398 | pred_flat.append(0) 399 | 400 | metrics = self._compute_metrics(gt_flat, pred_flat, types, print_results) 401 | return metrics 402 | 403 | def _compute_metrics(self, gt_all, pred_all, types, print_results: bool = False): 404 | labels = [t.index for t in types] 405 | per_type = prfs(gt_all, pred_all, labels=labels, average=None) 406 | micro = prfs(gt_all, pred_all, labels=labels, average='micro')[:-1] 407 | macro = prfs(gt_all, pred_all, labels=labels, average='macro')[:-1] 408 | total_support = sum(per_type[-1]) 409 | 410 | if print_results: 411 | self._print_results(per_type, list(micro) + [total_support], list(macro) + [total_support], types) 412 | 413 | return [m * 100 for m in micro + macro] 414 | 415 | def _print_results(self, per_type: List, micro: List, macro: List, types: List): 416 | columns = ('type', 'precision', 'recall', 'f1-score', 'support') 417 | 418 | row_fmt = "%20s" + (" %12s" * (len(columns) - 1)) 419 | results = [row_fmt % columns, '\n'] 420 | 421 | metrics_per_type = [] 422 | for i, t in enumerate(types): 423 | metrics = [] 424 | for j in range(len(per_type)): 425 | metrics.append(per_type[j][i]) 426 | metrics_per_type.append(metrics) 427 | 428 | for m, t in zip(metrics_per_type, types): 429 | results.append(row_fmt % self._get_row(m, t.short_name)) 430 | results.append('\n') 431 | 432 | results.append('\n') 433 | 434 | # micro 435 | results.append(row_fmt % self._get_row(micro, 'micro')) 436 | results.append('\n') 437 | 438 | # macro 439 | results.append(row_fmt % self._get_row(macro, 'macro')) 440 | 441 | results_str = ''.join(results) 442 | print(results_str) 443 | 444 | def _get_row(self, data, label): 445 | row = [label] 446 | for i in range(len(data) - 1): 447 | row.append("%.2f" % (data[i] * 100)) 448 | row.append(data[3]) 449 | return tuple(row) 450 | 451 | def _convert_example(self, doc: Document, gt: List[Tuple], pred: List[Tuple], 452 | include_term_types: bool, to_html): 453 | encoding = doc.encoding 454 | 455 | gt, pred = self._convert_by_setting([gt], [pred], include_term_types=include_term_types, include_score=True) 456 | gt, pred = gt[0], pred[0] 457 | 458 | # get micro precision/recall/f1 scores 459 | if gt or pred: 460 | pred_s = [p[:3] for p in pred] # remove score 461 | precision, recall, f1 = self._score([gt], [pred_s])[:3] 462 | else: 463 | # corner case: no ground truth and no predictions 464 | precision, recall, f1 = [100] * 3 465 | 466 | scores = [p[-1] for p in pred] 467 | pred = [p[:-1] for p in pred] 468 | union = set(gt + pred) 469 | 470 | # true positives 471 | tp = [] 472 | # false negatives 473 | fn = [] 474 | # false positives 475 | fp = [] 476 | 477 | for s in union: 478 | type_verbose = s[2].verbose_name 479 | 480 | if s in gt: 481 | if s in pred: 482 | score = scores[pred.index(s)] 483 | tp.append((to_html(s, encoding), type_verbose, score)) 484 | else: 485 | fn.append((to_html(s, encoding), type_verbose, -1)) 486 | else: 487 | score = scores[pred.index(s)] 488 | fp.append((to_html(s, encoding), type_verbose, score)) 489 | 490 | tp = sorted(tp, key=lambda p: p[-1], reverse=True) 491 | fp = sorted(fp, key=lambda p: p[-1], reverse=True) 492 | 493 | text = self._prettify(self._text_encoder.decode(encoding)) 494 | return dict(text=text, tp=tp, fn=fn, fp=fp, precision=precision, recall=recall, f1=f1, length=len(doc.tokens)) 495 | 496 | def _term_to_html(self, term: Tuple, encoding: List[int]): 497 | start, end = term[:2] 498 | term_type = term[2].verbose_name 499 | 500 | tag_start = ' ' 501 | tag_start += '%s' % term_type 502 | 503 | ctx_before = self._text_encoder.decode(encoding[:start]) 504 | e1 = self._text_encoder.decode(encoding[start:end]) 505 | ctx_after = self._text_encoder.decode(encoding[end:]) 506 | 507 | html = ctx_before + tag_start + e1 + ' ' + ctx_after 508 | html = self._prettify(html) 509 | 510 | return html 511 | 512 | def _rel_to_html(self, relation: Tuple, encoding: List[int]): 513 | head, tail = relation[:2] 514 | head_tag = ' %s' 515 | tail_tag = ' %s' 516 | 517 | if head[0] < tail[0]: 518 | e1, e2 = head, tail 519 | e1_tag, e2_tag = head_tag % head[2].verbose_name, tail_tag % tail[2].verbose_name 520 | else: 521 | e1, e2 = tail, head 522 | e1_tag, e2_tag = tail_tag % tail[2].verbose_name, head_tag % head[2].verbose_name 523 | 524 | segments = [encoding[:e1[0]], encoding[e1[0]:e1[1]], encoding[e1[1]:e2[0]], 525 | encoding[e2[0]:e2[1]], encoding[e2[1]:]] 526 | 527 | ctx_before = self._text_encoder.decode(segments[0]) 528 | e1 = self._text_encoder.decode(segments[1]) 529 | ctx_between = self._text_encoder.decode(segments[2]) 530 | e2 = self._text_encoder.decode(segments[3]) 531 | ctx_after = self._text_encoder.decode(segments[4]) 532 | 533 | html = (ctx_before + e1_tag + e1 + ' ' 534 | + ctx_between + e2_tag + e2 + ' ' + ctx_after) 535 | html = self._prettify(html) 536 | 537 | return html 538 | 539 | def _prettify(self, text: str): 540 | text = text.replace('_start_', '').replace('_classify_', '').replace('', '').replace('⁇', '') 541 | text = text.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '') 542 | return text 543 | 544 | def _store_examples(self, examples: List[Dict], file_path: str, template: str): 545 | template_path = os.path.join(SCRIPT_PATH, 'templates', template) 546 | 547 | # read template 548 | with open(os.path.join(SCRIPT_PATH, template_path)) as f: 549 | template = jinja2.Template(f.read()) 550 | 551 | # write to disc 552 | template.stream(examples=examples).dump(file_path) 553 | -------------------------------------------------------------------------------- /SynFue/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn as nn 4 | from transformers import BertConfig 5 | from transformers import BertModel 6 | from transformers import BertPreTrainedModel 7 | 8 | from SynFue import sampling 9 | from SynFue import util 10 | from SynFue import Encoder 11 | from SynFue import cross_attn 12 | 13 | 14 | def get_token(h: torch.tensor, x: torch.tensor, token: int): 15 | """ Get specific token embedding (e.g. [CLS]) """ 16 | emb_size = h.shape[-1] 17 | 18 | token_h = h.view(-1, emb_size) 19 | flat = x.contiguous().view(-1) 20 | 21 | # get contextualized embedding of given token 22 | token_h = token_h[flat == token, :] 23 | 24 | return token_h 25 | 26 | 27 | def get_head_tail_rep(h, head_tail_index): 28 | """ 29 | 30 | :param h: torch.tensor [batch size, seq_len, feat_dim] 31 | :param head_tail_index: [batch size, term_num, 2] 32 | :return: 33 | """ 34 | res = [] 35 | batch_size = head_tail_index.size(0) 36 | term_num = head_tail_index.size(1) 37 | for b in range(batch_size): 38 | temp = [] 39 | for t in range(term_num): 40 | temp.append(torch.index_select(h[b], 0, head_tail_index[b][t]).view(-1)) 41 | res.append(torch.stack(temp, dim=0)) 42 | res = torch.stack(res) 43 | return res 44 | 45 | 46 | class SynFueBERT(BertPreTrainedModel): 47 | """ Span-based model to jointly extract terms and relations """ 48 | 49 | VERSION = '1.2' 50 | 51 | def __init__(self, config: BertConfig, cls_token: int, relation_types: int, term_types: int, 52 | size_embedding: int, prop_drop: float, freeze_transformer: bool, args, max_pairs: int = 100, 53 | beta: float = 0.3, alpha: float = 1.0, sigma: float = 1.0): 54 | super(SynFueBERT, self).__init__(config) 55 | 56 | # BERT model 57 | self.bert = BertModel(config) 58 | self.bert_dropout = nn.Dropout(args.bert_dropout) 59 | self.SynFue = Encoder.SynFueEncoder(opt=args) 60 | self.cc = cross_attn.CA_module(config.hidden_size, config.hidden_size, 1, dropout=1.0) 61 | 62 | # layers 63 | self.rel_classifier = nn.Linear(config.hidden_size * 4 + size_embedding * 2, relation_types) 64 | self.rel_classifier3 = nn.Linear(config.hidden_size * 3 + size_embedding * 3, relation_types) 65 | self.term_linear = nn.Linear(config.hidden_size * 7 + size_embedding, config.hidden_size) 66 | self.term_classifier = nn.Linear(config.hidden_size, term_types) 67 | self.dep_linear = nn.Linear(config.hidden_size, relation_types) 68 | self.size_embeddings = nn.Embedding(100, size_embedding) 69 | self.dropout = nn.Dropout(prop_drop) 70 | 71 | self._cls_token = cls_token 72 | self._relation_types = relation_types 73 | self._term_types = term_types 74 | self._max_pairs = max_pairs 75 | self._beta = beta 76 | self._alpha = alpha 77 | self._sigma = sigma 78 | 79 | # weight initialization 80 | self.init_weights() 81 | 82 | if freeze_transformer: 83 | print("Freeze transformer weights") 84 | 85 | # freeze all transformer weights 86 | for param in self.bert.parameters(): 87 | param.requires_grad = False 88 | 89 | # def init_weights(self): 90 | # for name, param in self.SynFue.named_parameters(): 91 | # if name.find('weight') != -1: 92 | # torch.nn.init.normal_(param, 0, 1) 93 | # elif name.find('bias') != -1: 94 | # torch.nn.init.constant_(param, 0) 95 | # else: 96 | # torch.nn.init.xavier_normal_(param) 97 | # for name, param in self.cc.named_parameters(): 98 | # if name.find('weight') != -1: 99 | # torch.nn.init.normal_(param, 0, 1) 100 | # elif name.find('bias') != -1: 101 | # torch.nn.init.constant_(param, 0) 102 | # else: 103 | # torch.nn.init.xavier_normal_(param) 104 | 105 | def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor, 106 | term_sizes: torch.tensor, term_spans: torch.tensor, term_types: torch.tensor, 107 | relations: torch.tensor, rel_masks: torch.tensor, 108 | simple_graph: torch.tensor, graph: torch.tensor, 109 | relations3: torch.tensor, rel_masks3: torch.tensor, pair_mask: torch.tensor, 110 | pos: torch.tensor = None, pieces2word: torch.tensor = None): 111 | """ 112 | 113 | :param encodings: [B, L'] 114 | :param context_masks: [B, L'] 115 | :param term_masks: [B, max_term, L] 116 | :param term_sizes: [B, max_term] 117 | :param term_spans: [B, max_term, 2] 118 | :param term_types: [B, max_term] 119 | :param relations: [B, max_relation, 2] 120 | :param rel_masks: [B, max_relation, L] 121 | :param simple_graph: [B, L, L] 122 | :param graph: [B, L, L] 123 | :param relations3: [B, max_relation, 3] 124 | :param rel_masks3:[B, max_relation, max_relation, L] 125 | :param pair_mask: [B, max_relation] 126 | :param pos: [B, L] 127 | :param pieces2word: [B, L, L'] 128 | :return: 129 | """ 130 | # get contextualized token embeddings from last transformer layer 131 | word_reps, cls_output = self._bert_encoder(input_ids=encodings, pieces2word=pieces2word) 132 | h, dep_output = self.SynFue(word_reps=word_reps, simple_graph=simple_graph, graph=graph, pos=pos) 133 | 134 | batch_size = encodings.shape[0] 135 | 136 | # classify terms 137 | size_embeddings = self.size_embeddings(term_sizes) # embed term candidate sizes 138 | term_clf, term_reps = self._classify_terms(h, term_masks, size_embeddings, cls_output) 139 | 140 | # classify relations 141 | h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) # B, max_rel, L, H 142 | rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to( 143 | self.rel_classifier.weight.device) # B, max_rel, relation_types 144 | 145 | # get span representation 146 | # dep_output = [B, L, L, H] -> [batch size, span num, span num, feat_dim] 147 | span_repr, mapping_list = self.get_span_repr(term_spans, term_types, dep_output) 148 | # apply across-attention to compute the syntax-aware inter-span representation 149 | cross_attn_span = self.cc(span_repr) # batch size, span_num, soan_num, feat_dim 150 | 151 | # obtain relation logits 152 | # chunk processing to reduce memory usage 153 | for i in range(0, relations.shape[1], self._max_pairs): 154 | # classify relation candidates 155 | chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations(cross_attn_span, 156 | term_reps, 157 | size_embeddings, 158 | relations, rel_masks, 159 | h_large, i, 160 | relations3, rel_masks3, 161 | pair_mask, mapping_list) 162 | # apply sigmoid 163 | chunk_rel_clf = torch.sigmoid(chunk_rel_logits) 164 | chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score 165 | rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf 166 | 167 | max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item()) 168 | min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item()) 169 | inifite = torch.full_like(rel_clf, 1e-18) 170 | rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite) 171 | 172 | return term_clf, rel_clf 173 | 174 | def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor, 175 | term_sizes: torch.tensor, term_spans: torch.tensor, term_sample_masks: torch.tensor, 176 | simple_graph: torch.tensor, graph: torch.tensor, pos: torch.tensor = None, 177 | pieces2word: torch.tensor = None): 178 | # get contextualized token embeddings from last transformer layer 179 | # context_masks = context_masks.float() 180 | word_reps, cls_output = self._bert_encoder(input_ids=encodings, pieces2word=pieces2word) 181 | h, dep_output = self.SynFue(word_reps=word_reps, simple_graph=simple_graph, graph=graph, pos=pos) 182 | 183 | batch_size = encodings.shape[0] 184 | ctx_size = term_masks.shape[-1] 185 | 186 | # classify terms 187 | size_embeddings = self.size_embeddings(term_sizes) # embed term candidate sizes 188 | term_clf, term_reps = self._classify_terms(h, term_masks, size_embeddings, cls_output) 189 | 190 | # ignore term candidates that do not constitute an actual term for relations (based on classifier) 191 | relations, rel_masks, rel_sample_masks, relations3, rel_masks3, \ 192 | rel_sample_masks3, pair_mask, span_repr, mapping_list = self._filter_spans(term_clf, term_spans, 193 | term_sample_masks, 194 | ctx_size, dep_output) 195 | 196 | rel_sample_masks = rel_sample_masks.float().unsqueeze(-1) 197 | # h = self.rel_bert(input_ids=encodings, attention_mask=context_masks)[0] 198 | h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) 199 | rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to( 200 | self.rel_classifier.weight.device) 201 | 202 | # get span representation 203 | cross_attn_span = self.cc(span_repr) # batch size, seq_len, seq_len, feat_dim 204 | 205 | # obtain relation logits 206 | # chunk processing to reduce memory usage 207 | for i in range(0, relations.shape[1], self._max_pairs): 208 | # classify relation candidates 209 | chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations(cross_attn_span, 210 | term_reps, 211 | size_embeddings, 212 | relations, rel_masks, 213 | h_large, i, 214 | relations3, rel_masks3, 215 | pair_mask, mapping_list) 216 | # apply sigmoid 217 | chunk_rel_clf = torch.sigmoid(chunk_rel_logits) 218 | chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score 219 | rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf 220 | 221 | max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item()) 222 | min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item()) 223 | inifite = torch.full_like(rel_clf, 1e-18) 224 | rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite) 225 | 226 | rel_clf = rel_clf * rel_sample_masks # mask 227 | 228 | # apply softmax 229 | term_clf = torch.softmax(term_clf, dim=2) 230 | 231 | return term_clf, rel_clf, relations 232 | 233 | def _classify_terms(self, h, term_masks, size_embeddings, cls_output): 234 | """ 235 | 236 | :param h: [b, L, H*2] 237 | :param term_masks: [B, max_term, L] 238 | :param size_embeddings: [b, max_term, size_embedding] 239 | :param cls_output: [B, H] 240 | :return: 241 | term_clf: [B, max_term, term_types] 242 | term_spans_pool: [B, max_term, H] 243 | """ 244 | # max pool term candidate spans 245 | # m = (term_masks.unsqueeze(-1) == 0).float() * (-1e30) 246 | # term_spans_pool = m + h.unsqueeze(1).repeat(1, term_masks.shape[1], 1, 1) 247 | # term_spans_pool = term_spans_pool.max(dim=2)[0] 248 | 249 | min_value = torch.min(h).item() 250 | _h = h.unsqueeze(1).expand(-1, term_masks.size(1), -1, -1) 251 | _h = torch.masked_fill(_h, term_masks.eq(0).unsqueeze(-1), min_value) 252 | term_spans_pool, _ = torch.max(_h, dim=2) 253 | 254 | # get cls token as candidate context representation 255 | # term_ctx = get_token(h, encodings, self._cls_token) 256 | 257 | # get head and tail token representation 258 | m = term_masks.to(dtype=torch.long) 259 | k = torch.tensor(np.arange(0, term_masks.size(-1)), dtype=torch.long) 260 | k = k.unsqueeze(0).unsqueeze(0).repeat(term_masks.size(0), term_masks.size(1), 1).to(m.device) 261 | mk = torch.mul(m, k) # element-wise multiply 262 | mk_max = torch.argmax(mk, dim=-1, keepdim=True) 263 | mk_min = torch.argmin(mk, dim=-1, keepdim=True) 264 | mk = torch.cat([mk_min, mk_max], dim=-1) 265 | head_tail_rep = get_head_tail_rep(h, mk) # [batch size, term_num, bert_dim*2) 266 | 267 | # create candidate representations including context, max pooled span and size embedding, head and tail 268 | term_repr = torch.cat([cls_output.unsqueeze(1).repeat(1, term_spans_pool.shape[1], 1), 269 | term_spans_pool, size_embeddings, head_tail_rep], dim=2) 270 | term_repr = self.dropout(term_repr) 271 | term_repr = self.term_linear(term_repr) 272 | 273 | # classify term candidates 274 | term_clf = self.term_classifier(term_repr) 275 | 276 | return term_clf, term_repr 277 | 278 | def _classify_relations(self, spans_matrix, term_spans_repr, size_embeddings, relations, rel_masks, 279 | h, chunk_start, relations3, rel_masks3, pair_mask, rel_to_span): 280 | """ 281 | 282 | :param spans_matrix: [B, max_term, max_term, H] 283 | :param term_spans_repr: [B, max_term, H] 284 | :param size_embeddings: [B, max_term, term_size] 285 | :param relations: [B, max_rel, 2] 286 | :param rel_masks: [B, max_rel, L] 287 | :param h: [B, max_rel, L, H*2] 288 | :param chunk_start: 289 | :param relations3: [B, max_rel, 3] 290 | :param rel_masks3: [B, max_rel, max_rel, L] 291 | :param pair_mask: 292 | :param rel_to_span: 293 | :return: 294 | """ 295 | batch_size = relations.shape[0] 296 | feat_dim = spans_matrix.size(-1) 297 | 298 | # create chunks if necessary 299 | if relations.shape[1] > self._max_pairs: 300 | relations = relations[:, chunk_start:chunk_start + self._max_pairs] 301 | rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs] 302 | h = h[:, :relations.shape[1], :] 303 | 304 | def get_span_idx(mapping_list, idx1, idx2): 305 | for x in mapping_list: 306 | if idx1 == x[0][0] and idx2 == x[0][1]: 307 | return x[1][0], x[1][1] 308 | 309 | batch_dep_score = [] 310 | for i in range(batch_size): 311 | _rel = relations[i] 312 | dep_score_list = [] 313 | r_2_s = rel_to_span[i] 314 | for r in _rel: 315 | i1, i2 = r[0].item(), r[1].item() 316 | idx1, idx2 = get_span_idx(r_2_s, i1, i2) 317 | try: 318 | feat = spans_matrix[i][idx1][idx2] 319 | except: 320 | print('Out of bundary', spans_matrix.size(), i, i1, i2) 321 | feat = torch.zeros(feat_dim) 322 | dep_socre = self.dep_linear(feat).item() 323 | dep_score_list.append([dep_socre]) 324 | batch_dep_score.append(dep_score_list) 325 | 326 | batch_dep_score = torch.sigmoid( 327 | torch.tensor(batch_dep_score).to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))) 328 | 329 | # get pairs of term candidate representations 330 | term_pairs = util.batch_index(term_spans_repr, relations) 331 | term_pairs = term_pairs.view(batch_size, term_pairs.shape[1], -1) 332 | 333 | # get corresponding size embeddings 334 | size_pair_embeddings = util.batch_index(size_embeddings, relations) 335 | size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1) 336 | 337 | # relation context (context between term candidate pair) 338 | # mask non term candidate tokens 339 | m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1) 340 | rel_ctx = m + h 341 | # max pooling 342 | rel_ctx = rel_ctx.max(dim=2)[0] 343 | # set the context vector of neighboring or adjacent term candidates to zero 344 | rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0 345 | 346 | # create relation candidate representations including context, max pooled term candidate pairs 347 | # and corresponding size embeddings 348 | rel_repr = torch.cat([rel_ctx, term_pairs, size_pair_embeddings], dim=2) 349 | rel_repr = self.dropout(rel_repr) 350 | # classify relation candidates 351 | chunk_rel_logits = self.rel_classifier(rel_repr) 352 | 353 | if relations3.shape[1] > self._max_pairs: 354 | relations3 = relations3[:, chunk_start:chunk_start + self._max_pairs] 355 | # rel_masks3 = rel_masks3[:, chunk_start:chunk_start + self._max_pairs] 356 | 357 | p_num = relations3.size(1) 358 | p_tris = relations3.size(2) 359 | 360 | relations3 = relations3.view(batch_size, -1, 3) 361 | 362 | # get three pairs candidata representations 363 | term_pairs3 = util.batch_index(term_spans_repr, relations3) 364 | term_pairs3 = term_pairs3.view(batch_size, term_pairs3.shape[1], -1) 365 | 366 | size_pair_embeddings3 = util.batch_index(size_embeddings, relations3) 367 | size_pair_embeddings3 = size_pair_embeddings3.view(batch_size, size_pair_embeddings3.shape[1], -1) 368 | 369 | rel_repr = torch.cat([term_pairs3, size_pair_embeddings3], dim=2) 370 | rel_repr = self.dropout(rel_repr) 371 | # classify relation candidates 372 | chunk_rel_logits3 = self.rel_classifier3(rel_repr) 373 | 374 | chunk_rel_clf3 = chunk_rel_logits3.view(batch_size, p_num, p_tris, -1) 375 | chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3) 376 | 377 | chunk_rel_clf3 = torch.sum(chunk_rel_clf3, dim=2) 378 | chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3) 379 | 380 | return chunk_rel_logits, chunk_rel_clf3, batch_dep_score 381 | 382 | def _filter_spans(self, term_clf, term_spans, term_sample_masks, ctx_size, token_repr): 383 | """ 384 | according to the results of term type detection, we first filter the invalid term, i.e., only keeping aspect 385 | term and opinion term, and then we construct term pairs and term triplets for the following relations detection. 386 | :param term_clf: [B, max_term, term_types] 387 | :param term_spans: [B, max_term, 2] 388 | :param term_sample_masks: 389 | :param ctx_size: L 390 | :param token_repr: [B, L, H] 391 | :return: 392 | """ 393 | batch_size = term_clf.shape[0] 394 | feat_dim = token_repr.size(-1) 395 | term_logits_max = term_clf.argmax(dim=-1) * term_sample_masks.long() # get term type (including none) 396 | batch_relations = [] 397 | batch_rel_masks = [] 398 | batch_rel_sample_masks = [] 399 | 400 | batch_relations3 = [] 401 | batch_rel_masks3 = [] 402 | batch_rel_sample_masks3 = [] 403 | batch_pair_mask = [] 404 | 405 | batch_span_repr = [] 406 | batch_rel_to_span = [] 407 | 408 | for i in range(batch_size): 409 | rels = [] 410 | rel_masks = [] 411 | sample_masks = [] 412 | rels3 = [] 413 | rel_masks3 = [] 414 | sample_masks3 = [] 415 | 416 | span_repr = [] 417 | rel_to_span = [] 418 | 419 | # get spans classified as terms 420 | non_zero_indices = (term_logits_max[i] != 0).nonzero().view(-1) 421 | non_zero_spans = term_spans[i][non_zero_indices].tolist() 422 | non_zero_indices = non_zero_indices.tolist() 423 | 424 | # create relations and masks 425 | pair_mask = [] 426 | for idx1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): 427 | temp = [] 428 | for idx2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)): 429 | if i1 != i2: 430 | rels.append((i1, i2)) 431 | rel_masks.append(sampling.create_rel_mask(s1, s2, ctx_size)) 432 | sample_masks.append(1) 433 | p_rels3 = [] 434 | p_masks3 = [] 435 | for i3, s3 in zip(non_zero_indices, non_zero_spans): 436 | if i1 != i2 and i1 != i3 and i2 != i3: 437 | p_rels3.append((i1, i2, i3)) 438 | p_masks3.append(sampling.create_rel_mask3(s1, s2, s3, ctx_size)) 439 | sample_masks3.append(1) 440 | if len(p_rels3) > 0: 441 | rels3.append(p_rels3) 442 | rel_masks3.append(p_masks3) 443 | pair_mask.append(1) 444 | else: 445 | rels3.append([(i1, i2, 0)]) 446 | rel_masks3.append([sampling.create_rel_mask3(s1, s2, (0, 0), ctx_size)]) 447 | pair_mask.append(0) 448 | rel_to_span.append([[i1, i2], [idx1, idx2]]) 449 | feat = \ 450 | torch.max(token_repr[i, s1[0]: s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view(-1, feat_dim), 451 | dim=0)[0] 452 | temp.append(feat) 453 | span_repr.append(temp) 454 | 455 | if not rels: 456 | # case: no more than two spans classified as terms 457 | batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long)) 458 | batch_rel_masks.append(torch.tensor([[0] * ctx_size], dtype=torch.bool)) 459 | batch_rel_sample_masks.append(torch.tensor([0], dtype=torch.bool)) 460 | batch_span_repr.append(torch.tensor([[[0] * feat_dim]], dtype=torch.float)) 461 | batch_rel_to_span.append([[[0, 0], [0, 0]]]) 462 | else: 463 | # case: more than two spans classified as terms 464 | batch_relations.append(torch.tensor(rels, dtype=torch.long)) 465 | batch_rel_masks.append(torch.stack(rel_masks)) 466 | batch_rel_sample_masks.append(torch.tensor(sample_masks, dtype=torch.bool)) 467 | batch_span_repr.append(torch.stack([torch.stack(x) for x in span_repr])) 468 | batch_rel_to_span.append(rel_to_span) 469 | 470 | if not rels3: 471 | batch_relations3.append(torch.tensor([[[0, 0, 0]]], dtype=torch.long)) 472 | batch_rel_masks3.append(torch.tensor([[0] * ctx_size], dtype=torch.bool)) 473 | batch_rel_sample_masks3.append(torch.tensor([0], dtype=torch.bool)) 474 | batch_pair_mask.append(torch.tensor([0], dtype=torch.bool)) 475 | 476 | else: 477 | max_tri = max([len(x) for x in rels3]) 478 | # print(max_tri) 479 | for idx, r in enumerate(rels3): 480 | r_len = len(r) 481 | if r_len < max_tri: 482 | rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len)) 483 | rel_masks3[idx].extend( 484 | [rel_masks3[idx][0]] * (max_tri - r_len)) 485 | batch_relations3.append(torch.tensor(rels3, dtype=torch.long)) 486 | batch_rel_masks3.append(torch.stack([torch.stack(x) for x in rel_masks3])) 487 | batch_rel_sample_masks3.append(torch.tensor(sample_masks3, dtype=torch.bool)) 488 | batch_pair_mask.append(torch.tensor(pair_mask, dtype=torch.bool)) 489 | 490 | # stack 491 | device = self.rel_classifier.weight.device 492 | batch_relations = util.padded_stack(batch_relations).to(device) 493 | batch_rel_masks = util.padded_stack(batch_rel_masks).to(device) 494 | batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(device) 495 | batch_span_repr = util.padded_stack(batch_span_repr).to(device) 496 | 497 | batch_relations3 = util.padded_stack(batch_relations3).to(device) 498 | batch_rel_masks3 = util.padded_stack(batch_rel_masks3).to(device) 499 | batch_rel_sample_masks3 = util.padded_stack(batch_rel_sample_masks3).to(device) 500 | batch_pair_mask = util.padded_stack(batch_pair_mask).to(device) 501 | 502 | return batch_relations, batch_rel_masks, batch_rel_sample_masks, \ 503 | batch_relations3, batch_rel_masks3, batch_rel_sample_masks3, \ 504 | batch_pair_mask, batch_span_repr, batch_rel_to_span 505 | 506 | def get_span_repr(self, term_spans, term_types, token_repr): 507 | """ 508 | 509 | :param term_spans: [B, span_num, 2] 510 | :param term_types: [B, span_num] 511 | :param token_repr: [B, L, L, feat_dim] 512 | :return: 513 | batch_span_repr: [B, span_num, span_num, feat_dim] 514 | batch_mapping_list: B. It is used to store the correspondence between the index in span matrix (x1, x2) 515 | and the index in positive term span list (i1, i2). 516 | """ 517 | batch_size = term_spans.size(0) 518 | feat_dim = token_repr.size(-1) 519 | batch_span_repr = [] 520 | batch_mapping_list = [] 521 | for i in range(batch_size): 522 | span_repr = [] 523 | mapping_list = [] 524 | # get target spans as aspect term or opinion term 525 | non_zero_indices = (term_types[i] != 0).nonzero().view(-1) 526 | non_zero_spans = term_spans[i][non_zero_indices].tolist() 527 | non_zero_indices = non_zero_indices.tolist() 528 | for x1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): 529 | temp = [] 530 | for x2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)): 531 | feat = \ 532 | torch.max(token_repr[i, s1[0]: s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view(-1, feat_dim), 533 | dim=0)[0] 534 | temp.append(feat) 535 | mapping_list.append([[i1, i2], [x1, x2]]) 536 | 537 | span_repr.append(torch.stack(temp)) 538 | batch_span_repr.append(torch.stack(span_repr)) 539 | batch_mapping_list.append(mapping_list) 540 | 541 | device = self.rel_classifier.weight.device 542 | batch_span_repr = util.padded_stack(batch_span_repr).to(device) 543 | 544 | return batch_span_repr, batch_mapping_list 545 | 546 | def _bert_encoder(self, input_ids, pieces2word): 547 | """ 548 | 549 | :param input_ids: [B, L'], L' not equal L 550 | :param pieces2word: [B, L, L'] 551 | :return: 552 | word_reps: [B, L, H] 553 | pooler_output: [B, H] 554 | """ 555 | # sequence_output, pooled_output = self.bert(input_ids) 556 | bert_output = self.bert(input_ids=input_ids, attention_mask=input_ids.ne(0).float()) 557 | sequence_output, pooler_output = bert_output[0], bert_output[1] 558 | bert_embs = self.bert_dropout(sequence_output) 559 | 560 | length = pieces2word.size(1) 561 | min_value = torch.min(bert_embs).item() 562 | 563 | # Max pooling word representations from pieces 564 | _bert_embs = bert_embs.unsqueeze(1).expand(-1, length, -1, -1) 565 | _bert_embs = torch.masked_fill(_bert_embs, pieces2word.eq(0).unsqueeze(-1), min_value) 566 | word_reps, _ = torch.max(_bert_embs, dim=2) 567 | return word_reps, pooler_output 568 | 569 | def forward(self, *args, evaluate=False, **kwargs): 570 | if not evaluate: 571 | return self._forward_train(*args, **kwargs) 572 | else: 573 | return self._forward_eval(*args, **kwargs) 574 | 575 | 576 | # Model access 577 | 578 | _MODELS = { 579 | 'synfue': SynFueBERT, 580 | } 581 | 582 | 583 | def get_model(name): 584 | return _MODELS[name] 585 | --------------------------------------------------------------------------------