├── SynFue
├── __init__.py
├── opt.py
├── loss.py
├── cross_attn.py
├── Encoder.py
├── base_trainer.py
├── util.py
├── templates
│ ├── term_examples.html
│ └── relation_examples.html
├── input_reader.py
├── terms.py
├── sampling.py
├── trainer.py
├── evaluator.py
└── models.py
├── data
├── datasets
│ ├── towe
│ │ └── 16res
│ │ │ ├── types.json
│ │ │ ├── pos_vocab.txt
│ │ │ ├── dep_type_vocab.txt
│ │ │ └── char_vocab.txt
│ ├── data_config.yaml
│ ├── str_utils.py
│ ├── io_utils.py
│ ├── vocab.py
│ └── preprocess.py
└── log
│ └── towe_16res
│ └── eval_valid.csv
├── configs
├── 16res_eval.conf
└── 16res_train.conf
├── synfue.py
├── README.md
├── config_reader.py
└── args.py
/SynFue/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/datasets/towe/16res/types.json:
--------------------------------------------------------------------------------
1 | {"terms": {"Asp": {"short": "Asp", "verbose": "Aspect"}, "Opi": {"short": "Opi", "verbose": "Opinion"}}, "relations": {"Pair": {"short": "Pair", "verbose": "Pair", "symmetric": true}}}
--------------------------------------------------------------------------------
/SynFue/opt.py:
--------------------------------------------------------------------------------
1 | # optional packages
2 |
3 | try:
4 | import tensorboardx
5 | except ImportError:
6 | tensorboardx = None
7 |
8 |
9 | try:
10 | import jinja2
11 | except ImportError:
12 | jinja2 = None
13 |
--------------------------------------------------------------------------------
/data/datasets/data_config.yaml:
--------------------------------------------------------------------------------
1 | ori_data_dir: 'original/'
2 | new_data_dir: 'towe/'
3 |
4 | vocab_dir: 'data/vocab'
5 | token_vocab_file: 'token_vocab.txt'
6 | char_vocab_file: 'char_vocab.txt'
7 | pos_vocab_file: 'pos_vocab.txt'
8 | dep_type_vocab_file: 'dep_type_vocab.txt'
9 | prd_type_vocab_file: 'prd_type_vocab.txt'
10 | role_type_vocab_file: 'role_type_vocab.txt'
11 | action_vocab_file: 'action_vocab.txt'
12 |
13 | normalize_digits: false
14 | lower_case: True
15 | random_seed: 2942435
--------------------------------------------------------------------------------
/configs/16res_eval.conf:
--------------------------------------------------------------------------------
1 | [1]
2 | label = towe_16res
3 | model_type = synfue
4 | model_path = data/models/towe_16res
5 | tokenizer_path = data/models/towe_16res
6 | dataset_path = data/datasets/towe/16res/test.json
7 | types_path = data/datasets/towe/16res/types.json
8 | eval_batch_size = 1
9 | rel_filter_threshold = 0.4
10 | size_embedding = 25
11 | prop_drop = 0.1
12 | max_span_size = 10
13 | store_predictions = true
14 | store_examples = true
15 | sampling_processes = 4
16 | sampling_limit = 100
17 | max_pairs = 1000
18 | log_path = data/log/
19 | alpha = 1.0
20 | beta = 0.4
21 | sigma = 1.0
22 |
--------------------------------------------------------------------------------
/data/datasets/towe/16res/pos_vocab.txt:
--------------------------------------------------------------------------------
1 | special_token_size 1
2 | token_size 44
3 | singleton_size 0
4 | singleton_max_count 1
5 | 0 *PAD* 1
6 | 1 NN 3380
7 | 2 JJ 2407
8 | 3 DT 2164
9 | 4 IN 1507
10 | 5 . 1390
11 | 6 RB 1390
12 | 7 PRP 1005
13 | 8 CC 947
14 | 9 VBD 924
15 | 10 , 872
16 | 11 VBZ 675
17 | 12 NNS 674
18 | 13 VB 541
19 | 14 NNP 532
20 | 15 VBP 491
21 | 16 VBN 303
22 | 17 PRP$ 276
23 | 18 TO 192
24 | 19 MD 177
25 | 20 VBG 162
26 | 21 CD 145
27 | 22 HYPH 140
28 | 23 JJS 125
29 | 24 WDT 113
30 | 25 : 102
31 | 26 -RRB- 77
32 | 27 -LRB- 71
33 | 28 RP 61
34 | 29 JJR 49
35 | 30 WP 47
36 | 31 WRB 45
37 | 32 POS 43
38 | 33 SYM 41
39 | 34 UH 33
40 | 35 '' 26
41 | 36 EX 26
42 | 37 $ 25
43 | 38 PDT 25
44 | 39 RBS 19
45 | 40 RBR 16
46 | 41 FW 15
47 | 42 NFP 10
48 | 43 NNPS 7
49 | 44 `` 5
50 |
--------------------------------------------------------------------------------
/data/datasets/towe/16res/dep_type_vocab.txt:
--------------------------------------------------------------------------------
1 | special_token_size 2
2 | token_size 39
3 | singleton_size 0
4 | singleton_max_count 1
5 | 0 *PAD* 1
6 | 1 SELF_LOOP 1
7 | 2 punct 2685
8 | 3 nsubj 2107
9 | 4 det 2045
10 | 5 advmod 1452
11 | 6 ROOT 1408
12 | 7 amod 1393
13 | 8 case 1245
14 | 9 cop 1152
15 | 10 conj 1020
16 | 11 cc 933
17 | 12 compound 758
18 | 13 obj 734
19 | 14 obl 592
20 | 15 dep 547
21 | 16 nmod 485
22 | 17 mark 429
23 | 18 aux 392
24 | 19 nmod:poss 297
25 | 20 parataxis 207
26 | 21 advcl 207
27 | 22 ccomp 172
28 | 23 xcomp 168
29 | 24 nummod 104
30 | 25 aux:pass 103
31 | 26 acl:relcl 103
32 | 27 appos 91
33 | 28 nsubj:pass 84
34 | 29 acl 76
35 | 30 compound:prt 65
36 | 31 obl:npmod 35
37 | 32 obl:tmod 34
38 | 33 det:predet 27
39 | 34 fixed 27
40 | 35 discourse 25
41 | 36 csubj 23
42 | 37 expl 23
43 | 38 iobj 19
44 | 39 cc:preconj 6
45 | 40 csubj:pass 2
46 |
--------------------------------------------------------------------------------
/data/datasets/towe/16res/char_vocab.txt:
--------------------------------------------------------------------------------
1 | special_token_size 1
2 | token_size 59
3 | singleton_size 3
4 | singleton_max_count 1
5 | 0 *UNK* 1
6 | 1 e 10416
7 | 2 t 7362
8 | 3 a 7017
9 | 4 i 5843
10 | 5 o 5708
11 | 6 s 5465
12 | 7 r 4772
13 | 8 n 4667
14 | 9 h 3974
15 | 10 d 3312
16 | 11 l 3235
17 | 12 c 2385
18 | 13 u 2153
19 | 14 f 1874
20 | 15 w 1744
21 | 16 y 1744
22 | 17 m 1632
23 | 18 g 1631
24 | 19 p 1624
25 | 20 . 1400
26 | 21 b 1265
27 | 22 v 1193
28 | 23 , 834
29 | 24 k 598
30 | 25 z 284
31 | 26 ' 249
32 | 27 - 218
33 | 28 x 189
34 | 29 ! 184
35 | 30 j 126
36 | 31 q 86
37 | 32 ) 76
38 | 33 ( 71
39 | 34 – 40
40 | 35 0 36
41 | 36 : 25
42 | 37 $ 25
43 | 38 1 23
44 | 39 2 20
45 | 40 5 20
46 | 41 / 16
47 | 42 ’ 15
48 | 43 ; 15
49 | 44 3 12
50 | 45 8 10
51 | 46 4 9
52 | 47 & 8
53 | 48 ? 7
54 | 49 6 7
55 | 50 ` 6
56 | 51 7 5
57 | 52 * 5
58 | 53 9 4
59 | 54 + 3
60 | 55 % 2
61 | 56 é 2
62 | 57 @ 1
63 | 58 # 1
64 | 59 = 1
65 |
--------------------------------------------------------------------------------
/configs/16res_train.conf:
--------------------------------------------------------------------------------
1 | [1]
2 | label = towe_16res
3 | model_type = synfue
4 | model_path = bert-base-cased
5 | tokenizer_path = bert-base-cased
6 | train_path = data/datasets/towe/16res/train.json
7 | valid_path = data/datasets/towe/16res/dev.json
8 | types_path = data/datasets/towe/16res/types.json
9 | train_batch_size = 16
10 | eval_batch_size = 1
11 | neg_term_count = 100
12 | neg_relation_count = 100
13 | epochs = 20
14 | lr = 4e-5
15 | lr_warmup = 0.1
16 | weight_decay = 0.01
17 | max_grad_norm = 1.0
18 | rel_filter_threshold = 0.4
19 | size_embedding = 25
20 | prop_drop = 0.4
21 | max_span_size = 8
22 | store_predictions = true
23 | store_examples = true
24 | sampling_processes = 4
25 | sampling_limit = 100
26 | max_pairs = 800
27 | final_eval = false
28 | log_path = data/log/
29 | save_path = data/save/
30 | bert_dim = 768
31 | dep_dim = 100
32 | dep_num = 42
33 | pos_num = 45
34 | pos_dim = 100
35 | w_size = 5
36 | bert_dropout = 0.1
37 | output_dropout = 0.1
38 | num_layer = 2
39 | alpha = 1.0
40 | beta = 0.4
41 | sigma = 1.0
42 |
--------------------------------------------------------------------------------
/data/datasets/str_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import re
4 |
5 |
6 | def normalize_sent(sent):
7 | sent = str(sent).replace('``', '"')
8 | sent = str(sent).replace("''", '"')
9 | sent = str(sent).replace('-LRB-', '(')
10 | sent = str(sent).replace('-RRB-', ')')
11 | sent = str(sent).replace('-LSB-', '(')
12 | sent = str(sent).replace('-RSB-', ')')
13 | return sent
14 |
15 |
16 | def collapse_role_type(role_type):
17 | '''
18 | collapse role types from 36 to 28 following Bishan Yang 2016
19 | we also have to handle types like 'Beneficiary#Recipient'
20 | :param role_type:
21 | :return:
22 | '''
23 | if role_type.startswith('Time-'):
24 | return 'Time'
25 | idx = role_type.find('#')
26 | if idx != -1:
27 | role_type = role_type[:idx]
28 |
29 | return role_type
30 |
31 |
32 | def normalize_tok(tok, lower=False, normalize_digits=False):
33 |
34 | if lower:
35 | tok = tok.lower()
36 | if normalize_digits:
37 | tok = re.sub(r"\d", "0", tok)
38 | tok = re.sub(r"^(\d+[,])*(\d+)$", "0", tok)
39 | return tok
40 |
41 |
42 | def capitalize_first_char(sent):
43 | sent = str(sent[0]).upper() + sent[1:]
44 | return sent
45 |
46 |
--------------------------------------------------------------------------------
/synfue.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from args import train_argparser, eval_argparser
4 | from config_reader import process_configs
5 | from SynFue import input_reader
6 | from SynFue.trainer import SynFueTrainer
7 |
8 |
9 | def __train(run_args):
10 | trainer = SynFueTrainer(run_args)
11 | trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path,
12 | types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader)
13 |
14 |
15 | def _train():
16 | arg_parser = train_argparser()
17 | process_configs(target=__train, arg_parser=arg_parser)
18 |
19 |
20 | def __eval(run_args):
21 | trainer = SynFueTrainer(run_args)
22 | trainer.eval(dataset_path=run_args.dataset_path, types_path=run_args.types_path,
23 | input_reader_cls=input_reader.JsonInputReader)
24 |
25 |
26 | def _eval():
27 | arg_parser = eval_argparser()
28 | process_configs(target=__eval, arg_parser=arg_parser)
29 |
30 |
31 | if __name__ == '__main__':
32 | arg_parser = argparse.ArgumentParser(add_help=False)
33 | arg_parser.add_argument('train_mode', type=str, default='train', help="Mode: 'train' or 'eval'")
34 | args, _ = arg_parser.parse_known_args()
35 |
36 | if args.train_mode == 'train':
37 | _train()
38 | elif args.train_mode == 'eval':
39 | _eval()
40 | else:
41 | raise Exception("Mode not in ['train', 'eval'], e.g. 'python SynFue.py train ...'")
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Syntax fusion Encoder for the PAOTE task (Synfue)
2 | This repository implements the method described in the paper [Learn from Syntax: Improving Pair-wise Aspect and Opinion Terms Extraction with Rich Syntactic Knowledge](https://www.ijcai.org/proceedings/2021/0545.pdf)
3 |
4 |
5 | ## Prerequisite
6 | * [pytorch Library](https://pytorch.org/) (3.8.0)
7 | * [transformers](https://huggingface.co/transformers/model_doc/bert.html) (4.5.1)
8 | * [corenlp](https://stanfordnlp.github.io/CoreNLP/) (4.2)
9 | * torch (1.7.1)
10 | * numpy (1.20.2)
11 | * gensim (4.0.1)
12 | * pandas (1.2.4)
13 | * scikit-learn (0.24.1)
14 |
15 | ## Usage (by examples)
16 | ### Data
17 | Orignal data comes from [TOWE](https://github.com/NJUNLP/TOWE/tree/master/data).
18 |
19 |
20 | ### Preprocessing
21 | We need to obtain the dependency sturcture and POS tags for each data, and save as json format.
22 | Pay attention to the file path and modify as needed.
23 |
24 | #### Get Dependency and POS
25 | To parse the dependency structure and POS tags, we employ the [CoreNLP](https://stanfordnlp.github.io/CoreNLP/) provided by stanfordnlp.
26 | So please download relavant files first and put it in `data/datasets/orignal`.
27 | We use the NLTK package to obtain the dependency and POS parsing, so we need to modify the code as follows in `process.py` line 24:
28 | ```
29 | depparser = CoreNLPDependencyParser(url='http://127.0.0.1:9000')
30 | ```
31 | The url is set according to the real scenario.
32 |
33 | #### Save
34 | ```
35 | python preprocess.py
36 | ```
37 | The proposed data will be sotored in the dicretory `data/datasets/towe/`.
38 | We also provide some preprocessed examples.
39 |
40 | ### Train
41 | We use embedding bert-cased by [bert-base-cased](https://huggingface.co/bert-base-cased) (768d)
42 |
43 | ```
44 | python synfue.py train --config configs/16res_train.conf
45 | ```
46 | ### Test
47 | ```
48 | python synfue.py eval --config configs/16res_eval.conf
49 | ```
50 |
51 | ## Note
52 | this code refers to the [SpERT](https://github.com/lavis-nlp/spert)
53 |
--------------------------------------------------------------------------------
/SynFue/loss.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | import torch
3 |
4 |
5 | class Loss(ABC):
6 | def compute(self, *args, **kwargs):
7 | pass
8 |
9 |
10 | class SynFueLoss(Loss):
11 | def __init__(self, rel_criterion, term_criterion, model, optimizer, scheduler, max_grad_norm):
12 | self._rel_criterion = rel_criterion
13 | self._term_criterion = term_criterion
14 | self._model = model
15 | self._optimizer = optimizer
16 | self._scheduler = scheduler
17 | self._max_grad_norm = max_grad_norm
18 |
19 | def compute(self, term_logits, rel_logits, term_types, rel_types, term_sample_masks, rel_sample_masks):
20 | # term loss
21 | term_logits = term_logits.view(-1, term_logits.shape[-1])
22 | term_types = term_types.view(-1)
23 | term_sample_masks = term_sample_masks.view(-1).float()
24 |
25 | term_loss = self._term_criterion(term_logits, term_types)
26 | term_loss = (term_loss * term_sample_masks).sum() / term_sample_masks.sum()
27 |
28 | # relation loss
29 | rel_sample_masks = rel_sample_masks.view(-1).float()
30 | rel_count = rel_sample_masks.sum()
31 |
32 | if rel_count.item() != 0:
33 | rel_logits = rel_logits.view(-1, rel_logits.shape[-1])
34 | rel_types = rel_types.view(-1, rel_types.shape[-1])
35 |
36 | rel_loss = self._rel_criterion(rel_logits, rel_types)
37 | rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1]
38 | rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count
39 |
40 | # joint loss
41 | train_loss = term_loss + rel_loss
42 | else:
43 | # corner case: no positive/negative relation samples
44 | train_loss = term_loss
45 |
46 | train_loss.backward()
47 | torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._max_grad_norm)
48 | self._optimizer.step()
49 | self._scheduler.step()
50 | self._model.zero_grad()
51 | return train_loss.item()
52 |
--------------------------------------------------------------------------------
/config_reader.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import multiprocessing as mp
3 |
4 |
5 | def process_configs(target, arg_parser):
6 | args, _ = arg_parser.parse_known_args()
7 | ctx = mp.get_context('spawn')
8 |
9 | for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args):
10 | p = ctx.Process(target=target, args=(run_args,))
11 | p.start()
12 | p.join()
13 |
14 |
15 | def _read_config(path):
16 | lines = open(path).readlines()
17 |
18 | runs = []
19 | run = [1, dict()]
20 | for line in lines:
21 | stripped_line = line.strip()
22 |
23 | # continue in case of comment
24 | if stripped_line.startswith('#'):
25 | continue
26 |
27 | if not stripped_line:
28 | if run[1]:
29 | runs.append(run)
30 |
31 | run = [1, dict()]
32 | continue
33 |
34 | if stripped_line.startswith('[') and stripped_line.endswith(']'):
35 | repeat = int(stripped_line[1:-1])
36 | run[0] = repeat
37 | else:
38 | key, value = stripped_line.split('=')
39 | key, value = (key.strip(), value.strip())
40 | run[1][key] = value
41 |
42 | if run[1]:
43 | runs.append(run)
44 |
45 | return runs
46 |
47 |
48 | def _convert_config(config):
49 | config_list = []
50 | for k, v in config.items():
51 | if v.lower() == 'true':
52 | config_list.append('--' + k)
53 | elif v.lower() != 'false':
54 | config_list.extend(['--' + k] + v.split(' '))
55 |
56 | return config_list
57 |
58 |
59 | def _yield_configs(arg_parser, args, verbose=True):
60 | _print = (lambda x: print(x)) if verbose else lambda x: x
61 |
62 | if args.config:
63 | config = _read_config(args.config)
64 |
65 | for run_repeat, run_config in config:
66 | print("-" * 50)
67 | print("Config:")
68 | print(run_config)
69 |
70 | args_copy = copy.deepcopy(args)
71 | config_list = _convert_config(run_config)
72 | run_args = arg_parser.parse_args(config_list, namespace=args_copy)
73 | run_args_dict = vars(run_args)
74 |
75 | # set boolean values
76 | for k, v in run_config.items():
77 | if v.lower() == 'false':
78 | run_args_dict[k] = False
79 |
80 | print("Repeat %s times" % run_repeat)
81 | print("-" * 50)
82 |
83 | for iteration in range(run_repeat):
84 | _print("Iteration %s" % iteration)
85 | _print("-" * 50)
86 |
87 | yield run_args, run_config, run_repeat
88 |
89 | else:
90 | yield args, None, None
91 |
--------------------------------------------------------------------------------
/SynFue/cross_attn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 |
9 |
10 | class MultiHeadAttentionLayer(nn.Module):
11 | def __init__(self, hid_dim, n_heads, dropout, device):
12 | super().__init__()
13 |
14 | assert hid_dim % n_heads == 0
15 |
16 | self.hid_dim = hid_dim
17 | self.n_heads = n_heads
18 | self.head_dim = hid_dim // n_heads
19 |
20 | self.fc_q = nn.Linear(hid_dim, hid_dim)
21 | self.fc_k = nn.Linear(hid_dim, hid_dim)
22 | self.fc_v = nn.Linear(hid_dim, hid_dim)
23 |
24 | self.fc_o = nn.Linear(hid_dim, hid_dim)
25 |
26 | self.dropout = nn.Dropout(dropout)
27 |
28 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
29 |
30 | def forward(self, query, key, value, mask=None):
31 | batch_size = query.shape[0]
32 |
33 | # query = [batch size, query len, hid dim]
34 | # key = [batch size, key len, hid dim]
35 | # value = [batch size, value len, hid dim]
36 |
37 | Q = self.fc_q(query)
38 | K = self.fc_k(key)
39 | V = self.fc_v(value)
40 |
41 | # Q = [batch size, query len, hid dim]
42 | # K = [batch size, key len, hid dim]
43 | # V = [batch size, value len, hid dim]
44 |
45 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
46 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
47 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
48 |
49 | # Q = [batch size, n heads, query len, head dim]
50 | # K = [batch size, n heads, key len, head dim]
51 | # V = [batch size, n heads, value len, head dim]
52 |
53 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
54 |
55 | # energy = [batch size, n heads, query len, key len]
56 |
57 | if mask is not None:
58 | energy = energy.masked_fill(mask == 0, -1e10)
59 |
60 | attention = torch.softmax(energy, dim=-1)
61 |
62 | # attention = [batch size, n heads, query len, key len]
63 | x = torch.matmul(self.dropout(attention), V)
64 |
65 | # x = [batch size, n heads, query len, head dim]
66 |
67 | x = x.permute(0, 2, 1, 3).contiguous()
68 |
69 | # x = [batch size, query len, n heads, head dim]
70 |
71 | x = x.view(batch_size, -1, self.hid_dim)
72 |
73 | # x = [batch size, query len, hid dim]
74 |
75 | x = self.fc_o(x)
76 |
77 | # x = [batch size, query len, hid dim]
78 |
79 | return x, attention
80 |
81 |
82 | class PositionwiseFeedforwardLayer(nn.Module):
83 | def __init__(self, hid_dim, pf_dim, dropout):
84 | super().__init__()
85 |
86 | self.fc_1 = nn.Linear(hid_dim, pf_dim)
87 | self.fc_2 = nn.Linear(pf_dim, hid_dim)
88 |
89 | self.dropout = nn.Dropout(dropout)
90 |
91 | def forward(self, x):
92 | # x = [batch size, seq len, hid dim]
93 |
94 | x = self.dropout(torch.relu(self.fc_1(x)))
95 |
96 | # x = [batch size, seq len, pf dim]
97 |
98 | x = self.fc_2(x)
99 |
100 | # x = [batch size, seq len, hid dim]
101 |
102 | return x
103 |
104 |
105 | class CA_module(nn.Module):
106 | def __init__(self, hid_dim, pf_dim, n_heads, dropout):
107 | super(CA_module, self).__init__()
108 | self.n_hidden = hid_dim
109 | self.multi_head_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
110 | self.FFN = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
111 | self.layer_norm = nn.LayerNorm(hid_dim)
112 |
113 | def forward(self, inputs):
114 |
115 | z_hat = self.CR_2Datt(inputs)
116 | z = self.layer_norm(z_hat + inputs)
117 | outputs_hat = self.FFN(z)
118 | outputs = self.layer_norm(outputs_hat + z)
119 | return outputs
120 |
121 | def row_2Datt(self, inputs):
122 | max_len = inputs.size(1)
123 | z_row = inputs.reshape(-1, max_len, self.n_hidden)
124 | z_row, _ = self.multi_head_attention(z_row, z_row, z_row)
125 | return z_row.reshape(-1, max_len, max_len, self.n_hidden)
126 |
127 | def col_2Datt(self, inputs):
128 | z_col = inputs.permute(0, 2, 1, 3)
129 | z_col = self.row_2Datt(z_col)
130 | z_col = z_col.permute(0, 2, 1, 3)
131 | return z_col
132 |
133 | def CR_2Datt(self, inputs):
134 | z_row = self.row_2Datt(inputs)
135 | z_col = self.col_2Datt(inputs)
136 | outputs = (z_row + z_col) / 2.
137 | return outputs
--------------------------------------------------------------------------------
/data/datasets/io_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import numpy as np
5 |
6 | import pickle
7 | import json
8 | import yaml
9 | import sys
10 | import logging
11 |
12 |
13 | def read_json_lines(path):
14 | res = []
15 | for line in open(path, 'r'):
16 | res.append(json.loads(line))
17 | return res
18 |
19 |
20 | def read_yaml(path, encoding='utf-8'):
21 | with open(path, 'r', encoding=encoding) as file:
22 | return yaml.load(file.read())
23 |
24 |
25 | def read_lines(path, encoding='utf-8', return_list=False):
26 | with open(path, 'r', encoding=encoding) as file:
27 | if return_list:
28 | return file.readlines()
29 | for line in file:
30 | yield line.strip()
31 |
32 |
33 | def read_multi_line_sent(path, skip_line_start='-DOCSTART-'):
34 | sent_lines = []
35 | for line in read_lines(path):
36 | line = line.strip()
37 | if len(line) == 0 or line.startswith(skip_line_start):
38 | if sent_lines:
39 | yield sent_lines
40 | sent_lines = []
41 | continue
42 | sent_lines.append(line)
43 | if sent_lines:
44 | yield sent_lines
45 |
46 |
47 | def read_pickle(path):
48 | with open(path, 'rb') as file:
49 | return pickle.load(file)
50 |
51 |
52 | def save_pickle(path, obj):
53 | with open(path, 'wb') as file:
54 | pickle.dump(obj, file)
55 |
56 |
57 | def write_lines(txt_path, lines):
58 | with open(txt_path, 'w', encoding='utf-8') as file:
59 | for line in lines:
60 | file.write(line + '\n')
61 |
62 |
63 | def write_texts(txt_path, lines):
64 | with open(txt_path, 'w', encoding='utf-8') as file:
65 | for line in lines:
66 | file.write(line)
67 |
68 |
69 | def load_embed_txt(embed_file):
70 | """Load embed_file into a python dictionary.
71 | Note: the embed_file should be a Glove formated txt file. Assuming
72 | embed_size=5, for example:
73 | the -0.071549 0.093459 0.023738 -0.090339 0.056123
74 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037
75 | and 0.20327 0.47348 0.050877 0.002103 0.060547
76 | Args:
77 | embed_file: file path to the embedding file.
78 | Returns:
79 | a dictionary that maps word to vector, and the size of embedding dimensions.
80 | """
81 | emb_dict = dict()
82 | emb_size = None
83 | for line in read_lines(embed_file):
84 | tokens = line.strip().split(" ")
85 | word = tokens[0]
86 | vec = list(map(float, tokens[1:]))
87 | emb_dict[word] = vec
88 | if emb_size:
89 | assert emb_size == len(vec), "All embedding size should be same."
90 | else:
91 | emb_size = len(vec)
92 | return emb_dict, emb_size
93 |
94 |
95 | def txt_to_npy(dirname, fname, output_name):
96 | emb_dict, emb_size = load_embed_txt(os.path.join(dirname,fname))
97 | words = []
98 | vec = np.empty((len(emb_dict), emb_size))
99 | i = 0
100 | for word, vec_list in emb_dict.items():
101 | words.append(word)
102 | vec[i] = vec_list
103 | i += 1
104 | with open(os.path.join(dirname,output_name+'.vocab.txt'), 'w') as file:
105 | for word in words:
106 | file.write(word+'\n')
107 | np.save(os.path.join(dirname,output_name+'.npy'), vec)
108 |
109 |
110 | def get_logger(name, log_dir=None, log_name=None, file_model='a',
111 | level=logging.INFO, handler=sys.stdout,
112 | formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'):
113 | logger = logging.getLogger(name)
114 | logger.setLevel(logging.INFO)
115 | formatter = logging.Formatter(formatter)
116 | stream_handler = logging.StreamHandler(handler)
117 | stream_handler.setLevel(level)
118 | stream_handler.setFormatter(formatter)
119 | logger.addHandler(stream_handler)
120 | if log_dir and log_name:
121 | filename = os.path.join(log_dir, log_name)
122 | file_handler = logging.FileHandler(filename, encoding='utf-8', mode=file_model)
123 | file_handler.setLevel(level)
124 | file_handler.setFormatter(formatter)
125 | logger.addHandler(file_handler)
126 |
127 | return logger
128 |
129 |
130 | def trigger_tag_to_list(trigger_tags, O_idx):
131 | trigger_list = []
132 | for i, trigger_type in enumerate(trigger_tags):
133 | if trigger_type != O_idx:
134 | trigger_list.append([i, trigger_type])
135 | return trigger_list
136 |
137 |
138 | def relative_position(ent_start, ent_end, tok_idx, max_position_len=150):
139 | if ent_start <= tok_idx <= ent_end:
140 | return 0
141 | elif tok_idx < ent_start:
142 | return ent_start - tok_idx
143 | elif tok_idx > ent_end:
144 | return tok_idx - ent_end + max_position_len
145 |
146 | return None
147 |
148 |
149 | def to_set(*list_vars):
150 | sets = []
151 | for list_var in list_vars:
152 | sets.append({tuple(sub_list) for sub_list in list_var})
153 | return sets
--------------------------------------------------------------------------------
/SynFue/Encoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 |
9 | class LabelAwareGCN(nn.Module):
10 | """
11 | Simple GCN layer
12 | """
13 | def __init__(self, dep_dim, in_features, out_features, pos_dim=None, bias=True):
14 | super(LabelAwareGCN, self).__init__()
15 | self.dep_dim = dep_dim
16 | self.pos_dim = pos_dim
17 | self.in_features = in_features
18 | self.out_features = out_features
19 |
20 | self.dep_attn = nn.Linear(dep_dim + pos_dim + in_features, out_features)
21 | self.dep_fc = nn.Linear(dep_dim, out_features)
22 | self.pos_fc = nn.Linear(pos_dim, out_features)
23 |
24 | def forward(self, text, adj, dep_embed, pos_embed=None):
25 | """
26 |
27 | :param text: [batch size, seq_len, feat_dim]
28 | :param adj: [batch size, seq_len, seq_len]
29 | :param dep_embed: [batch size, seq_len, seq_len, dep_type_dim]
30 | :param pos_embed: [batch size, seq_len, pos_dim]
31 | :return: [batch size, seq_len, feat_dim]
32 | """
33 | batch_size, seq_len, feat_dim = text.shape
34 |
35 | val_us = text.unsqueeze(dim=2)
36 | val_us = val_us.repeat(1, 1, seq_len, 1)
37 | pos_us = pos_embed.unsqueeze(dim=2).repeat(1, 1, seq_len, 1)
38 | # [batch size, seq_len, seq_len, feat_dim+pos_dim+dep_dim]
39 | val_sum = torch.cat([val_us, pos_us, dep_embed], dim=-1)
40 |
41 | r = self.dep_attn(val_sum)
42 |
43 | p = torch.sum(r, dim=-1)
44 | mask = (adj == 0).float() * (-1e30)
45 | p = p + mask
46 | p = torch.softmax(p, dim=2)
47 | p_us = p.unsqueeze(3).repeat(1, 1, 1, feat_dim)
48 |
49 | output = val_us + self.pos_fc(pos_us) + self.dep_fc(dep_embed)
50 | output = torch.mul(p_us, output)
51 |
52 | output_sum = torch.sum(output, dim=2)
53 |
54 | return r, output_sum, p
55 |
56 |
57 | class nLaGCN(nn.Module):
58 | def __init__(self, opt):
59 | super(nLaGCN, self).__init__()
60 | self.opt = opt
61 | self.model = nn.ModuleList([LabelAwareGCN(opt.dep_dim, opt.bert_dim,
62 | opt.bert_dim, opt.pos_dim)
63 | for i in range(self.opt.num_layer)])
64 | self.dep_embedding = nn.Embedding(opt.dep_num, opt.dep_dim, padding_idx=0)
65 |
66 | def forward(self, x, simple_graph, graph, pos_embed=None, output_attention=False):
67 |
68 | dep_embed = self.dep_embedding(graph)
69 |
70 | attn_list = []
71 | for lagcn in self.model:
72 | r, x, attn = lagcn(x, simple_graph, dep_embed, pos_embed=pos_embed)
73 | attn_list.append(attn)
74 |
75 | if output_attention is True:
76 | return x, r, attn_list
77 | else:
78 | return x, r
79 |
80 |
81 | class SynFueEncoder(nn.Module):
82 | def __init__(self, opt):
83 | super(SynFueEncoder, self).__init__()
84 | self.opt = opt
85 | self.lagcn = nLaGCN(opt)
86 |
87 | self.fc = nn.Linear(opt.bert_dim*2 + opt.pos_dim, opt.bert_dim*2)
88 | self.output_dropout = nn.Dropout(opt.output_dropout)
89 |
90 | self.pod_embedding = nn.Embedding(opt.pos_num, opt.pos_dim, padding_idx=0)
91 |
92 | def forward(self, word_reps, simple_graph, graph, pos=None, output_attention=False):
93 | """
94 |
95 | :param word_reps: [B, L, H]
96 | :param simple_graph: [B, L, L]
97 | :param graph: [B, L, L]
98 | :param pos: [B, L]
99 | :param output_attention: bool
100 | :return:
101 | output: [B, L, H]
102 | dep_reps: [B, L, H]
103 | cls_reps: [B, H]
104 | """
105 |
106 | pos_embed = self.pod_embedding(pos)
107 |
108 | lagcn_output = self.lagcn(word_reps, simple_graph, graph, pos_embed, output_attention)
109 |
110 | pos_output = self.local_attn(word_reps, pos_embed, self.opt.num_layer, self.opt.w_size)
111 |
112 | output = torch.cat((lagcn_output[0], pos_output, word_reps), dim=-1)
113 | output = self.fc(output)
114 | output = self.output_dropout(output)
115 | return output, lagcn_output[1]
116 |
117 | def local_attn(self, x, pos_embed, num_layer, w_size):
118 | """
119 |
120 | :param x:
121 | :param pos_embed:
122 | :return:
123 | """
124 | batch_size, seq_len, feat_dim = x.shape
125 | pos_dim = pos_embed.size(-1)
126 | output = pos_embed
127 | for i in range(num_layer):
128 | val_sum = torch.cat([x, output], dim=-1) # [batch size, seq_len, feat_dim+pos_dim]
129 | attn = torch.matmul(val_sum, val_sum.transpose(1, 2)) # [batch size, seq_len, seq_len]
130 | # pad size = seq_len + (window_size - 1) // 2 * 2
131 | pad_size = seq_len + w_size * 2
132 | mask = torch.zeros((batch_size, seq_len, pad_size), dtype=torch.float).to(
133 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
134 | for i in range(seq_len):
135 | mask[:, i, i:i + w_size] = 1.0
136 | pad_attn = torch.full((batch_size, seq_len, w_size), -1e18).to(
137 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
138 | attn = torch.cat([pad_attn, attn, pad_attn], dim=-1)
139 | local_attn = torch.softmax(torch.mul(attn, mask), dim=-1)
140 | local_attn = local_attn[:, :, w_size:pad_size - w_size] # [batch size, seq_len, seq_len]
141 | local_attn = local_attn.unsqueeze(dim=3).repeat(1, 1, 1, pos_dim)
142 | output = output.unsqueeze(dim=2).repeat(1, 1, seq_len, 1)
143 | output = torch.sum(torch.mul(output, local_attn), dim=2) # [batch size, seq_len, pos_dim]
144 | return output
145 |
--------------------------------------------------------------------------------
/SynFue/base_trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import logging
4 | import os
5 | import sys
6 | from typing import List, Dict, Tuple
7 |
8 | import torch
9 | from torch.nn import DataParallel
10 | from torch.optim import optimizer
11 | from transformers import PreTrainedModel
12 | from transformers import PreTrainedTokenizer
13 |
14 | from SynFue import util
15 | from SynFue.opt import tensorboardx
16 |
17 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
18 |
19 |
20 | class BaseTrainer:
21 | """ Trainer base class with common methods """
22 |
23 | def __init__(self, args: argparse.Namespace):
24 | self.args = args
25 | self._debug = self.args.debug
26 |
27 | # logging
28 | name = str(datetime.datetime.now()).replace(' ', '_')
29 | self._log_path = os.path.join(self.args.log_path, self.args.label, name)
30 | util.create_directories_dir(self._log_path)
31 |
32 | if hasattr(args, 'save_path'):
33 | self._save_path = os.path.join(self.args.save_path, self.args.label)
34 | util.create_directories_dir(self._save_path)
35 |
36 | self._log_paths = dict()
37 |
38 | # file + console logging
39 | log_formatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s")
40 | self._logger = logging.getLogger()
41 | util.reset_logger(self._logger)
42 |
43 | file_handler = logging.FileHandler(os.path.join(self._log_path, 'all.log'))
44 | file_handler.setFormatter(log_formatter)
45 | self._logger.addHandler(file_handler)
46 |
47 | console_handler = logging.StreamHandler(sys.stdout)
48 | console_handler.setFormatter(log_formatter)
49 | self._logger.addHandler(console_handler)
50 |
51 | if self._debug:
52 | self._logger.setLevel(logging.DEBUG)
53 | else:
54 | self._logger.setLevel(logging.INFO)
55 |
56 | # tensorboard summary
57 | self._summary_writer = tensorboardx.SummaryWriter(self._log_path) if tensorboardx is not None else None
58 |
59 | self._best_results = dict()
60 | self._log_arguments()
61 |
62 | # CUDA devices
63 | self._device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
64 | self._gpu_count = torch.cuda.device_count()
65 |
66 | # set seed
67 | if args.seed is not None:
68 | util.set_seed(args.seed)
69 |
70 | def _add_dataset_logging(self, *labels, data: Dict[str, List[str]]):
71 | for label in labels:
72 | dic = dict()
73 |
74 | for key, columns in data.items():
75 | path = os.path.join(self._log_path, '%s_%s.csv' % (key, label))
76 | util.create_csv(path, *columns)
77 | dic[key] = path
78 |
79 | self._log_paths[label] = dic
80 | self._best_results[label] = 0
81 |
82 | def _log_arguments(self):
83 | util.save_dict(self._log_path, self.args, 'args')
84 | if self._summary_writer is not None:
85 | util.summarize_dict(self._summary_writer, self.args, 'args')
86 |
87 | def _log_tensorboard(self, dataset_label: str, data_label: str, data: object, iteration: int):
88 | if self._summary_writer is not None:
89 | self._summary_writer.add_scalar('data/%s/%s' % (dataset_label, data_label), data, iteration)
90 |
91 | def _log_csv(self, dataset_label: str, data_label: str, *data: Tuple[object]):
92 | logs = self._log_paths[dataset_label]
93 | util.append_csv(logs[data_label], *data)
94 |
95 | def _save_best(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, optimizer: optimizer,
96 | accuracy: float, iteration: int, label: str, extra=None):
97 | if accuracy > self._best_results[label]:
98 | self._logger.info("[%s] Best model in iteration %s: %s%% accuracy" % (label, iteration, accuracy))
99 | self._save_model(self._save_path, model, tokenizer, iteration,
100 | optimizer=optimizer if self.args.save_optimizer else None,
101 | save_as_best=True, name='model_%s' % label, extra=extra)
102 | self._best_results[label] = accuracy
103 |
104 | def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
105 | iteration: int, optimizer: optimizer = None, save_as_best: bool = False,
106 | extra: dict = None, include_iteration: int = True, name: str = 'model'):
107 | extra_state = dict(iteration=iteration)
108 |
109 | if optimizer:
110 | extra_state['optimizer'] = optimizer.state_dict()
111 |
112 | if extra:
113 | extra_state.update(extra)
114 |
115 | # if save_as_best:
116 | # dir_path = os.path.join(save_path, '%s_best' % name)
117 | # else:
118 | # dir_name = '%s_%s' % (name, iteration) if include_iteration else name
119 | # dir_path = os.path.join(save_path, dir_name)
120 |
121 | util.create_directories_dir(save_path)
122 |
123 | # save model
124 | if isinstance(model, DataParallel):
125 | model.module.save_pretrained(save_path)
126 | else:
127 | model.save_pretrained(save_path)
128 |
129 | # save vocabulary
130 | tokenizer.save_pretrained(save_path)
131 |
132 | # save extra
133 | state_path = os.path.join(save_path, 'extra.state')
134 | torch.save(extra_state, state_path)
135 |
136 | # def _load_model(self, save_path: str):
137 | #
138 | # model = torch.load()
139 |
140 | def _get_lr(self, optimizer):
141 | lrs = []
142 | for group in optimizer.param_groups:
143 | lr_scheduled = group['lr']
144 | lrs.append(lr_scheduled)
145 | return lrs
146 |
147 | def _close_summary_writer(self):
148 | if self._summary_writer is not None:
149 | self._summary_writer.close()
150 |
--------------------------------------------------------------------------------
/data/log/towe_16res/eval_valid.csv:
--------------------------------------------------------------------------------
1 | ner_prec_micro;ner_rec_micro;ner_f1_micro;ner_prec_macro;ner_rec_macro;ner_f1_macro;rel_prec_micro;rel_rec_micro;rel_f1_micro;rel_prec_macro;rel_rec_macro;rel_f1_macro;rel_nec_prec_micro;rel_nec_rec_micro;rel_nec_f1_micro;rel_nec_prec_macro;rel_nec_rec_macro;rel_nec_f1_macro;epoch;iteration;global_iteration
2 | 64.01326699834162;82.30277185501066;72.01492537313432;64.17537970504073;82.15470141071893;72.02142739668096;36.92307692307693;60.0;45.714285714285715;36.92307692307693;60.0;45.714285714285715;36.92307692307693;60.0;45.714285714285715;36.92307692307693;60.0;45.714285714285715;1;0;269
3 | 82.13891951488424;79.42430703624733;80.75880758807588;82.65960576640188;79.52273936956651;80.78248656685518;66.46216768916156;62.5;64.42021803766104;66.46216768916156;62.5;64.42021803766104;66.46216768916156;62.5;64.42021803766104;66.46216768916156;62.5;64.42021803766104;2;0;538
4 | 82.32804232804233;82.94243070362474;82.63409453000531;82.28590291750504;82.86574741716974;82.56053056751291;67.3469387755102;69.8076923076923;68.55524079320115;67.3469387755102;69.8076923076923;68.55524079320115;67.3469387755102;69.8076923076923;68.55524079320115;67.3469387755102;69.8076923076923;68.55524079320115;3;0;807
5 | 81.32530120481928;86.35394456289978;83.76421923474663;81.4069844856457;86.38321876833912;83.76537823178138;64.321608040201;73.84615384615385;68.75559534467322;64.321608040201;73.84615384615385;68.75559534467322;64.321608040201;73.84615384615385;68.75559534467322;64.321608040201;73.84615384615385;68.75559534467322;4;0;1076
6 | 73.52185089974293;91.47121535181236;81.52019002375297;73.49442269770728;91.40557827647544;81.46932857501848;49.467455621301774;80.38461538461539;61.24542124542125;49.467455621301774;80.38461538461539;61.24542124542125;49.467455621301774;80.38461538461539;61.24542124542125;49.467455621301774;80.38461538461539;61.24542124542125;5;0;1345
7 | 79.41747572815534;87.20682302771856;83.13008130081302;79.40115154463084;87.1875241678305;83.10578512396695;63.8801261829653;77.88461538461539;70.19064124783363;63.8801261829653;77.88461538461539;70.19064124783363;63.8801261829653;77.88461538461539;70.19064124783363;63.8801261829653;77.88461538461539;70.19064124783363;6;0;1614
8 | 82.12512413108243;88.16631130063965;85.03856041131107;82.18998350577297;88.17220688117844;85.03409992069786;72.62969588550983;78.07692307692308;75.25486561631139;72.62969588550983;78.07692307692308;75.25486561631139;72.62969588550983;78.07692307692308;75.25486561631139;72.62969588550983;78.07692307692308;75.25486561631139;7;0;1883
9 | 78.9622641509434;89.23240938166312;83.78378378378379;79.18097806121175;89.2444624392108;83.81268288772196;58.55354659248957;80.96153846153847;67.95803066989508;58.55354659248957;80.96153846153847;67.95803066989508;58.55354659248957;80.96153846153847;67.95803066989508;58.55354659248957;80.96153846153847;67.95803066989508;8;0;2152
10 | 79.73484848484848;89.76545842217483;84.45336008024073;79.83453870035713;89.76421295896132;84.45565830456094;65.8267716535433;80.38461538461539;72.3809523809524;65.8267716535433;80.38461538461539;72.3809523809524;65.8267716535433;80.38461538461539;72.3809523809524;65.8267716535433;80.38461538461539;72.3809523809524;9;0;2421
11 | 82.72357723577237;86.78038379530916;84.70343392299688;82.89277220556302;86.78264192487387;84.71655535053753;68.62068965517241;76.53846153846153;72.36363636363636;68.62068965517241;76.53846153846153;72.36363636363636;68.62068965517241;76.53846153846153;72.36363636363636;68.62068965517241;76.53846153846153;72.36363636363636;10;0;2690
12 | 80.3073967339097;89.12579957356077;84.48711470439616;80.37625078462504;89.11867598957315;84.48152596659804;65.60509554140127;79.23076923076923;71.77700348432056;65.60509554140127;79.23076923076923;71.77700348432056;65.60509554140127;79.23076923076923;71.77700348432056;65.60509554140127;79.23076923076923;71.77700348432056;11;0;2959
13 | 81.21974830590513;89.4456289978678;85.13444951801115;81.2567949554251;89.43052630142346;85.12008221028701;68.90756302521008;78.84615384615384;73.54260089686098;68.90756302521008;78.84615384615384;73.54260089686098;68.90756302521008;78.84615384615384;73.54260089686098;68.90756302521008;78.84615384615384;73.54260089686098;12;0;3228
14 | 78.87453874538745;91.15138592750533;84.56973293768546;78.98414482535465;91.13194156957833;84.57265291746566;63.813813813813816;81.73076923076923;71.66947723440134;63.813813813813816;81.73076923076923;71.66947723440134;63.813813813813816;81.73076923076923;71.66947723440134;63.813813813813816;81.73076923076923;71.66947723440134;13;0;3497
15 | 82.14990138067061;88.80597014925374;85.34836065573771;82.21549895908322;88.7849893320353;85.33761652809271;72.7112676056338;79.42307692307692;75.91911764705883;72.7112676056338;79.42307692307692;75.91911764705883;72.7112676056338;79.42307692307692;75.91911764705883;72.7112676056338;79.42307692307692;75.91911764705883;14;0;3766
16 | 80.3639846743295;89.4456289978678;84.6619576185671;80.4291802953712;89.43598538784534;84.65509657133082;69.24342105263158;80.96153846153847;74.64539007092199;69.24342105263158;80.96153846153847;74.64539007092199;69.24342105263158;80.96153846153847;74.64539007092199;69.24342105263158;80.96153846153847;74.64539007092199;15;0;4035
17 | 81.16504854368932;89.12579957356077;84.95934959349594;81.16504854368931;89.09138055746371;84.92984282810254;69.49152542372882;78.84615384615384;73.87387387387386;69.49152542372882;78.84615384615384;73.87387387387386;69.49152542372882;78.84615384615384;73.87387387387386;69.49152542372882;78.84615384615384;73.87387387387386;16;0;4304
18 | 82.66129032258065;87.42004264392324;84.97409326424871;82.6448027901645;87.38996528930883;84.94307196562835;73.49177330895796;77.3076923076923;75.35145267104029;73.49177330895796;77.3076923076923;75.35145267104029;73.49177330895796;77.3076923076923;75.35145267104029;73.49177330895796;77.3076923076923;75.35145267104029;17;0;4573
19 | 82.73453093812375;88.37953091684435;85.4639175257732;82.80607364897179;88.36918891623486;85.45676385468668;72.15411558669001;79.23076923076923;75.52703941338221;72.15411558669001;79.23076923076923;75.52703941338221;72.15411558669001;79.23076923076923;75.52703941338221;72.15411558669001;79.23076923076923;75.52703941338221;18;0;4842
20 | 82.7037773359841;88.69936034115139;85.59670781893004;82.73430490171665;88.67012105524141;85.57540858667117;72.63157894736842;79.61538461538461;75.96330275229359;72.63157894736842;79.61538461538461;75.96330275229359;72.63157894736842;79.61538461538461;75.96330275229359;72.63157894736842;79.61538461538461;75.96330275229359;19;0;5111
21 | 82.76892430278885;88.59275053304904;85.58187435633367;82.7915236413421;88.56071186486942;85.5577241291527;73.00177619893428;79.03846153846153;75.90027700831024;73.00177619893428;79.03846153846153;75.90027700831024;73.00177619893428;79.03846153846153;75.90027700831024;73.00177619893428;79.03846153846153;75.90027700831024;20;0;5380
22 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def _add_common_args(arg_parser):
5 | arg_parser.add_argument('--config', type=str)
6 |
7 | # Input
8 | arg_parser.add_argument('--types_path', type=str, help="Path to type specifications")
9 |
10 | # Preprocessing
11 | arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer")
12 | arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans")
13 | arg_parser.add_argument('--lowercase', action='store_true', default=False,
14 | help="If true, input is lowercased during preprocessing")
15 | arg_parser.add_argument('--sampling_processes', type=int, default=4,
16 | help="Number of sampling processes. 0 = no multiprocessing for sampling")
17 | arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue")
18 |
19 | # Logging
20 | arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models")
21 | arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored")
22 | arg_parser.add_argument('--store_predictions', action='store_true', default=False,
23 | help="If true, store predictions on disc (in log directory)")
24 | arg_parser.add_argument('--store_examples', action='store_true', default=False,
25 | help="If true, store evaluation examples on disc (in log directory)")
26 | arg_parser.add_argument('--example_count', type=int, default=None,
27 | help="Count of evaluation example to store (if store_examples == True)")
28 | arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off")
29 |
30 | # Model / Training / Evaluation
31 | arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints")
32 | arg_parser.add_argument('--model_type', type=str, default="SynFue", help="Type of model")
33 | arg_parser.add_argument('--cpu', action='store_true', default=False,
34 | help="If true, train/evaluate on CPU even if a CUDA device is available")
35 | arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size")
36 | arg_parser.add_argument('--max_pairs', type=int, default=1000,
37 | help="Maximum term pairs to process during training/evaluation")
38 | arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations")
39 | arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding")
40 | arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT")
41 | arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights")
42 | arg_parser.add_argument('--no_overlapping', action='store_true', default=False,
43 | help="If true, do not evaluate on overlapping terms "
44 | "and relations with overlapping terms")
45 |
46 | # LaGCN
47 | arg_parser.add_argument('--bert_dim', type=int, default=768)
48 | arg_parser.add_argument('--dep_dim', type=int, default=100)
49 | arg_parser.add_argument('--bert_dropout', type=float, default=0.5)
50 | arg_parser.add_argument('--output_dropout', type=float, default=0.5)
51 | arg_parser.add_argument('--dep_num', type=int, default=42) # 40 + 1(None) + 1(self-loop)
52 | arg_parser.add_argument('--num_layer', type=int, default=3)
53 |
54 | # Misc
55 | arg_parser.add_argument('--seed', type=int, default=58986, help="Seed")
56 | arg_parser.add_argument('--cache_path', type=str, default=None,
57 | help="Path to cache transformer models (for HuggingFace transformers library)")
58 |
59 | arg_parser.add_argument('--beta', type=float, default=0.3, help='weight for triaffine')
60 | arg_parser.add_argument('--alpha', type=float, default=0.3, help='weight for biaffine')
61 | arg_parser.add_argument('--sigma', type=float, default=0.3, help='weight for syntactic-aware score')
62 |
63 | # pos
64 | # arg_parser.add_argument('--use_pos', type=bool, default=True)
65 | arg_parser.add_argument('--pos_num', type=int, default=45)
66 | arg_parser.add_argument('--pos_dim', type=int, default=100)
67 | arg_parser.add_argument('--w_size', type=int, default=1, help='the window size of local attention')
68 |
69 |
70 | def train_argparser():
71 | arg_parser = argparse.ArgumentParser()
72 |
73 | # Input
74 | arg_parser.add_argument('--train_path', type=str, help="Path to train dataset")
75 | arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset")
76 |
77 | # Logging
78 | arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored")
79 | arg_parser.add_argument('--init_eval', action='store_true', default=False,
80 | help="If true, evaluate validation set before training")
81 | arg_parser.add_argument('--save_optimizer', action='store_true', default=False,
82 | help="Save optimizer alongside model")
83 | arg_parser.add_argument('--train_log_iter', type=int, default=1, help="Log training process every x iterations")
84 | arg_parser.add_argument('--final_eval', action='store_true', default=False,
85 | help="Evaluate the model only after training, not at every epoch")
86 |
87 | # Model / Training
88 | arg_parser.add_argument('--train_batch_size', type=int, default=2, help="Training batch size")
89 | arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs")
90 | arg_parser.add_argument('--neg_term_count', type=int, default=100,
91 | help="Number of negative term samples per document (sentence)")
92 | arg_parser.add_argument('--neg_relation_count', type=int, default=100,
93 | help="Number of negative relation samples per document (sentence)")
94 | arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate")
95 | arg_parser.add_argument('--lr_warmup', type=float, default=0.1,
96 | help="Proportion of total train iterations to warmup in linear increase/decrease schedule")
97 | arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply")
98 | arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm")
99 |
100 | # arg_parser.add_argument('--beta', type=float, default=0.3, help='weight for triaffine')
101 |
102 | _add_common_args(arg_parser)
103 |
104 | return arg_parser
105 |
106 |
107 | def eval_argparser():
108 | arg_parser = argparse.ArgumentParser()
109 |
110 | # Input
111 | arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset")
112 |
113 | _add_common_args(arg_parser)
114 |
115 | return arg_parser
116 |
--------------------------------------------------------------------------------
/SynFue/util.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 | import random
5 | import shutil
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from SynFue.terms import TokenSpan
11 |
12 | CSV_DELIMETER = ';'
13 |
14 |
15 | def create_directories_file(f):
16 | d = os.path.dirname(f)
17 |
18 | if d and not os.path.exists(d):
19 | os.makedirs(d)
20 |
21 | return f
22 |
23 |
24 | def create_directories_dir(d):
25 | if d and not os.path.exists(d):
26 | os.makedirs(d)
27 |
28 | return d
29 |
30 |
31 | def create_csv(file_path, *column_names):
32 | if not os.path.exists(file_path):
33 | with open(file_path, 'w', newline='') as csv_file:
34 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL)
35 |
36 | if column_names:
37 | writer.writerow(column_names)
38 |
39 |
40 | def append_csv(file_path, *row):
41 | if not os.path.exists(file_path):
42 | raise Exception("The given file doesn't exist")
43 |
44 | with open(file_path, 'a', newline='') as csv_file:
45 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL)
46 | writer.writerow(row)
47 |
48 |
49 | def append_csv_multiple(file_path, *rows):
50 | if not os.path.exists(file_path):
51 | raise Exception("The given file doesn't exist")
52 |
53 | with open(file_path, 'a', newline='') as csv_file:
54 | writer = csv.writer(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL)
55 | for row in rows:
56 | writer.writerow(row)
57 |
58 |
59 | def read_csv(file_path):
60 | lines = []
61 | with open(file_path, 'r') as csv_file:
62 | reader = csv.reader(csv_file, delimiter=CSV_DELIMETER, quotechar='|', quoting=csv.QUOTE_MINIMAL)
63 | for row in reader:
64 | lines.append(row)
65 |
66 | return lines[0], lines[1:]
67 |
68 |
69 | def copy_python_directory(source, dest, ignore_dirs=None):
70 | source = source if source.endswith('/') else source + '/'
71 | for (dir_path, dir_names, file_names) in os.walk(source):
72 | tail = '/'.join(dir_path.split(source)[1:])
73 | new_dir = os.path.join(dest, tail)
74 |
75 | if ignore_dirs and True in [(ignore_dir in tail) for ignore_dir in ignore_dirs]:
76 | continue
77 |
78 | create_directories_dir(new_dir)
79 |
80 | for file_name in file_names:
81 | if file_name.endswith('.py'):
82 | file_path = os.path.join(dir_path, file_name)
83 | shutil.copy2(file_path, new_dir)
84 |
85 |
86 | def save_dict(log_path, dic, name):
87 | # save arguments
88 | # 1. as json
89 | path = os.path.join(log_path, '%s.json' % name)
90 | f = open(path, 'w')
91 | json.dump(vars(dic), f)
92 | f.close()
93 |
94 | # 2. as string
95 | path = os.path.join(log_path, '%s.txt' % name)
96 | f = open(path, 'w')
97 | args_str = ["%s = %s" % (key, value) for key, value in vars(dic).items()]
98 | f.write('\n'.join(args_str))
99 | f.close()
100 |
101 |
102 | def summarize_dict(summary_writer, dic, name):
103 | table = 'Argument|Value\n-|-'
104 |
105 | for k, v in vars(dic).items():
106 | row = '\n%s|%s' % (k, v)
107 | table += row
108 | summary_writer.add_text(name, table)
109 |
110 |
111 | def set_seed(seed):
112 | random.seed(seed)
113 | np.random.seed(seed)
114 | torch.manual_seed(seed)
115 | torch.cuda.manual_seed_all(seed)
116 |
117 |
118 | def reset_logger(logger):
119 | for handler in logger.handlers[:]:
120 | logger.removeHandler(handler)
121 |
122 | for f in logger.filters[:]:
123 | logger.removeFilters(f)
124 |
125 |
126 | def flatten(l):
127 | return [i for p in l for i in p]
128 |
129 |
130 | def get_as_list(dic, key):
131 | if key in dic:
132 | return [dic[key]]
133 | else:
134 | return []
135 |
136 |
137 | def extend_tensor(tensor, extended_shape, fill=0):
138 | tensor_shape = tensor.shape
139 |
140 | extended_tensor = torch.zeros(extended_shape, dtype=tensor.dtype).to(tensor.device)
141 | extended_tensor = extended_tensor.fill_(fill)
142 |
143 | if len(tensor_shape) == 1:
144 | extended_tensor[:tensor_shape[0]] = tensor
145 | elif len(tensor_shape) == 2:
146 | extended_tensor[:tensor_shape[0], :tensor_shape[1]] = tensor
147 | elif len(tensor_shape) == 3:
148 | extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2]] = tensor
149 | elif len(tensor_shape) == 4:
150 | extended_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2], :tensor_shape[3]] = tensor
151 |
152 | return extended_tensor
153 |
154 |
155 | def padded_stack(tensors, padding=0):
156 | dim_count = len(tensors[0].shape)
157 |
158 | max_shape = [max([t.shape[d] for t in tensors]) for d in range(dim_count)]
159 | padded_tensors = []
160 |
161 | for t in tensors:
162 | e = extend_tensor(t, max_shape, fill=padding)
163 | padded_tensors.append(e)
164 |
165 | stacked = torch.stack(padded_tensors)
166 | return stacked
167 |
168 |
169 | def batch_index(tensor, index, pad=False):
170 | if tensor.shape[0] != index.shape[0]:
171 | raise Exception()
172 |
173 | if not pad:
174 | return torch.stack([tensor[i][index[i]] for i in range(index.shape[0])])
175 | else:
176 | return padded_stack([tensor[i][index[i]] for i in range(index.shape[0])])
177 |
178 |
179 | def padded_nonzero(tensor, padding=0):
180 | indices = padded_stack([tensor[i].nonzero().view(-1) for i in range(tensor.shape[0])], padding)
181 | return indices
182 |
183 |
184 | def swap(v1, v2):
185 | return v2, v1
186 |
187 |
188 | def get_span_tokens(tokens, span):
189 | inside = False
190 | span_tokens = []
191 |
192 | for t in tokens:
193 | if t.span[0] == span[0]:
194 | inside = True
195 |
196 | if inside:
197 | span_tokens.append(t)
198 |
199 | if inside and t.span[1] == span[1]:
200 | return TokenSpan(span_tokens)
201 |
202 | return None
203 |
204 |
205 | def to_device(batch, device):
206 | converted_batch = dict()
207 | for key in batch.keys():
208 | converted_batch[key] = batch[key].to(device)
209 |
210 | return converted_batch
211 |
212 |
213 | def check_version(config, model_class, model_path):
214 | if os.path.exists(model_path):
215 | model_path = model_path if model_path.endswith('.bin') else os.path.join(model_path, 'pytorch_model.bin')
216 | state_dict = torch.load(model_path, map_location=torch.device('cpu'))
217 | config_dict = config.to_dict()
218 |
219 | # version check
220 | loaded_version = config_dict.get('spert_version', '1.2')
221 | if 'rel_classifier.weight' in state_dict and loaded_version != model_class.VERSION:
222 | msg = ("Current SpERT version (%s) does not match the version of the loaded model (%s). "
223 | % (model_class.VERSION, loaded_version))
224 | msg += "Use the code matching your version or train a new model."
225 | raise Exception(msg)
226 |
227 |
228 |
--------------------------------------------------------------------------------
/SynFue/templates/term_examples.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Entity Extraction Examples
6 |
7 |
9 |
10 |
11 |
12 |
104 |
105 |
106 |
107 | Entity Extraction Examples ({{ examples|length }})
108 |
109 |
110 |
111 | check_circle_outline F1 = 100.000
112 | check_circle_outline F1 >= 50.00
113 | highlight_off F1 < 50.00
114 |
115 |
116 |
117 |
118 | True Positive
119 |
False Positive
120 |
False Negative
121 |
122 |
123 |
124 |
125 |
126 |
127 | {% for example in examples %}
128 | {% set outer_loop = loop %}
129 |
130 |
131 |
133 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 | | Score |
174 | Text |
175 |
176 |
177 |
178 |
179 |
180 | {% for tp in example["tp"] %}
181 |
182 | |
183 | {{ "%.4f"|format(tp[2]) }}
184 | |
185 | {{ tp[0] | safe }} |
186 |
187 | {% endfor %}
188 |
189 | {% for fp in example["fp"] %}
190 |
191 | |
192 | {{ "%.4f"|format(fp[2]) }}
193 | |
194 | {{ fp[0] | safe }} |
195 |
196 | {% endfor %}
197 |
198 | {% for fn in example["fn"] %}
199 |
200 | |
201 | {{ fn[0] | safe }} |
202 |
203 | {% endfor %}
204 |
205 |
206 | |
207 | |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 | {% endfor %}
216 |
217 |
218 |
219 |
222 |
225 |
228 |
229 |
--------------------------------------------------------------------------------
/data/datasets/vocab.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import Counter
3 |
4 |
5 | class Vocab(object):
6 | PAD = '*PAD*'
7 | UNK = '*UNK*'
8 | NULL = '*NULL*'
9 | START = '*START*'
10 | END = '*END*'
11 | ROOT = '*ROOT*'
12 | SELF_LOOP = 'SELF_LOOP'
13 |
14 | def __init__(self):
15 |
16 | self.tok2idx = {}
17 | self.idx2count = {}
18 | self.idx2tok = {}
19 | self.special_token_size = 0
20 | self.singleton_size = 0
21 | self.singleton_max_count = 0
22 | self.min_count = 0
23 | self.max_count = 0
24 |
25 | def __len__(self):
26 | return len(self.tok2idx)
27 |
28 | def items(self):
29 | for k, v in self.tok2idx.items():
30 | yield k, v
31 |
32 | def __getitem__(self, item):
33 | return self.tok2idx[item]
34 |
35 | def keys(self):
36 | return self.tok2idx.keys()
37 |
38 | def vals(self):
39 | return self.tok2idx.values()
40 |
41 | def add_counter(self,
42 | counter,
43 | min_count=1,
44 | max_count=1e7,
45 | singleton_max_count=1,
46 | update_count=False
47 | ):
48 | '''
49 |
50 | :param counter:
51 | :param min_count:
52 | :param max_count:
53 | :param singleton_max_count: int, we treat a token as a singleton
54 | when 0 < token count <= singleton_max_count,
55 | this is in favor of some UNK replace strategies
56 | :return:
57 | '''
58 | self.min_count = min_count
59 | self.max_count = max_count
60 | self.singleton_max_count = singleton_max_count
61 |
62 | for tok, count in counter.most_common(n=int(max_count)):
63 | if count >= self.min_count:
64 | self.add_token(tok, count, update_count)
65 |
66 | def add_spec_toks(self,
67 | pad_tok=True,
68 | unk_tok=True,
69 | start_tok=False,
70 | end_tok=False,
71 | root_tok=False,
72 | null_tok=False,
73 | self_loop_tok=False):
74 | if pad_tok:
75 | self.add_token(Vocab.PAD)
76 | self.special_token_size += 1
77 |
78 | if unk_tok:
79 | self.add_token(Vocab.UNK)
80 | self.special_token_size += 1
81 |
82 | if start_tok:
83 | self.add_token(Vocab.START)
84 | self.special_token_size += 1
85 |
86 | if end_tok:
87 | self.add_token(Vocab.END)
88 | self.special_token_size += 1
89 |
90 | if root_tok:
91 | self.add_token(Vocab.ROOT)
92 | self.special_token_size += 1
93 |
94 | if null_tok:
95 | self.add_token(Vocab.NULL)
96 | self.special_token_size += 1
97 |
98 | if self_loop_tok:
99 | self.add_token(Vocab.SELF_LOOP)
100 | self.special_token_size += 1
101 |
102 | def add_token(self, token, count=1, update_count=False):
103 | idx = self.tok2idx.get(token, None)
104 | if idx is None:
105 | idx = len(self.tok2idx)
106 | self.tok2idx[token] = idx
107 | self.idx2count[idx] = count
108 | if count <= self.singleton_max_count:
109 | self.singleton_size += 1
110 | elif update_count:
111 | new_count = self.idx2count[idx] + count
112 | self.idx2count[idx] = new_count
113 | if new_count > self.singleton_max_count:
114 | self.singleton_size -= 1
115 |
116 | return idx
117 |
118 | def get_vocab_size(self):
119 | return len(self.tok2idx)
120 |
121 | def get_vocab_size_without_spec(self):
122 | return len(self.tok2idx) - self.special_token_size
123 |
124 | def get_index(self, token, default_value='*UNK*'):
125 | idx = self.tok2idx.get(token, None)
126 | if idx is None:
127 | if default_value:
128 | try:
129 | return self.tok2idx[default_value]
130 | except KeyError:
131 | print('Token %s not found' % token)
132 | exit(1)
133 | else:
134 | raise RuntimeError('Token %s not found' % token)
135 | else:
136 | return idx
137 |
138 | def __iter__(self):
139 | for tok, idx in self.tok2idx.items():
140 | yield idx, tok
141 |
142 | def get_token(self, index):
143 | if len(self.idx2tok) == 0:
144 | for tok, idx in self.tok2idx.items():
145 | self.idx2tok[idx] = tok
146 |
147 | return self.idx2tok[index]
148 |
149 | def get_token_set(self):
150 | return self.tok2idx.keys()
151 |
152 | def recount_singleton_size(self):
153 | singleton_size = 0
154 | for count in self.idx2count.values():
155 | if count <= self.singleton_max_count:
156 | singleton_size += 1
157 | self.singleton_max_count = singleton_size
158 |
159 | def get_singleton_size(self, re_count=False):
160 | if re_count:
161 | self.recount_singleton_size()
162 | return self.singleton_size
163 |
164 | def is_singleton(self, token_or_index):
165 | if isinstance(token_or_index, str):
166 | idx = self.tok2idx.get(token_or_index, None)
167 | if idx is None:
168 | # we treat OOV as singletons
169 | return True
170 | elif isinstance(token_or_index, int):
171 | idx = token_or_index
172 | else:
173 | raise TypeError('Unknown type %s' % (type(token_or_index)))
174 |
175 | count = self.idx2count[idx]
176 | return count <= self.singleton_max_count
177 |
178 | def __contains__(self, token):
179 | return token in self.tok2idx
180 |
181 | def __str__(self):
182 | spec_tok_size_str = 'special_token_size\t' + str(self.special_token_size) + '\n'
183 | tok_size_str = 'token_size\t' + str(self.get_vocab_size_without_spec()) + '\n'
184 | singleton_size_str = 'singleton_size\t' + str(self.singleton_size) + '\n'
185 | singleton_max_count_str = 'singleton_max_count\t' + str(self.singleton_max_count) + '\n'
186 | return spec_tok_size_str + tok_size_str + singleton_size_str + singleton_max_count_str
187 |
188 | def save(self, file_path, format='text'):
189 | '''
190 |
191 | :param format: 'pickle' or 'text'
192 | :return:
193 | '''
194 |
195 | if format == 'pickle':
196 | with open(file_path, 'wb', encoding='utf-8') as file:
197 | pickle.dump(self, file)
198 |
199 | elif format == 'text':
200 | with open(file_path, 'w', encoding='utf-8') as file:
201 | file.write(str(self))
202 | # write tokens by index increase order
203 | tok2idx_list = sorted(list(self.tok2idx.items()), key=lambda x: x[1])
204 | for tok, idx in tok2idx_list:
205 | count = self.idx2count[idx]
206 | file.write(str(idx) + '\t' + tok + '\t' + str(count) + '\n')
207 | else:
208 | raise RuntimeError('Unknown save format')
209 |
210 | @staticmethod
211 | def load(file_path, format='text'):
212 |
213 | if format == 'pickle':
214 | with open(file_path, 'rb') as file:
215 | return pickle.load(file)
216 |
217 | elif format == 'text':
218 | with open(file_path, 'r', encoding='utf-8') as file:
219 | vocab = Vocab()
220 | lines = list(file.readlines())
221 | spec_tok_size = int(lines[0].split('\t')[1])
222 | tok_size = int(lines[1].split('\t')[1])
223 | singleton_size = int(lines[2].split('\t')[1])
224 | singleton_max_count = int(lines[3].split('\t')[1])
225 |
226 | vocab.special_token_size = spec_tok_size
227 | vocab.singleton_size = singleton_size
228 | vocab.singleton_max_count = singleton_max_count
229 |
230 | offset = 4 # skip head information
231 | for i in range(offset, offset + tok_size + spec_tok_size):
232 | line_arr = lines[i].strip().split('\t')
233 | idx, tok, count = int(line_arr[0]), line_arr[1], int(line_arr[2])
234 | vocab.tok2idx[tok] = idx
235 | vocab.idx2count[idx] = count
236 | return vocab
237 |
238 | else:
239 | raise RuntimeError('Unknown load format')
240 |
--------------------------------------------------------------------------------
/SynFue/input_reader.py:
--------------------------------------------------------------------------------
1 | import json
2 | from abc import abstractmethod, ABC
3 | from collections import OrderedDict
4 | from logging import Logger
5 | from typing import Iterable, List
6 |
7 | from tqdm import tqdm
8 | from transformers import BertTokenizer
9 |
10 | from SynFue import util
11 | from SynFue.terms import Dataset, TermType, RelationType, Term, Relation, Document
12 |
13 |
14 | class BaseInputReader(ABC):
15 | def __init__(self, types_path: str, tokenizer: BertTokenizer, neg_term_count: int = None,
16 | neg_rel_count: int = None, max_span_size: int = None, logger: Logger = None):
17 | types = json.load(open(types_path), object_pairs_hook=OrderedDict) # term + relation types
18 |
19 | self._term_types = OrderedDict()
20 | self._idx2term_type = OrderedDict()
21 | self._relation_types = OrderedDict()
22 | self._idx2relation_type = OrderedDict()
23 |
24 | # terms
25 | # add 'None' term type
26 | none_term_type = TermType('None', 0, 'None', 'No Term')
27 | self._term_types['None'] = none_term_type
28 | self._idx2term_type[0] = none_term_type
29 |
30 | # specified term types
31 | for i, (key, v) in enumerate(types['terms'].items()):
32 | term_type = TermType(key, i + 1, v['short'], v['verbose'])
33 | self._term_types[key] = term_type
34 | self._idx2term_type[i + 1] = term_type
35 |
36 | # relations
37 | # add 'None' relation type
38 | none_relation_type = RelationType('None', 0, 'None', 'No Relation')
39 | self._relation_types['None'] = none_relation_type
40 | self._idx2relation_type[0] = none_relation_type
41 |
42 | # specified relation types
43 | for i, (key, v) in enumerate(types['relations'].items()):
44 | relation_type = RelationType(key, i + 1, v['short'], v['verbose'], v['symmetric'])
45 | self._relation_types[key] = relation_type
46 | self._idx2relation_type[i + 1] = relation_type
47 |
48 | self._neg_term_count = neg_term_count
49 | self._neg_rel_count = neg_rel_count
50 | self._max_span_size = max_span_size
51 |
52 | self._datasets = dict()
53 |
54 | self._tokenizer = tokenizer
55 | self._logger = logger
56 |
57 | self._vocabulary_size = tokenizer.vocab_size
58 | self._context_size = -1
59 |
60 | @abstractmethod
61 | def read(self, datasets):
62 | pass
63 |
64 | def get_dataset(self, label) -> Dataset:
65 | return self._datasets[label]
66 |
67 | def get_term_type(self, idx) -> TermType:
68 | term = self._idx2term_type[idx]
69 | return term
70 |
71 | def get_relation_type(self, idx) -> RelationType:
72 | relation = self._idx2relation_type[idx]
73 | return relation
74 |
75 | def _calc_context_size(self, datasets: Iterable[Dataset]):
76 | sizes = []
77 |
78 | for dataset in datasets:
79 | for doc in dataset.documents:
80 | sizes.append(len(doc.encoding))
81 |
82 | context_size = max(sizes)
83 | return context_size
84 |
85 | def _log(self, text):
86 | if self._logger is not None:
87 | self._logger.info(text)
88 |
89 | @property
90 | def datasets(self):
91 | return self._datasets
92 |
93 | @property
94 | def term_types(self):
95 | return self._term_types
96 |
97 | @property
98 | def relation_types(self):
99 | return self._relation_types
100 |
101 | @property
102 | def relation_type_count(self):
103 | return len(self._relation_types)
104 |
105 | @property
106 | def term_type_count(self):
107 | return len(self._term_types)
108 |
109 | @property
110 | def vocabulary_size(self):
111 | return self._vocabulary_size
112 |
113 | @property
114 | def context_size(self):
115 | return self._context_size
116 |
117 | def __str__(self):
118 | string = ""
119 | for dataset in self._datasets.values():
120 | string += "Dataset: %s\n" % dataset
121 | string += str(dataset)
122 |
123 | return string
124 |
125 | def __repr__(self):
126 | return self.__str__()
127 |
128 |
129 | class JsonInputReader(BaseInputReader):
130 | def __init__(self, types_path: str, tokenizer: BertTokenizer, neg_term_count: int = None,
131 | neg_rel_count: int = None, max_span_size: int = None, logger: Logger = None):
132 | super().__init__(types_path, tokenizer, neg_term_count, neg_rel_count, max_span_size, logger)
133 |
134 | def read(self, dataset_paths):
135 | for dataset_label, dataset_path in dataset_paths.items():
136 | dataset = Dataset(dataset_label, self._relation_types, self._term_types, self._neg_term_count,
137 | self._neg_rel_count, self._max_span_size)
138 | self._parse_dataset(dataset_path, dataset)
139 | self._datasets[dataset_label] = dataset
140 |
141 | self._context_size = self._calc_context_size(self._datasets.values())
142 |
143 | def _parse_dataset(self, dataset_path, dataset):
144 | documents = json.load(open(dataset_path))
145 | for document in tqdm(documents, desc="Parse dataset '%s'" % dataset.label):
146 | self._parse_document(document, dataset)
147 |
148 | def _parse_document(self, doc, dataset) -> Document:
149 | jtokens = doc['tokens']
150 | jrelations = doc['relations']
151 | jterms = doc['entities']
152 | jdep_label = doc['dep_label']
153 | jdep_label_indices = doc['dep_label_indices']
154 | jdep = doc['dep']
155 | jpos = doc['pos']
156 | jpos_indices = doc['pos_indices']
157 |
158 | # parse tokens
159 | doc_tokens, doc_encoding = self._parse_tokens(jtokens, dataset)
160 |
161 | # parse term mentions
162 | terms = self._parse_terms(jterms, doc_tokens, dataset)
163 |
164 | # parse relations
165 | relations = self._parse_relations(jrelations, terms, dataset)
166 |
167 | # create document
168 | document = dataset.create_document(doc_tokens, terms, relations, doc_encoding, jdep_label,
169 | jdep_label_indices, jdep, jpos, jpos_indices)
170 |
171 | return document
172 |
173 | def _parse_tokens(self, jtokens, dataset):
174 | doc_tokens = []
175 |
176 | # full document encoding including special tokens ([CLS] and [SEP]) and byte-pair encodings of original tokens
177 | doc_encoding = [self._tokenizer.convert_tokens_to_ids('[CLS]')]
178 |
179 | # parse tokens
180 | for i, token_phrase in enumerate(jtokens):
181 | token_encoding = self._tokenizer.encode(token_phrase, add_special_tokens=False)
182 | span_start, span_end = i, i + 1
183 | sub_token_start, sub_token_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding))
184 |
185 | token = dataset.create_token(i, span_start, span_end, token_phrase, sub_token_start, sub_token_end)
186 |
187 | doc_tokens.append(token)
188 | doc_encoding += token_encoding
189 |
190 | doc_encoding += [self._tokenizer.convert_tokens_to_ids('[SEP]')]
191 |
192 | return doc_tokens, doc_encoding
193 |
194 | def _parse_terms(self, jterms, doc_tokens, dataset) -> List[Term]:
195 | terms = []
196 |
197 | for term_idx, jterm in enumerate(jterms):
198 | term_type = self._term_types[jterm['type']]
199 | start, end = jterm['start'], jterm['end']
200 |
201 | # create term mention
202 | tokens = doc_tokens[start:end+1]
203 | phrase = " ".join([t.phrase for t in tokens])
204 | term = dataset.create_term(term_type, tokens, phrase)
205 | terms.append(term)
206 |
207 | return terms
208 |
209 | def _parse_relations(self, jrelations, terms, dataset) -> List[Relation]:
210 | relations = []
211 |
212 | for jrelation in jrelations:
213 | relation_type = self._relation_types[jrelation['type']]
214 |
215 | head_idx = jrelation['head']
216 | tail_idx = jrelation['tail']
217 |
218 | # create relation
219 | head = terms[head_idx]
220 | tail = terms[tail_idx]
221 |
222 | reverse = int(tail.tokens[0].index) < int(head.tokens[0].index)
223 |
224 | # for symmetric relations: head occurs before tail in sentence
225 | if relation_type.symmetric and reverse:
226 | head, tail = util.swap(head, tail)
227 |
228 | relation = dataset.create_relation(relation_type, head_term=head, tail_term=tail, reverse=reverse)
229 | relations.append(relation)
230 |
231 | return relations
232 |
--------------------------------------------------------------------------------
/SynFue/templates/relation_examples.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Relation Extraction Examples
6 |
7 |
9 |
10 |
11 |
12 |
112 |
113 |
114 |
115 | Relation Extraction Examples ({{ examples|length }})
116 |
117 |
118 |
119 | check_circle_outline F1 = 100.000
120 | check_circle_outline F1 >= 50.00
121 | highlight_off F1 < 50.00
122 |
123 |
124 |
125 |
126 | True Positive
127 |
False Positive
128 |
False Negative
129 |
130 |
131 |
132 |
133 |
134 |
135 | {% for example in examples %}
136 | {% set outer_loop = loop %}
137 |
138 |
139 |
141 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | | Score |
182 | Relation |
183 | Text (Head - Tail) |
184 |
185 |
186 |
187 |
188 |
189 | {% for tp in example["tp"] %}
190 |
191 | |
192 | {{ "%.4f"|format(tp[2]) }}
193 | |
194 |
195 | {{ tp[1] }} |
196 | {{ tp[0] | safe }} |
197 |
198 | {% endfor %}
199 |
200 | {% for fp in example["fp"] %}
201 |
202 | |
203 | {{ "%.4f"|format(fp[2]) }}
204 | |
205 | {{ fp[1] }} |
206 | {{ fp[0] | safe }} |
207 |
208 | {% endfor %}
209 |
210 | {% for fn in example["fn"] %}
211 |
212 | |
213 | {{ fn[1] }} |
214 | {{ fn[0] | safe }} |
215 |
216 | {% endfor %}
217 |
218 |
219 | |
220 | |
221 | |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 | {% endfor %}
230 |
231 |
232 |
233 |
236 |
239 |
242 |
243 |
--------------------------------------------------------------------------------
/SynFue/terms.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import List
3 | from torch.utils.data import Dataset as TorchDataset
4 |
5 | from SynFue import sampling
6 |
7 |
8 | class RelationType:
9 | def __init__(self, identifier, index, short_name, verbose_name, symmetric=False):
10 | self._identifier = identifier
11 | self._index = index
12 | self._short_name = short_name
13 | self._verbose_name = verbose_name
14 | self._symmetric = symmetric
15 |
16 | @property
17 | def identifier(self):
18 | return self._identifier
19 |
20 | @property
21 | def index(self):
22 | return self._index
23 |
24 | @property
25 | def short_name(self):
26 | return self._short_name
27 |
28 | @property
29 | def verbose_name(self):
30 | return self._verbose_name
31 |
32 | @property
33 | def symmetric(self):
34 | return self._symmetric
35 |
36 | def __int__(self):
37 | return self._index
38 |
39 | def __eq__(self, other):
40 | if isinstance(other, RelationType):
41 | return self._identifier == other._identifier
42 | return False
43 |
44 | def __hash__(self):
45 | return hash(self._identifier)
46 |
47 |
48 | class TermType:
49 | def __init__(self, identifier, index, short_name, verbose_name):
50 | self._identifier = identifier
51 | self._index = index
52 | self._short_name = short_name
53 | self._verbose_name = verbose_name
54 |
55 | @property
56 | def identifier(self):
57 | return self._identifier
58 |
59 | @property
60 | def index(self):
61 | return self._index
62 |
63 | @property
64 | def short_name(self):
65 | return self._short_name
66 |
67 | @property
68 | def verbose_name(self):
69 | return self._verbose_name
70 |
71 | def __int__(self):
72 | return self._index
73 |
74 | def __eq__(self, other):
75 | if isinstance(other, TermType):
76 | return self._identifier == other._identifier
77 | return False
78 |
79 | def __hash__(self):
80 | return hash(self._identifier)
81 |
82 |
83 | class Token:
84 | def __init__(self, tid: int, index: int, span_start: int, span_end: int, phrase: str,
85 | sub_token_start: int, sub_token_end: int):
86 | self._tid = tid # ID within the corresponding dataset
87 | self._index = index # original token index in document
88 |
89 | self._span_start = span_start # start of token in document (inclusive)
90 | self._span_end = span_end # end of token in document (exclusive)
91 | self._phrase = phrase
92 | self._sub_token_start = sub_token_start # start of sub word in a document after a WordPiece Tokenizer
93 | self._sub_token_end = sub_token_end # end of sub word in a document after a WordPiece Tokenizer
94 |
95 | @property
96 | def index(self):
97 | return self._index
98 |
99 | @property
100 | def span_start(self):
101 | return self._span_start
102 |
103 | @property
104 | def span_end(self):
105 | return self._span_end
106 |
107 | @property
108 | def span(self):
109 | return self._span_start, self._span_end
110 |
111 | @property
112 | def sub_token_start(self):
113 | return self._sub_token_start
114 |
115 | @property
116 | def sub_token_end(self):
117 | return self._sub_token_end
118 |
119 | @property
120 | def sub_token(self):
121 | return self._sub_token_start, self._sub_token_end
122 |
123 | @property
124 | def phrase(self):
125 | return self._phrase
126 |
127 | def __eq__(self, other):
128 | if isinstance(other, Token):
129 | return self._tid == other._tid
130 | return False
131 |
132 | def __hash__(self):
133 | return hash(self._tid)
134 |
135 | def __str__(self):
136 | return self._phrase
137 |
138 | def __repr__(self):
139 | return self._phrase
140 |
141 |
142 | class TokenSpan:
143 | def __init__(self, tokens):
144 | self._tokens = tokens
145 |
146 | @property
147 | def span_start(self):
148 | return self._tokens[0].span_start
149 |
150 | @property
151 | def span_end(self):
152 | return self._tokens[-1].span_end
153 |
154 | @property
155 | def span(self):
156 | return self.span_start, self.span_end
157 |
158 | @property
159 | def sub_token_start(self):
160 | return self._tokens[0].sub_token_start
161 |
162 | @property
163 | def sub_token_end(self):
164 | return self._tokens[0].sub_token_end
165 |
166 | @property
167 | def sub_token(self):
168 | return self.sub_token_start, self.sub_token_end
169 |
170 | def __getitem__(self, s):
171 | if isinstance(s, slice):
172 | return TokenSpan(self._tokens[s.start:s.stop:s.step])
173 | else:
174 | try:
175 | return self._tokens[s]
176 | except:
177 | print(self._tokens)
178 | print(len(self._tokens))
179 | print(s)
180 |
181 | def __iter__(self):
182 | return iter(self._tokens)
183 |
184 | def __len__(self):
185 | return len(self._tokens)
186 |
187 |
188 | class Term:
189 | def __init__(self, eid: int, term_type: TermType, tokens: List[Token], phrase: str):
190 | self._eid = eid # ID within the corresponding dataset
191 |
192 | self._term_type = term_type
193 |
194 | self._tokens = tokens
195 | self._phrase = phrase
196 |
197 | def as_tuple(self):
198 | return self.span_start, self.span_end, self._term_type
199 |
200 | @property
201 | def term_type(self):
202 | return self._term_type
203 |
204 | @property
205 | def tokens(self):
206 | return TokenSpan(self._tokens)
207 |
208 | @property
209 | def span_start(self):
210 | return self._tokens[0].span_start
211 |
212 | @property
213 | def span_end(self):
214 | return self._tokens[-1].span_end
215 |
216 | @property
217 | def span(self):
218 | return self.span_start, self.span_end
219 |
220 | @property
221 | def sub_token_start(self):
222 | return self._tokens[0].sub_token_start
223 |
224 | @property
225 | def sub_token_end(self):
226 | return self._tokens[-1].sub_token_end
227 |
228 | @property
229 | def sub_token(self):
230 | return self.sub_token_start, self.sub_token_end
231 |
232 |
233 | @property
234 | def phrase(self):
235 | return self._phrase
236 |
237 | def __eq__(self, other):
238 | if isinstance(other, Term):
239 | return self._eid == other._eid
240 | return False
241 |
242 | def __hash__(self):
243 | return hash(self._eid)
244 |
245 | def __str__(self):
246 | return self._phrase
247 |
248 |
249 | class Relation:
250 | def __init__(self, rid: int, relation_type: RelationType, head_term: Term,
251 | tail_term: Term, reverse: bool = False):
252 | self._rid = rid # ID within the corresponding dataset
253 | self._relation_type = relation_type
254 |
255 | self._head_term = head_term
256 | self._tail_term = tail_term
257 |
258 | self._reverse = reverse
259 |
260 | self._first_term = head_term if not reverse else tail_term
261 | self._second_term = tail_term if not reverse else head_term
262 |
263 | def as_tuple(self):
264 | head = self._head_term
265 | tail = self._tail_term
266 | head_start, head_end = (head.span_start, head.span_end)
267 | tail_start, tail_end = (tail.span_start, tail.span_end)
268 |
269 | t = ((head_start, head_end, head.term_type),
270 | (tail_start, tail_end, tail.term_type), self._relation_type)
271 | return t
272 |
273 | @property
274 | def relation_type(self):
275 | return self._relation_type
276 |
277 | @property
278 | def head_term(self):
279 | return self._head_term
280 |
281 | @property
282 | def tail_term(self):
283 | return self._tail_term
284 |
285 | @property
286 | def first_term(self):
287 | return self._first_term
288 |
289 | @property
290 | def second_term(self):
291 | return self._second_term
292 |
293 | @property
294 | def reverse(self):
295 | return self._reverse
296 |
297 | def __eq__(self, other):
298 | if isinstance(other, Relation):
299 | return self._rid == other._rid
300 | return False
301 |
302 | def __hash__(self):
303 | return hash(self._rid)
304 |
305 |
306 | class Document:
307 | def __init__(self, doc_id: int, tokens: List[Token], terms: List[Term], relations: List[Relation],
308 | encoding: List[int], dep_label: List[int], dep_label_indices: List[int], dep: List[int],
309 | pos: List[str], pos_indices: List[int]):
310 | self._doc_id = doc_id # ID within the corresponding dataset
311 |
312 | self._tokens = tokens
313 | self._terms = terms
314 | self._relations = relations
315 |
316 | # byte-pair document encoding including special tokens ([CLS] and [SEP])
317 | self._encoding = encoding
318 |
319 | self._dep_label = dep_label
320 | self._dep_label_indices = dep_label_indices
321 | self._dep = dep
322 |
323 | self._pos = pos
324 | self._pos_indices = pos_indices
325 |
326 | @property
327 | def doc_id(self):
328 | return self._doc_id
329 |
330 | @property
331 | def terms(self):
332 | return self._terms
333 |
334 | @property
335 | def relations(self):
336 | return self._relations
337 |
338 | @property
339 | def tokens(self):
340 | return TokenSpan(self._tokens)
341 |
342 | @property
343 | def encoding(self):
344 | return self._encoding
345 |
346 | @property
347 | def dep_label(self):
348 | return self._dep_label
349 |
350 | @property
351 | def dep_label_indices(self):
352 | return self._dep_label_indices
353 |
354 | @property
355 | def dep(self):
356 | return self._dep
357 |
358 | @property
359 | def pos_indices(self):
360 | return self._pos_indices
361 |
362 | @property
363 | def pos(self):
364 | return self._pos
365 |
366 | @encoding.setter
367 | def encoding(self, value):
368 | self._encoding = value
369 |
370 | def __eq__(self, other):
371 | if isinstance(other, Document):
372 | return self._doc_id == other._doc_id
373 | return False
374 |
375 | def __hash__(self):
376 | return hash(self._doc_id)
377 |
378 |
379 | class BatchIterator:
380 | def __init__(self, terms, batch_size, order=None, truncate=False):
381 | self._terms = terms
382 | self._batch_size = batch_size
383 | self._truncate = truncate
384 | self._length = len(self._terms)
385 | self._order = order
386 |
387 | if order is None:
388 | self._order = list(range(len(self._terms)))
389 |
390 | self._i = 0
391 |
392 | def __iter__(self):
393 | return self
394 |
395 | def __next__(self):
396 | if self._truncate and self._i + self._batch_size > self._length:
397 | raise StopIteration
398 | elif not self._truncate and self._i >= self._length:
399 | raise StopIteration
400 | else:
401 | terms = [self._terms[n] for n in self._order[self._i:self._i + self._batch_size]]
402 | self._i += self._batch_size
403 | return terms
404 |
405 |
406 | class Dataset(TorchDataset):
407 | TRAIN_MODE = 'train'
408 | EVAL_MODE = 'eval'
409 |
410 | def __init__(self, label, rel_types, term_types, neg_term_count,
411 | neg_rel_count, max_span_size):
412 | self._label = label
413 | self._rel_types = rel_types
414 | self._term_types = term_types
415 | self._neg_term_count = neg_term_count
416 | self._neg_rel_count = neg_rel_count
417 | self._max_span_size = max_span_size
418 | self._mode = Dataset.TRAIN_MODE
419 |
420 | self._documents = OrderedDict()
421 | self._terms = OrderedDict()
422 | self._relations = OrderedDict()
423 |
424 | # current ids
425 | self._doc_id = 0
426 | self._rid = 0
427 | self._eid = 0
428 | self._tid = 0
429 |
430 | def iterate_documents(self, batch_size, order=None, truncate=False):
431 | return BatchIterator(self.documents, batch_size, order=order, truncate=truncate)
432 |
433 | def iterate_relations(self, batch_size, order=None, truncate=False):
434 | return BatchIterator(self.relations, batch_size, order=order, truncate=truncate)
435 |
436 | def create_token(self, idx, span_start, span_end, phrase, sub_token_start, sub_token_end) -> Token:
437 | token = Token(self._tid, idx, span_start, span_end, phrase, sub_token_start, sub_token_end)
438 | self._tid += 1
439 | return token
440 |
441 | def create_document(self, tokens, term_mentions, relations, doc_encoding, dep_label, dep_label_indices, dep,
442 | pos, pos_indices) -> Document:
443 | document = Document(self._doc_id, tokens, term_mentions, relations, doc_encoding, dep_label,
444 | dep_label_indices, dep, pos, pos_indices)
445 | self._documents[self._doc_id] = document
446 | self._doc_id += 1
447 |
448 | return document
449 |
450 | def create_term(self, term_type, tokens, phrase) -> Term:
451 | mention = Term(self._eid, term_type, tokens, phrase)
452 | self._terms[self._eid] = mention
453 | self._eid += 1
454 | return mention
455 |
456 | def create_relation(self, relation_type, head_term, tail_term, reverse=False) -> Relation:
457 | relation = Relation(self._rid, relation_type, head_term, tail_term, reverse)
458 | self._relations[self._rid] = relation
459 | self._rid += 1
460 | return relation
461 |
462 | def __len__(self):
463 | return len(self._documents)
464 |
465 | def __getitem__(self, index: int):
466 | doc = self._documents[index]
467 |
468 | if self._mode == Dataset.TRAIN_MODE:
469 | return sampling.create_train_sample(doc, self._neg_term_count, self._neg_rel_count,
470 | self._max_span_size, len(self._rel_types))
471 | else:
472 | return sampling.create_eval_sample(doc, self._max_span_size)
473 |
474 | def switch_mode(self, mode):
475 | self._mode = mode
476 |
477 | @property
478 | def label(self):
479 | return self._label
480 |
481 | @property
482 | def input_reader(self):
483 | return self._input_reader
484 |
485 | @property
486 | def documents(self):
487 | return list(self._documents.values())
488 |
489 | @property
490 | def terms(self):
491 | return list(self._terms.values())
492 |
493 | @property
494 | def relations(self):
495 | return list(self._relations.values())
496 |
497 | @property
498 | def document_count(self):
499 | return len(self._documents)
500 |
501 | @property
502 | def term_count(self):
503 | return len(self._terms)
504 |
505 | @property
506 | def relation_count(self):
507 | return len(self._relations)
508 |
--------------------------------------------------------------------------------
/data/datasets/preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | '''
4 | Read data from JSON files,
5 | in the meantime, we do preprocess like capitalize the first character of a sentence or normalize digits
6 | '''
7 | import os
8 | import json
9 | from collections import Counter
10 | from nltk.parse import CoreNLPDependencyParser
11 | import numpy as np
12 | import argparse
13 | from tqdm import tqdm
14 | from io_utils import read_yaml
15 | from str_utils import normalize_tok
16 | from vocab import Vocab
17 | from sklearn.model_selection import train_test_split
18 |
19 | config = read_yaml('data_config.yaml')
20 | print('seed:', config['random_seed'])
21 | normalize_digits = config['normalize_digits']
22 | lower_case = config['lower_case']
23 |
24 | depparser = CoreNLPDependencyParser(url='http://127.0.0.1:9000')
25 |
26 |
27 | def build_vocab(data_list, file_path):
28 | token_list = []
29 | char_list = []
30 | pos_list = []
31 | dep_list = []
32 | for inst in tqdm(data_list, total=len(data_list)):
33 | words = inst['words']
34 |
35 | temp_parser_res = depparser.parse(words)
36 | parser_res = []
37 | for i in temp_parser_res:
38 | temp = i.to_conll(4).strip().split('\n')
39 | for t in temp:
40 | parser_res.append(t.split('\t'))
41 | pos_list.extend([a[1] for a in parser_res])
42 | dep_list.extend([a[3] for a in parser_res])
43 |
44 | for word in words:
45 | word = normalize_tok(word, lower_case, normalize_digits)
46 | token_list.append(word)
47 | char_list.extend(list(word))
48 |
49 | token_vocab_file = os.path.join(file_path, config['token_vocab_file'])
50 | char_vocab_file = os.path.join(file_path, config['char_vocab_file'])
51 | pos_vocab_file = os.path.join(file_path, config['pos_vocab_file'])
52 | dep_type_vocab_file = os.path.join(file_path, config['dep_type_vocab_file'])
53 |
54 | print('--------token_vocab---------------')
55 | token_vocab = Vocab()
56 | token_vocab.add_spec_toks(unk_tok=True, pad_tok=False)
57 | token_vocab.add_counter(Counter(token_list))
58 | token_vocab.save(token_vocab_file)
59 | print(token_vocab)
60 |
61 | print('--------char_vocab---------------')
62 | char_vocab = Vocab()
63 | char_vocab.add_spec_toks(unk_tok=True, pad_tok=False)
64 | char_vocab.add_counter(Counter(char_list))
65 | char_vocab.save(char_vocab_file)
66 | print(char_vocab)
67 |
68 | print('--------pos_vocab---------------')
69 | pos_vocab = Vocab()
70 | pos_vocab.add_spec_toks(pad_tok=True, unk_tok=False)
71 | pos_vocab.add_counter(Counter(pos_list))
72 | pos_vocab.save(pos_vocab_file)
73 | print(pos_vocab)
74 |
75 | print('--------dep_vocab---------------')
76 | dep_vocab = Vocab()
77 | dep_vocab.add_spec_toks(pad_tok=True, unk_tok=False, self_loop_tok=True)
78 | dep_vocab.add_counter(Counter(dep_list))
79 | dep_vocab.save(dep_type_vocab_file)
80 | print(dep_vocab)
81 |
82 | return dep_vocab, pos_vocab
83 |
84 |
85 | def load_data_from_ori(file_path):
86 | pre_id = ''
87 | data_list = []
88 | data_dic = {}
89 | token = []
90 | opinion = []
91 | opinion_idx = []
92 | aspect = []
93 | aspect_idx = []
94 | pair = []
95 | pair_idx = []
96 | with open(file_path, 'r', encoding='utf-8') as f:
97 | f.readline()
98 | while True:
99 | line = f.readline()
100 | if line == '':
101 | token = token_temp
102 | aspect.append(a_temp)
103 | aspect_idx.append(a_temp_idx)
104 | opinion.append(o_temp)
105 | opinion_idx.append(o_temp_idx)
106 | pair.extend(pair_temp)
107 | pair_idx.extend(pair_idx_temp)
108 | data_dic['words'] = token
109 | data_dic['aspects'] = aspect
110 | data_dic['aspects_idx'] = aspect_idx
111 | data_dic['opinions'] = opinion
112 | data_dic['opinions_idx'] = opinion_idx
113 | data_dic['pair'] = pair
114 | data_dic['pair_idx'] = pair_idx
115 | data_list.append(data_dic)
116 | break
117 | line = line.split('\t')
118 | if pre_id == line[0]:
119 | token = token_temp
120 | aspect.append(a_temp)
121 | aspect_idx.append(a_temp_idx)
122 | opinion.append(o_temp)
123 | opinion_idx.append(o_temp_idx)
124 | pair.extend(pair_temp)
125 | pair_idx.extend(pair_idx_temp)
126 | elif pre_id != '' and pre_id != line[0]:
127 | token = token_temp
128 | aspect.append(a_temp)
129 | aspect_idx.append(a_temp_idx)
130 | opinion.append(o_temp)
131 | opinion_idx.append(o_temp_idx)
132 | pair.extend(pair_temp)
133 | pair_idx.extend(pair_idx_temp)
134 | data_dic['words'] = token
135 | data_dic['aspects'] = aspect
136 | data_dic['aspects_idx'] = aspect_idx
137 | data_dic['opinions'] = opinion
138 | data_dic['opinions_idx'] = opinion_idx
139 | data_dic['pair'] = pair
140 | data_dic['pair_idx'] = pair_idx
141 |
142 | data_list.append(data_dic)
143 | data_dic = {}
144 | token = []
145 | opinion = []
146 | opinion_idx = []
147 | aspect = []
148 | aspect_idx = []
149 | pair = []
150 | pair_idx = []
151 | try:
152 | token_temp = line[1].strip().split()
153 | except:
154 | print(line)
155 | # aspect term
156 | a_temp = []
157 | a_temp_idx = []
158 | for idx, a in enumerate(line[2].strip().split()):
159 | if a.strip().split('\\')[1] != 'O':
160 | a_temp.append(a.strip().split('\\')[0])
161 | a_temp_idx.append(idx)
162 | # opinion term
163 | o_temp = []
164 | o_temp_idx = []
165 | for idx, o in enumerate(line[3].strip().split()):
166 | if o.strip().split('\\')[1] != 'O':
167 | o_temp.append(o.strip().split('\\')[0])
168 | o_temp_idx.append(idx)
169 | pair_temp = []
170 | pair_idx_temp = []
171 | pair_temp.append((a_temp, o_temp))
172 | pair_idx_temp.append((a_temp_idx, o_temp_idx))
173 | pre_id = line[0]
174 | return data_list
175 |
176 |
177 | def sep_discontinuous_term(term_idx):
178 | """
179 | according to the source term idx, get the new idx, term rep = [start_idx, end_idx] and it is consequent
180 | :param term_idx:
181 | :return:
182 | """
183 | bert_idx = []
184 | flag = 0 # Whether to include discontinuous terms
185 | for idx in term_idx:
186 | if len(idx) < 2:
187 | if [idx[0], idx[-1]] not in bert_idx:
188 | bert_idx.append([idx[0], idx[-1]])
189 | else:
190 | if (idx[0] + len(idx) - 1) == idx[-1]:
191 | if [idx[0], idx[-1]] not in bert_idx:
192 | bert_idx.append([idx[0], idx[-1]])
193 | else:
194 | temp_flag = 0
195 | s_idx = e_idx = idx[0] # start_idx = end_idx
196 | for i in idx[1:]:
197 | if i == e_idx + 1:
198 | e_idx = i
199 | else:
200 | temp_flag = 1
201 | if [s_idx, e_idx] not in bert_idx:
202 | bert_idx.append([s_idx, e_idx])
203 | s_idx = e_idx = i
204 | if [s_idx, e_idx] not in bert_idx:
205 | bert_idx.append([s_idx, e_idx])
206 |
207 | if temp_flag == 1:
208 | flag += 1
209 | temp_flag = 0
210 |
211 | return bert_idx, flag
212 |
213 |
214 | def sep_pair_idx(pair_idx):
215 | """
216 | according to the source pair idx, get the new idx, term rep = [[start_idx, end_idx], [start_idx, end_idx]]
217 | and it is consequent
218 | :param pair_idx:
219 | :return:
220 | """
221 | bert_idx = []
222 | flag = a_flag = o_flag = 0
223 | for p in pair_idx:
224 | bert_a_idx, a_flag = sep_discontinuous_term([p[0]])
225 | bert_o_idx, o_flag = sep_discontinuous_term([p[1]])
226 | for a in bert_a_idx:
227 | for b in bert_o_idx:
228 | bert_idx.append((a, b))
229 |
230 | if a_flag != 0 or o_flag != 0:
231 | flag = 1
232 | return bert_idx, flag
233 |
234 |
235 | def construct_instance(inst_list, dep_vocab, pos_vocab):
236 | data = []
237 | idx = 0
238 | for inst in tqdm(inst_list, total=len(inst_list)):
239 | inst_dict = {}
240 | words = inst['words']
241 | words_processed = [normalize_tok(w, lower_case, normalize_digits) for w in words]
242 | temp_parser_res = depparser.parse(words_processed)
243 | parser_res = []
244 | for i in temp_parser_res:
245 | temp = i.to_conll(4).strip().split('\n')
246 | for t in temp:
247 | parser_res.append(t.split('\t'))
248 | words = [a[0] for a in parser_res]
249 | inst['words'] = words
250 | inst_dict['tokens'] = words
251 | aspects_idx = inst['aspects_idx']
252 | opinions_idx = inst['opinions_idx']
253 | pair_idx = inst['pair_idx']
254 |
255 | new_aspects_idx, _ = sep_discontinuous_term(aspects_idx)
256 | new_opinions_idx, _ = sep_discontinuous_term(opinions_idx)
257 | new_pair_idx, _ = sep_pair_idx(pair_idx)
258 |
259 | s_to_t = {}
260 | i = j = 0
261 | while i < len(inst['words']):
262 | if inst['words'][i] == words[j]:
263 | s_to_t[i] = [j]
264 | i += 1
265 | j += 1
266 | else:
267 | s_to_t[i] = []
268 | if i + 1 > len(inst['words']) - 1:
269 | s_to_t[i] = [x for x in range(j, len(words))]
270 | else:
271 | next_token = inst['words'][i + 1]
272 | while words[j] != '-RRB-' and words[j] != next_token and words[j] not in next_token and j <= len(
273 | words) - 1:
274 | s_to_t[i].append(j)
275 | j += 1
276 | i += 1
277 |
278 | def get_new_term(old_term):
279 | new_term = []
280 | for i in old_term:
281 | temp = []
282 | for j in i:
283 | temp.extend(s_to_t[j])
284 | new_term.append(temp)
285 | return new_term
286 |
287 | new_aspects = get_new_term(new_aspects_idx)
288 | new_opinions = get_new_term(new_opinions_idx)
289 | inst['aspects_idx'] = new_aspects
290 | inst['opinions_idx'] = new_opinions
291 |
292 | new_pairs = []
293 | for p in new_pair_idx:
294 | new_p_a = []
295 | for a in p[0]:
296 | new_p_a.extend(s_to_t[a])
297 |
298 | new_p_o = []
299 | for a in p[1]:
300 | new_p_o.extend(s_to_t[a])
301 | new_pairs.append((new_p_a, new_p_o))
302 | inst['pair_idx'] = new_pairs
303 |
304 | entities = []
305 | for a in new_aspects_idx:
306 | entities.append({'type': "Asp", "start": a[0], "end": a[-1]})
307 | for o in new_opinions_idx:
308 | entities.append({'type': "Opi", "start": o[0], "end": o[-1]})
309 | inst_dict['entities'] = entities.copy()
310 |
311 | relations = []
312 | for p in new_pair_idx:
313 | s1, s2 = p[0], p[-1]
314 | head = new_aspects_idx.index(s1) # the index of entities
315 | tail = new_opinions_idx.index(s2)
316 | relations.append({"type": "Pair", "head": head, "tail": tail+len(new_aspects_idx)})
317 | inst_dict['relations'] = relations
318 | inst_dict['orig_id'] = idx
319 | inst_dict['dep'] = [a[2] for a in parser_res]
320 | inst_dict['dep_label'] = [a[3] for a in parser_res]
321 | inst_dict['dep_label_indices'] = [dep_vocab.get_index(a[3], default_value='dep') for a in parser_res]
322 | inst_dict['pos'] = [a[1] for a in parser_res]
323 | inst_dict['pos_indices'] = [pos_vocab.get_index(a[1]) for a in parser_res]
324 | assert len(inst_dict['tokens']) == len(inst_dict['pos']) == len(inst_dict['dep'])
325 | # try:
326 | # assert len(inst['pos']) == len(inst['pos_indices']) == len(inst['words'])
327 | # except:
328 | # print(inst['pos'])
329 | # print(len(inst['pos']))
330 | # print(inst['words'])
331 | # print(len(inst['words']))
332 | # print(len(inst['dep_label']))
333 | # print(inst['dep_label'])
334 | # print(len(inst['dep']))
335 | # idx = inst['words'].index('1/2')
336 | # inst_dict['tokens'] = inst['words'][:idx] + inst['words'][idx+1:]
337 | # assert len(inst['tag_type']) == len(inst['tag_type_indices']) == len(inst_dict['tokens'])
338 | # print(inst_dict['tokens'])
339 | data.append(inst_dict.copy())
340 | idx += 1
341 | return data
342 |
343 |
344 | def save_data_to_json(train_list, dev_list, test_list, target_dir, dep_vocab, pos_vocab):
345 | train_path = os.path.join(target_dir, 'train.json')
346 | dev_path = os.path.join(target_dir, 'dev.json')
347 | test_path = os.path.join(target_dir, 'test.json')
348 |
349 | train_data = construct_instance(train_list, dep_vocab, pos_vocab)
350 | dev_data = construct_instance(dev_list, dep_vocab, pos_vocab)
351 | test_data = construct_instance(test_list, dep_vocab, pos_vocab)
352 |
353 | json.dump(train_data, open(train_path, 'w', encoding='utf-8'))
354 | json.dump(dev_data, open(dev_path, 'w', encoding='utf-8'))
355 | json.dump(test_data, open(test_path, 'w', encoding='utf-8'))
356 | print("preprocess data successful")
357 |
358 |
359 | def count_men_len_and_rel_dis(inst_lists):
360 | """
361 | Count the average length of mention and the distance between the head token of pairs
362 | Args:
363 | inst_lists:
364 |
365 | Returns:
366 | avg_mention_len: float
367 | avg_rel_distance: float
368 | """
369 |
370 | mention_lens = []
371 | rel_distances = []
372 | for inst in tqdm(inst_lists, total=len(inst_lists)):
373 | for p in inst["pair_idx"]:
374 | mention_lens.append(p[0][-1]-p[0][0]+1)
375 | mention_lens.append(p[1][-1]-p[1][0]+1)
376 | rel_distances.append(abs(p[0][0]-p[1][0])+1)
377 |
378 | # orl = inst['orl']
379 | # for x in orl:
380 | # if x[-1] == 'DSE':
381 | # mention_lens.append(x[3] - x[2] + 1)
382 | # elif x[-1] == 'AGENT' or x[-1] == 'TARGET':
383 | # mention_lens.append(x[3] - x[2] + 1)
384 | # rel_distances.append(abs(x[0]-x[2])+1)
385 | # else:
386 | # raise KeyError('annotation error, check {}'.format(' '.join(inst['sentences'])))
387 | avg_mention_len = sum(mention_lens) / len(mention_lens)
388 | avg_rel_distance = sum(rel_distances) / len(rel_distances)
389 | print("he average length of mentions: ", avg_mention_len)
390 | print("the distance between the head token of pairs: ", avg_rel_distance)
391 |
392 |
393 |
394 | if __name__ == '__main__':
395 |
396 | train_list = []
397 | dev_list = []
398 | test_list = []
399 | dataset = ['14lap', '14res', '15res', '16res']
400 | for d_name in dataset:
401 | print(d_name)
402 | file_path = os.path.join(config['ori_data_dir'], d_name)
403 | train = load_data_from_ori(os.path.join(file_path, 'train.tsv'))
404 | train, dev = train_test_split(train, test_size=0.2, shuffle=True)
405 | test = load_data_from_ori(os.path.join(file_path, 'test.tsv'))
406 | # count_men_len_and_rel_dis(train+dev+test)
407 | dep_vocab, pos_vocab = build_vocab(train+dev+test, os.path.join(config['new_data_dir'], d_name))
408 | save_data_to_json(train, dev, test, os.path.join(config['new_data_dir'], d_name),
409 | dep_vocab, pos_vocab)
410 |
411 |
--------------------------------------------------------------------------------
/SynFue/sampling.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 |
5 | from SynFue import util
6 |
7 |
8 | def create_train_sample(doc, neg_term_count: int, neg_rel_count: int, max_span_size: int, rel_type_count: int):
9 | encodings = doc.encoding
10 | token_count = len(doc.tokens) # the size of tokens in a document, NOT INCLUDE [CLS] AND [SEP]
11 | context_size = len(
12 | encodings) # the size of sub-tokens in a document after WordPiece Tokenizer, include [CLS] AND [SEP]
13 |
14 | pieces2word = torch.zeros((token_count, context_size), dtype=torch.bool)
15 | start = 0
16 | for i, token in enumerate(doc.tokens):
17 | pieces = list(range(token.sub_token_start, token.sub_token_end))
18 | pieces2word[i, pieces[0]: pieces[-1] + 1] = 1
19 | start += len(pieces)
20 |
21 | # positive terms
22 | pos_term_spans, pos_term_types, pos_term_masks, pos_term_sizes = [], [], [], []
23 | for e in doc.terms:
24 | pos_term_spans.append(e.span)
25 | pos_term_types.append(e.term_type.index)
26 | pos_term_masks.append(create_term_mask(*e.span, token_count))
27 | pos_term_sizes.append(len(e.tokens))
28 |
29 | # positive relations
30 | pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []
31 | pos_rels3, pos_rel_masks3, pos_rel_spans3 = [], [], []
32 | pos_pair_mask = [] # which triplet rel is true
33 | for rel in doc.relations:
34 | s1, s2 = rel.head_term.span, rel.tail_term.span
35 | pos_rels.append((pos_term_spans.index(s1), pos_term_spans.index(s2)))
36 | pos_rel_spans.append((s1, s2))
37 | pos_rel_types.append(rel.relation_type)
38 | pos_rel_masks.append(create_rel_mask(s1, s2, token_count))
39 |
40 | def is_in_relation(head, tail, relations):
41 | for _rel in relations:
42 | _s1, _s2 = _rel.head_term.span, _rel.tail_term.span
43 | if (_s1 == head and _s2 == tail) or (_s1 == tail and _s2 == head):
44 | return 1
45 | return 0
46 |
47 | for x in range(len(doc.relations)):
48 | s1, s2 = doc.relations[x].head_term, doc.relations[x].tail_term
49 | x1, x2 = pos_term_spans.index(s1.span), pos_term_spans.index(s2.span)
50 | t_p_rels3 = []
51 | t_p_rels3_mask = []
52 | t_p_rel_span3 = []
53 | for idx, _term_span in enumerate(pos_term_spans):
54 | if idx != x1 and idx != x2:
55 | if is_in_relation(s1, _term_span, doc.relations) or is_in_relation(s2, _term_span, doc.relations):
56 | t_p_rels3.append((x1, x2, idx))
57 | t_p_rels3_mask.append(create_rel_mask3(s1.span, s2.span, _term_span, token_count))
58 | t_p_rel_span3.append((s1.span, s2.span, _term_span))
59 | # t_p_rel_types3.append(1)
60 | if len(t_p_rels3) > 0:
61 | pos_rels3.append(t_p_rels3)
62 | pos_rel_masks3.append(t_p_rels3_mask)
63 | pos_pair_mask.append(1)
64 | pos_rel_spans3.append(t_p_rel_span3)
65 | # pos_rel_types3.append(t_p_rel_types3)
66 | else:
67 | pos_rels3.append([(x1, x2, 0)])
68 | pos_rel_masks3.append([(create_rel_mask3(s1.span, s2.span, (0, 0), token_count))])
69 | pos_pair_mask.append(0)
70 |
71 | assert len(pos_rels) == len(pos_rels3) == len(pos_pair_mask)
72 |
73 | # negative terms
74 | neg_term_spans, neg_term_sizes = [], []
75 | for size in range(1, max_span_size + 1):
76 | for i in range(0, (token_count - size) + 1):
77 | span = doc.tokens[i:i + size].span
78 | if span not in pos_term_spans:
79 | neg_term_spans.append(span)
80 | neg_term_sizes.append(size)
81 |
82 | # sample negative terms
83 | neg_term_samples = random.sample(list(zip(neg_term_spans, neg_term_sizes)),
84 | min(len(neg_term_spans), neg_term_count))
85 | neg_term_spans, neg_term_sizes = zip(*neg_term_samples) if neg_term_samples else ([], [])
86 |
87 | neg_term_masks = [create_term_mask(*span, token_count) for span in neg_term_spans]
88 | neg_term_types = [0] * len(neg_term_spans)
89 |
90 | # negative relations
91 | # use only strong negative relations, i.e. pairs of actual (labeled) terms that are not related
92 | # neg_rels3 = []
93 | neg_rel_spans = []
94 | neg_rel_spans3 = []
95 | neg_pair_mask = []
96 |
97 | for i1, s1 in enumerate(pos_term_spans):
98 | for i2, s2 in enumerate(pos_term_spans):
99 | rev = (s2, s1)
100 | rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric
101 |
102 | # do not add as negative relation sample:
103 | # neg. relations from an term to itself
104 | # term pairs that are related according to gt
105 | # term pairs whose reverse exists as a symmetric relation in gt
106 | if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric:
107 | neg_rel_spans.append((s1, s2))
108 |
109 | p_rel_span3 = []
110 | for i3, s3 in enumerate(pos_term_spans):
111 | # three spans are different from each other and not exist in pos_rel_span3
112 | if s1 != s2 and s1 != s3 and s2 != s3 and (s1, s2, s3) not in pos_rel_spans3:
113 | p_rel_span3.append((s1, s2, s3))
114 | if len(p_rel_span3) > 0:
115 | neg_rel_spans3.append(p_rel_span3)
116 | neg_pair_mask.append(1)
117 | else:
118 | neg_rel_spans3.append([(s1, s2, (0, 0))])
119 | neg_pair_mask.append(0)
120 |
121 | # sample negative relations
122 |
123 | assert len(neg_rel_spans) == len(neg_rel_spans3) == len(neg_pair_mask)
124 |
125 | neg_rel_spans_samples = random.sample(list(zip(neg_rel_spans, neg_rel_spans3, neg_pair_mask)),
126 | min(len(neg_rel_spans), neg_rel_count))
127 | neg_rel_spans, neg_rel_spans3, neg_pair_mask = zip(*neg_rel_spans_samples) if neg_rel_spans_samples else (
128 | [], [], [])
129 |
130 | neg_rels = [(pos_term_spans.index(s1), pos_term_spans.index(s2)) for s1, s2 in neg_rel_spans]
131 | neg_rels3 = [[(pos_term_spans.index(s1), pos_term_spans.index(s2), pos_term_spans.index(s3)) for s1, s2, s3 in x]
132 | for x in neg_rel_spans3]
133 |
134 | assert len(neg_rels3) == len(neg_rel_spans3) == len(neg_pair_mask)
135 |
136 | neg_rel_masks = [create_rel_mask(*spans, token_count) for spans in neg_rel_spans]
137 | neg_rel_masks3 = [[create_rel_mask3(*sps, token_count) for sps in spans] for spans in neg_rel_spans3]
138 | neg_rel_types = [0] * len(neg_rel_spans)
139 | # neg_rel_types3 = [0] * len(neg_rel_spans3)
140 |
141 | # merge
142 | term_types = pos_term_types + neg_term_types
143 | term_masks = pos_term_masks + neg_term_masks
144 | term_sizes = pos_term_sizes + list(neg_term_sizes)
145 | term_spans = pos_term_spans + list(neg_term_spans)
146 |
147 | rels = pos_rels + neg_rels
148 | rel_types = [r.index for r in pos_rel_types] + neg_rel_types
149 | rel_masks = pos_rel_masks + neg_rel_masks
150 |
151 | rels3 = pos_rels3 + neg_rels3
152 | # rel_types3 = pos_rel_types3 + neg_rel_types3
153 | rel_masks3 = pos_rel_masks3 + neg_rel_masks3
154 | pair_mask = pos_pair_mask + list(neg_pair_mask)
155 |
156 | assert len(term_masks) == len(term_sizes) == len(term_types)
157 | try:
158 | assert len(rels) == len(rel_masks) == len(rel_types) == len(rels3) == len(pair_mask)
159 | except:
160 | print(len(rels))
161 | print(len(rels3))
162 | print(len(pair_mask))
163 |
164 | encodings = torch.tensor(encodings, dtype=torch.long)
165 | # masking of tokens
166 | context_masks = torch.ones(context_size, dtype=torch.bool)
167 |
168 | # also create samples_masks:
169 | # tensors to mask term/relation samples of batch
170 | # since samples are stacked into batches, "padding" terms/relations possibly must be created
171 | # these are later masked during loss computation
172 | if term_masks:
173 | term_types = torch.tensor(term_types, dtype=torch.long)
174 | term_masks = torch.stack(term_masks)
175 | term_sizes = torch.tensor(term_sizes, dtype=torch.long)
176 | term_sample_masks = torch.ones([term_masks.shape[0]], dtype=torch.bool)
177 | term_spans = torch.tensor(term_spans, dtype=torch.long)
178 | else:
179 | # corner case handling (no pos/neg terms)
180 | term_types = torch.zeros([1], dtype=torch.long)
181 | term_masks = torch.zeros([1, token_count], dtype=torch.bool)
182 | term_sizes = torch.zeros([1], dtype=torch.long)
183 | term_sample_masks = torch.zeros([1], dtype=torch.bool)
184 | term_spans = torch.tensor([1, 2], dtype=torch.long)
185 |
186 | if rels:
187 | rels = torch.tensor(rels, dtype=torch.long)
188 | rel_masks = torch.stack(rel_masks)
189 | rel_types = torch.tensor(rel_types, dtype=torch.long)
190 | rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool)
191 | else:
192 | # corner case handling (no pos/neg relations)
193 | rels = torch.zeros([1, 2], dtype=torch.long)
194 | rel_types = torch.zeros([1], dtype=torch.long)
195 | rel_masks = torch.zeros([1, token_count], dtype=torch.bool)
196 | rel_sample_masks = torch.zeros([1], dtype=torch.bool)
197 |
198 | if rels3:
199 | max_tri = max([len(x) for x in rels3])
200 | for idx, r in enumerate(rels3):
201 | r_len = len(r)
202 | if r_len < max_tri:
203 | rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len))
204 | rel_masks3[idx].extend([rel_masks3[idx][0]] * (max_tri - r_len))
205 | rels3 = torch.tensor(rels3, dtype=torch.long)
206 | try:
207 | rel_masks3 = torch.stack([torch.stack(x) for x in rel_masks3])
208 | except:
209 | print(rel_masks3)
210 | rel_sample_masks3 = torch.ones([rels3.shape[0]], dtype=torch.bool)
211 | pair_mask = torch.tensor(pair_mask, dtype=torch.bool)
212 | else:
213 | rels3 = torch.zeros([1, 3], dtype=torch.long)
214 | rel_masks3 = torch.zeros([1, token_count], dtype=torch.bool)
215 | rel_sample_masks3 = torch.zeros([1], dtype=torch.bool)
216 | pair_mask = torch.tensor(pair_mask, dtype=torch.bool)
217 |
218 | # relation types to one-hot encoding
219 | rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
220 | rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
221 | rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation
222 |
223 | simple_graph = None
224 | graph = None
225 | try:
226 | simple_graph = torch.tensor(get_simple_graph(token_count, doc.dep), dtype=torch.long) # only the relation
227 | except:
228 | print(context_size)
229 | print(token_count)
230 | print(encodings)
231 | print(doc.dep)
232 | print(doc.dep_label_indices)
233 | try:
234 | graph = torch.tensor(get_graph(token_count, doc.dep, doc.dep_label_indices),
235 | dtype=torch.long) # relation and the type of relation
236 | except:
237 | print(context_size)
238 | print(token_count)
239 | print(encodings)
240 | print(doc.dep)
241 | print(doc.dep_label_indices)
242 |
243 | pos = torch.tensor(get_pos(token_count, doc.pos_indices), dtype=torch.long)
244 |
245 | return dict(encodings=encodings, context_masks=context_masks, term_masks=term_masks,
246 | term_sizes=term_sizes, term_types=term_types, term_spans=term_spans,
247 | rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
248 | rels3=rels3, rel_sample_masks3=rel_sample_masks3, rel_masks3=rel_masks3,
249 | pair_mask=pair_mask,
250 | term_sample_masks=term_sample_masks, rel_sample_masks=rel_sample_masks,
251 | simple_graph=simple_graph, graph=graph, pos=pos, pieces2word=pieces2word)
252 |
253 |
254 | def create_eval_sample(doc, max_span_size: int):
255 | encodings = doc.encoding
256 | token_count = len(doc.tokens)
257 | context_size = len(encodings)
258 |
259 | pieces2word = torch.zeros((token_count, context_size), dtype=torch.bool)
260 | start = 0
261 | for i, token in enumerate(doc.tokens):
262 | pieces = list(range(token.sub_token_start, token.sub_token_end))
263 | pieces2word[i, pieces[0]: pieces[-1] + 1] = 1
264 | start += len(pieces)
265 |
266 | # create term candidates
267 | term_spans = []
268 | term_masks = []
269 | term_sizes = []
270 |
271 | for size in range(1, max_span_size + 1):
272 | for i in range(0, (token_count - size) + 1):
273 | span = doc.tokens[i:i + size].span
274 | term_spans.append(span)
275 | term_masks.append(create_term_mask(*span, token_count))
276 | term_sizes.append(size)
277 |
278 | # create tensors
279 | # token indices
280 | # _encoding = encodings
281 | # encodings = torch.zeros(context_size, dtype=torch.long)
282 | # encodings[:len(_encoding)] = torch.tensor(_encoding, dtype=torch.long)
283 | encodings = torch.tensor(encodings, dtype=torch.long)
284 |
285 | # masking of tokens
286 | context_masks = torch.ones(context_size, dtype=torch.bool)
287 | # context_masks[:len(_encoding)] = 1
288 |
289 | # terms
290 | if term_masks:
291 | term_masks = torch.stack(term_masks)
292 | term_sizes = torch.tensor(term_sizes, dtype=torch.long)
293 | term_spans = torch.tensor(term_spans, dtype=torch.long)
294 |
295 | # tensors to mask term samples of batch
296 | # since samples are stacked into batches, "padding" terms possibly must be created
297 | # these are later masked during evaluation
298 | term_sample_masks = torch.tensor([1] * term_masks.shape[0], dtype=torch.bool)
299 | else:
300 | # corner case handling (no terms)
301 | term_masks = torch.zeros([1, token_count], dtype=torch.bool)
302 | term_sizes = torch.zeros([1], dtype=torch.long)
303 | term_spans = torch.zeros([1, 2], dtype=torch.long)
304 | term_sample_masks = torch.zeros([1], dtype=torch.bool)
305 |
306 | simple_graph = torch.tensor(get_simple_graph(token_count, doc.dep), dtype=torch.long) # only the relation
307 | graph = torch.tensor(get_graph(token_count, doc.dep, doc.dep_label_indices),
308 | dtype=torch.long) # relation and the type of relation
309 | pos = torch.tensor(get_pos(token_count, doc.pos_indices), dtype=torch.long)
310 |
311 | return dict(encodings=encodings, context_masks=context_masks, term_masks=term_masks,
312 | term_sizes=term_sizes, term_spans=term_spans, term_sample_masks=term_sample_masks,
313 | simple_graph=simple_graph, graph=graph, pos=pos, pieces2word=pieces2word)
314 |
315 |
316 | def create_term_mask(start, end, context_size):
317 | mask = torch.zeros(context_size, dtype=torch.bool)
318 | mask[start:end] = 1
319 | return mask
320 |
321 |
322 | def create_rel_mask(s1, s2, context_size):
323 | start = s1[1] if s1[1] < s2[0] else s2[1]
324 | end = s2[0] if s1[1] < s2[0] else s1[0]
325 | mask = create_term_mask(start, end, context_size)
326 | return mask
327 |
328 |
329 | def create_rel_mask3(s1, s2, s3, context_size):
330 | mask = torch.zeros(context_size, dtype=torch.bool)
331 | start = min(s1[0], s1[1], s2[0], s2[1], s3[0], s3[1])
332 | end = max(s1[0], s1[1], s2[0], s2[1], s3[0], s3[1])
333 | mask[start:end] = 1
334 | return mask
335 |
336 |
337 | def collate_fn_padding(batch):
338 | padded_batch = dict()
339 | keys = batch[0].keys()
340 |
341 | for key in keys:
342 | samples = [s[key] for s in batch]
343 | if not batch[0][key].shape:
344 | padded_batch[key] = torch.stack(samples)
345 | else:
346 | padded_batch[key] = util.padded_stack([s[key] for s in batch])
347 |
348 | return padded_batch
349 |
350 |
351 | def get_graph(seq_len, feature_data, feature2id):
352 | """
353 | To create a table t_{i,j} in T. t_{i,j} = r, r is the dependency relation label between word i and word j.
354 | :param seq_len: token
355 | :param feature_data: dependency head. Specifically, '0' represents the head word is ROOT
356 | :param feature2id: dependency label indices
357 | :return:
358 | """
359 | assert len(feature2id) == len(feature_data) == seq_len
360 | ret = [[0] * seq_len for _ in range(seq_len)]
361 | for i, item in enumerate(feature_data):
362 | # the head word is ROOT, so this token only has a self-loop edge
363 | if int(item) == 0:
364 | ret[i][i] = 1
365 | continue
366 | ret[i][int(item) - 1] = feature2id[i]
367 | ret[int(item) - 1][i] = feature2id[i]
368 | ret[i][i] = 1
369 | return ret
370 |
371 |
372 | def get_simple_graph(seq_len, feature_data):
373 | """
374 | To create a table t_{i,j} in T. t_{i,j} = 1, which means there is an edge between the word i and word j.
375 | :param seq_len: token
376 | :param feature_data: dependency head. Specifically, '0' represents the head word is ROOT
377 | :return:
378 | """
379 | assert len(feature_data) == seq_len
380 | ret = [[0] * seq_len for _ in range(seq_len)]
381 | for i, item in enumerate(feature_data):
382 | if int(item) == 0:
383 | ret[i][i] = 1
384 | continue
385 | ret[i][int(item) - 1] = 1
386 | ret[int(item) - 1][i] = 1
387 | ret[i][i] = 1
388 | return ret
389 |
390 |
391 | def get_pos(seq_len, pos_indices):
392 | assert len(pos_indices) == seq_len
393 | ret = [0] * seq_len
394 | for i, item in enumerate(pos_indices):
395 | ret[i] = pos_indices[i]
396 | return ret
397 |
--------------------------------------------------------------------------------
/SynFue/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 |
5 | import torch
6 | from torch.nn import DataParallel
7 | from torch.optim import Optimizer
8 | import transformers
9 | from torch.utils.data import DataLoader
10 | from transformers import AdamW, BertConfig
11 | from transformers import BertTokenizer
12 |
13 | from SynFue import models
14 | from SynFue import sampling
15 | from SynFue import util
16 | from SynFue.terms import Dataset
17 | from SynFue.evaluator import Evaluator
18 | from SynFue.input_reader import JsonInputReader, BaseInputReader
19 | from SynFue.loss import SynFueLoss, Loss
20 | from tqdm import tqdm
21 | from SynFue.base_trainer import BaseTrainer
22 |
23 | from torchsummary import summary
24 |
25 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
26 |
27 |
28 | class SynFueTrainer(BaseTrainer):
29 | """ Joint term and relation extraction training and evaluation """
30 |
31 | def __init__(self, args: argparse.Namespace):
32 | super().__init__(args)
33 |
34 | # byte-pair encoding
35 | self._tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
36 | do_lower_case=args.lowercase,
37 | cache_dir=args.cache_path)
38 |
39 | # path to export predictions to
40 | self._predictions_path = os.path.join(self._log_path, 'predictions_%s_epoch_%s.json')
41 |
42 | # path to export relation extraction examples to
43 | self._examples_path = os.path.join(self._log_path, 'examples_%s_%s_epoch_%s.html')
44 |
45 | def train(self, train_path: str, valid_path: str, types_path: str, input_reader_cls: BaseInputReader):
46 | args = self.args
47 | train_label, valid_label = 'train', 'valid'
48 |
49 | self._logger.info("Datasets: %s, %s" % (train_path, valid_path))
50 | self._logger.info("Model type: %s" % args.model_type)
51 |
52 | # create log csv files
53 | self._init_train_logging(train_label)
54 | self._init_eval_logging(valid_label)
55 |
56 | # read datasets
57 | input_reader = input_reader_cls(types_path, self._tokenizer, args.neg_term_count,
58 | args.neg_relation_count, args.max_span_size, self._logger)
59 | input_reader.read({train_label: train_path, valid_label: valid_path})
60 | self._log_datasets(input_reader)
61 |
62 | train_dataset = input_reader.get_dataset(train_label)
63 | train_sample_count = train_dataset.document_count
64 | updates_epoch = train_sample_count // args.train_batch_size
65 | updates_total = updates_epoch * args.epochs
66 |
67 | validation_dataset = input_reader.get_dataset(valid_label)
68 |
69 | self._logger.info("Updates per epoch: %s" % updates_epoch)
70 | self._logger.info("Updates total: %s" % updates_total)
71 |
72 | # create model
73 | model_class = models.get_model(self.args.model_type)
74 |
75 | # print('Training:')
76 | # print('model_path: ', self.args.model_path)
77 | # print('cache_path: ', self.args.cache_path)
78 |
79 | # load model
80 | config = BertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path)
81 | util.check_version(config, model_class, self.args.model_path)
82 |
83 | config.model_version = model_class.VERSION
84 | model = model_class.from_pretrained(self.args.model_path,
85 | config=config,
86 | cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
87 | relation_types=input_reader.relation_type_count - 1,
88 | term_types=input_reader.term_type_count,
89 | max_pairs=self.args.max_pairs,
90 | prop_drop=self.args.prop_drop,
91 | size_embedding=self.args.size_embedding,
92 | freeze_transformer=self.args.freeze_transformer,
93 | args=self.args,
94 | beta=self.args.beta,
95 | alpha=self.args.alpha,
96 | sigma=self.args.sigma)
97 |
98 | model.to(self._device)
99 | # summary(model)
100 | # create optimizer
101 | optimizer_params = self._get_optimizer_params(model)
102 | optimizer = AdamW(optimizer_params, lr=args.lr, weight_decay=args.weight_decay, correct_bias=False)
103 | # create scheduler
104 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
105 | num_warmup_steps=args.lr_warmup * updates_total,
106 | num_training_steps=updates_total)
107 | # create loss function
108 | rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
109 | term_criterion = torch.nn.CrossEntropyLoss(reduction='none')
110 | compute_loss = SynFueLoss(rel_criterion, term_criterion, model, optimizer, scheduler, args.max_grad_norm)
111 |
112 | # eval validation set
113 | if args.init_eval:
114 | self._eval(model, validation_dataset, input_reader, 0, updates_epoch)
115 |
116 | # train
117 | best_f1 = 0.0
118 | for epoch in range(args.epochs):
119 | # train epoch
120 | self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch)
121 | # rel_nec_eval = self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)
122 | # eval validation sets
123 | if not args.final_eval or (epoch == args.epochs - 1):
124 | rel_nec_eval = self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)
125 | if best_f1 < rel_nec_eval[-1]:
126 | # save final model
127 | best_f1 = rel_nec_eval[-1]
128 | extra = dict(epoch=args.epochs, updates_epoch=updates_epoch, epoch_iteration=0)
129 | global_iteration = args.epochs * updates_epoch
130 | self._save_model(self._save_path, model, self._tokenizer, global_iteration,
131 | optimizer=optimizer if self.args.save_optimizer else None, save_as_best=True,
132 | extra=extra, include_iteration=False)
133 |
134 | self._logger.info("Logged in: %s" % self._log_path)
135 | self._logger.info("Saved in: %s" % self._save_path)
136 | self._close_summary_writer()
137 |
138 | def eval(self, dataset_path: str, types_path: str, input_reader_cls: BaseInputReader):
139 | args = self.args
140 | dataset_label = 'test'
141 |
142 | self._logger.info("Dataset: %s" % dataset_path)
143 | self._logger.info("Model: %s" % args.model_type)
144 |
145 | # create log csv files
146 | self._init_eval_logging(dataset_label)
147 |
148 | # read datasets
149 | input_reader = input_reader_cls(types_path, self._tokenizer,
150 | max_span_size=args.max_span_size, logger=self._logger)
151 | input_reader.read({dataset_label: dataset_path})
152 | self._log_datasets(input_reader)
153 |
154 | # create model
155 | model_class = models.get_model(self.args.model_type)
156 |
157 | # print('Eval:')
158 | # print('model_path: ', self.args.model_path)
159 | # print('cache_path: ', self.args.cache_path)
160 |
161 | config = BertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path)
162 |
163 | util.check_version(config, model_class, self.args.model_path)
164 |
165 | model = model_class.from_pretrained(self.args.model_path,
166 | config=config,
167 | cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
168 | relation_types=input_reader.relation_type_count - 1,
169 | term_types=input_reader.term_type_count,
170 | max_pairs=self.args.max_pairs,
171 | prop_drop=self.args.prop_drop,
172 | size_embedding=self.args.size_embedding,
173 | freeze_transformer=self.args.freeze_transformer,
174 | args=self.args,
175 | beta=self.args.beta,
176 | alpha=self.args.alpha,
177 | sigma=self.args.sigma)
178 | # model = torch.load(os.path.join(self.args.model_path, 'pytorch_model.bin'))
179 |
180 | model.to(self._device)
181 | # summary(model)
182 | # evaluate
183 | self._eval(model, input_reader.get_dataset(dataset_label), input_reader)
184 |
185 | self._logger.info("Logged in: %s" % self._log_path)
186 | self._close_summary_writer()
187 |
188 | def _train_epoch(self, model: torch.nn.Module, compute_loss: Loss, optimizer: Optimizer, dataset: Dataset,
189 | updates_epoch: int, epoch: int):
190 | self._logger.info("Train epoch: %s" % epoch)
191 |
192 | # create data loader
193 | dataset.switch_mode(Dataset.TRAIN_MODE)
194 | data_loader = DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, drop_last=True,
195 | num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding)
196 |
197 | model.zero_grad()
198 |
199 | iteration = 0
200 | total = dataset.document_count // self.args.train_batch_size
201 | for batch in tqdm(data_loader, total=total, desc='Train epoch %s' % epoch):
202 | model.train()
203 | batch = util.to_device(batch, self._device)
204 |
205 | # forward step
206 | term_logits, rel_logits = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
207 | term_masks=batch['term_masks'], term_sizes=batch['term_sizes'],
208 | term_spans=batch['term_spans'], term_types=batch['term_types'],
209 | relations=batch['rels'], rel_masks=batch['rel_masks'],
210 | simple_graph=batch['simple_graph'], graph=batch['graph'],
211 | relations3=batch['rels3'], rel_masks3=batch['rel_masks3'],
212 | pair_mask=batch['pair_mask'], pos=batch['pos'],
213 | pieces2word=batch['pieces2word'])
214 |
215 | # compute loss and optimize parameters
216 | batch_loss = compute_loss.compute(term_logits=term_logits, rel_logits=rel_logits,
217 | rel_types=batch['rel_types'], term_types=batch['term_types'],
218 | term_sample_masks=batch['term_sample_masks'],
219 | rel_sample_masks=batch['rel_sample_masks'])
220 |
221 | # logging
222 | iteration += 1
223 | global_iteration = epoch * updates_epoch + iteration
224 |
225 | if global_iteration % self.args.train_log_iter == 0:
226 | self._log_train(optimizer, batch_loss, epoch, iteration, global_iteration, dataset.label)
227 |
228 | return iteration
229 |
230 | def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInputReader,
231 | epoch: int = 0, updates_epoch: int = 0, iteration: int = 0):
232 | self._logger.info("Evaluate: %s" % dataset.label)
233 |
234 | if isinstance(model, DataParallel):
235 | # currently no multi GPU support during evaluation
236 | model = model.module
237 |
238 | # create evaluator
239 | evaluator = Evaluator(dataset, input_reader, self._tokenizer,
240 | self.args.rel_filter_threshold, self.args.no_overlapping, self._predictions_path,
241 | self._examples_path, self.args.example_count, epoch, dataset.label)
242 |
243 | # create data loader
244 | dataset.switch_mode(Dataset.EVAL_MODE)
245 | data_loader = DataLoader(dataset, batch_size=self.args.eval_batch_size, shuffle=False, drop_last=False,
246 | num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding)
247 |
248 | with torch.no_grad():
249 | model.eval()
250 |
251 | # iterate batches
252 | total = math.ceil(dataset.document_count / self.args.eval_batch_size)
253 | for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch):
254 | # move batch to selected device
255 | batch = util.to_device(batch, self._device)
256 |
257 | # run model (forward pass)
258 | result = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
259 | term_masks=batch['term_masks'], term_sizes=batch['term_sizes'],
260 | term_spans=batch['term_spans'], term_sample_masks=batch['term_sample_masks'],
261 | evaluate=True, simple_graph=batch['simple_graph'], graph=batch['graph'],
262 | pos=batch['pos'], pieces2word=batch['pieces2word']) # pos=batch['pos']
263 | term_clf, rel_clf, rels = result
264 |
265 | # evaluate batch
266 | evaluator.eval_batch(term_clf, rel_clf, rels, batch)
267 |
268 | global_iteration = epoch * updates_epoch + iteration
269 | ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores()
270 | self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval,
271 | epoch, iteration, global_iteration, dataset.label)
272 |
273 | if self.args.store_predictions and not self.args.no_overlapping:
274 | evaluator.store_predictions()
275 |
276 | if self.args.store_examples:
277 | evaluator.store_examples()
278 | return rel_nec_eval
279 |
280 | def _get_optimizer_params(self, model):
281 | param_optimizer = list(model.named_parameters())
282 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
283 | optimizer_params = [
284 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
285 | 'weight_decay': self.args.weight_decay},
286 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
287 |
288 | return optimizer_params
289 |
290 | def _log_train(self, optimizer: Optimizer, loss: float, epoch: int,
291 | iteration: int, global_iteration: int, label: str):
292 | # average loss
293 | avg_loss = loss / self.args.train_batch_size
294 | # get current learning rate
295 | lr = self._get_lr(optimizer)[0]
296 |
297 | # log to tensorboard
298 | self._log_tensorboard(label, 'loss', loss, global_iteration)
299 | self._log_tensorboard(label, 'loss_avg', avg_loss, global_iteration)
300 | self._log_tensorboard(label, 'lr', lr, global_iteration)
301 |
302 | # log to csv
303 | self._log_csv(label, 'loss', loss, epoch, iteration, global_iteration)
304 | self._log_csv(label, 'loss_avg', avg_loss, epoch, iteration, global_iteration)
305 | self._log_csv(label, 'lr', lr, epoch, iteration, global_iteration)
306 |
307 | def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: float,
308 | ner_prec_macro: float, ner_rec_macro: float, ner_f1_macro: float,
309 |
310 | rel_prec_micro: float, rel_rec_micro: float, rel_f1_micro: float,
311 | rel_prec_macro: float, rel_rec_macro: float, rel_f1_macro: float,
312 |
313 | rel_nec_prec_micro: float, rel_nec_rec_micro: float, rel_nec_f1_micro: float,
314 | rel_nec_prec_macro: float, rel_nec_rec_macro: float, rel_nec_f1_macro: float,
315 | epoch: int, iteration: int, global_iteration: int, label: str):
316 |
317 | # log to tensorboard
318 | self._log_tensorboard(label, 'eval/ner_prec_micro', ner_prec_micro, global_iteration)
319 | self._log_tensorboard(label, 'eval/ner_recall_micro', ner_rec_micro, global_iteration)
320 | self._log_tensorboard(label, 'eval/ner_f1_micro', ner_f1_micro, global_iteration)
321 | self._log_tensorboard(label, 'eval/ner_prec_macro', ner_prec_macro, global_iteration)
322 | self._log_tensorboard(label, 'eval/ner_recall_macro', ner_rec_macro, global_iteration)
323 | self._log_tensorboard(label, 'eval/ner_f1_macro', ner_f1_macro, global_iteration)
324 |
325 | self._log_tensorboard(label, 'eval/rel_prec_micro', rel_prec_micro, global_iteration)
326 | self._log_tensorboard(label, 'eval/rel_recall_micro', rel_rec_micro, global_iteration)
327 | self._log_tensorboard(label, 'eval/rel_f1_micro', rel_f1_micro, global_iteration)
328 | self._log_tensorboard(label, 'eval/rel_prec_macro', rel_prec_macro, global_iteration)
329 | self._log_tensorboard(label, 'eval/rel_recall_macro', rel_rec_macro, global_iteration)
330 | self._log_tensorboard(label, 'eval/rel_f1_macro', rel_f1_macro, global_iteration)
331 |
332 | self._log_tensorboard(label, 'eval/rel_nec_prec_micro', rel_nec_prec_micro, global_iteration)
333 | self._log_tensorboard(label, 'eval/rel_nec_recall_micro', rel_nec_rec_micro, global_iteration)
334 | self._log_tensorboard(label, 'eval/rel_nec_f1_micro', rel_nec_f1_micro, global_iteration)
335 | self._log_tensorboard(label, 'eval/rel_nec_prec_macro', rel_nec_prec_macro, global_iteration)
336 | self._log_tensorboard(label, 'eval/rel_nec_recall_macro', rel_nec_rec_macro, global_iteration)
337 | self._log_tensorboard(label, 'eval/rel_nec_f1_macro', rel_nec_f1_macro, global_iteration)
338 |
339 | # log to csv
340 | self._log_csv(label, 'eval', ner_prec_micro, ner_rec_micro, ner_f1_micro,
341 | ner_prec_macro, ner_rec_macro, ner_f1_macro,
342 |
343 | rel_prec_micro, rel_rec_micro, rel_f1_micro,
344 | rel_prec_macro, rel_rec_macro, rel_f1_macro,
345 |
346 | rel_nec_prec_micro, rel_nec_rec_micro, rel_nec_f1_micro,
347 | rel_nec_prec_macro, rel_nec_rec_macro, rel_nec_f1_macro,
348 | epoch, iteration, global_iteration)
349 |
350 | def _log_datasets(self, input_reader):
351 | self._logger.info("Relation type count: %s" % input_reader.relation_type_count)
352 | self._logger.info("Term type count: %s" % input_reader.term_type_count)
353 |
354 | self._logger.info("Terms:")
355 | for e in input_reader.term_types.values():
356 | self._logger.info(e.verbose_name + '=' + str(e.index))
357 |
358 | self._logger.info("Relations:")
359 | for r in input_reader.relation_types.values():
360 | self._logger.info(r.verbose_name + '=' + str(r.index))
361 |
362 | for k, d in input_reader.datasets.items():
363 | self._logger.info('Dataset: %s' % k)
364 | self._logger.info("Document count: %s" % d.document_count)
365 | self._logger.info("Relation count: %s" % d.relation_count)
366 | self._logger.info("Term count: %s" % d.term_count)
367 |
368 | self._logger.info("Context size: %s" % input_reader.context_size)
369 |
370 | def _init_train_logging(self, label):
371 | self._add_dataset_logging(label,
372 | data={'lr': ['lr', 'epoch', 'iteration', 'global_iteration'],
373 | 'loss': ['loss', 'epoch', 'iteration', 'global_iteration'],
374 | 'loss_avg': ['loss_avg', 'epoch', 'iteration', 'global_iteration']})
375 |
376 | def _init_eval_logging(self, label):
377 | self._add_dataset_logging(label,
378 | data={'eval': ['ner_prec_micro', 'ner_rec_micro', 'ner_f1_micro',
379 | 'ner_prec_macro', 'ner_rec_macro', 'ner_f1_macro',
380 | 'rel_prec_micro', 'rel_rec_micro', 'rel_f1_micro',
381 | 'rel_prec_macro', 'rel_rec_macro', 'rel_f1_macro',
382 | 'rel_nec_prec_micro', 'rel_nec_rec_micro', 'rel_nec_f1_micro',
383 | 'rel_nec_prec_macro', 'rel_nec_rec_macro', 'rel_nec_f1_macro',
384 | 'epoch', 'iteration', 'global_iteration']})
385 |
--------------------------------------------------------------------------------
/SynFue/evaluator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import warnings
4 | from typing import List, Tuple, Dict
5 |
6 | import torch
7 | from sklearn.metrics import precision_recall_fscore_support as prfs
8 | from transformers import BertTokenizer
9 |
10 | from SynFue import util
11 | from SynFue.terms import Document, Dataset, TermType
12 | from SynFue.input_reader import JsonInputReader
13 | from SynFue.opt import jinja2
14 |
15 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
16 |
17 |
18 | class Evaluator:
19 | def __init__(self, dataset: Dataset, input_reader: JsonInputReader, text_encoder: BertTokenizer,
20 | rel_filter_threshold: float, no_overlapping: bool,
21 | predictions_path: str, examples_path: str, example_count: int, epoch: int, dataset_label: str):
22 | self._text_encoder = text_encoder
23 | self._input_reader = input_reader
24 | self._dataset = dataset
25 | self._rel_filter_threshold = rel_filter_threshold
26 | self._no_overlapping = no_overlapping
27 |
28 | self._epoch = epoch
29 | self._dataset_label = dataset_label
30 |
31 | self._predictions_path = predictions_path
32 |
33 | self._examples_path = examples_path
34 | self._example_count = example_count
35 |
36 | # relations
37 | self._gt_relations = [] # ground truth
38 | self._pred_relations = [] # prediction
39 |
40 | # terms
41 | self._gt_terms = [] # ground truth
42 | self._pred_terms = [] # prediction
43 |
44 | self._pseudo_term_type = TermType('Term', 1, 'Term', 'Term') # for span only evaluation
45 |
46 | self._convert_gt(self._dataset.documents)
47 |
48 | def eval_batch(self, batch_term_clf: torch.tensor, batch_rel_clf: torch.tensor,
49 | batch_rels: torch.tensor, batch: dict):
50 | batch_size = batch_rel_clf.shape[0]
51 | rel_class_count = batch_rel_clf.shape[2]
52 | # get maximum activation (index of predicted term type)
53 | batch_term_types = batch_term_clf.argmax(dim=-1)
54 | # apply term sample mask
55 | batch_term_types *= batch['term_sample_masks'].long()
56 |
57 | batch_rel_clf = batch_rel_clf.view(batch_size, -1)
58 |
59 | # apply threshold to relations
60 | if self._rel_filter_threshold > 0:
61 | batch_rel_clf[batch_rel_clf < self._rel_filter_threshold] = 0
62 |
63 | for i in range(batch_size):
64 | # get model predictions for sample
65 | rel_clf = batch_rel_clf[i]
66 | term_types = batch_term_types[i]
67 |
68 | # get predicted relation labels and corresponding term pairs
69 | rel_nonzero = rel_clf.nonzero().view(-1)
70 | rel_scores = rel_clf[rel_nonzero]
71 |
72 | rel_types = (rel_nonzero % rel_class_count) + 1 # model does not predict None class (+1)
73 | rel_indices = rel_nonzero // rel_class_count
74 |
75 | rels = batch_rels[i][rel_indices]
76 |
77 | # get masks of terms in relation
78 | rel_term_spans = batch['term_spans'][i][rels].long()
79 |
80 | # get predicted term types
81 | rel_term_types = torch.zeros([rels.shape[0], 2])
82 | if rels.shape[0] != 0:
83 | rel_term_types = torch.stack([term_types[rels[j]] for j in range(rels.shape[0])])
84 |
85 | # convert predicted relations for evaluation
86 | sample_pred_relations = self._convert_pred_relations(rel_types, rel_term_spans,
87 | rel_term_types, rel_scores)
88 |
89 | # get terms that are not classified as 'None'
90 | valid_term_indices = term_types.nonzero().view(-1)
91 | valid_term_types = term_types[valid_term_indices]
92 | valid_term_spans = batch['term_spans'][i][valid_term_indices]
93 | valid_term_scores = torch.gather(batch_term_clf[i][valid_term_indices], 1,
94 | valid_term_types.unsqueeze(1)).view(-1)
95 | sample_pred_terms = self._convert_pred_terms(valid_term_types, valid_term_spans,
96 | valid_term_scores)
97 |
98 | if self._no_overlapping:
99 | sample_pred_terms, sample_pred_relations = self._remove_overlapping(sample_pred_terms,
100 | sample_pred_relations)
101 |
102 | self._pred_terms.append(sample_pred_terms)
103 | self._pred_relations.append(sample_pred_relations)
104 |
105 | def compute_scores(self):
106 | print("Evaluation")
107 |
108 | print("")
109 | print("--- Terms (named term recognition (NTR)) ---")
110 | print("An term is considered correct if the term type and span is predicted correctly")
111 | print("")
112 | gt, pred = self._convert_by_setting(self._gt_terms, self._pred_terms, include_term_types=True)
113 | ner_eval = self._score(gt, pred, print_results=True)
114 |
115 | print("")
116 | print("--- Relations ---")
117 | print("")
118 | print("Without named term classification (NTC)")
119 | print("A relation is considered correct if the relation type and the spans of the two "
120 | "related terms are predicted correctly (term type is not considered)")
121 | print("")
122 | gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_term_types=False)
123 | rel_eval = self._score(gt, pred, print_results=True)
124 |
125 | print("")
126 | print("With named term classification (NTC)")
127 | print("A relation is considered correct if the relation type and the two "
128 | "related terms are predicted correctly (in span and term type)")
129 | print("")
130 | gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_term_types=True)
131 | rel_nec_eval = self._score(gt, pred, print_results=True)
132 |
133 | return ner_eval, rel_eval, rel_nec_eval
134 |
135 | def store_predictions(self):
136 | predictions = []
137 |
138 | for i, doc in enumerate(self._dataset.documents):
139 | tokens = doc.tokens
140 | pred_terms = self._pred_terms[i]
141 | pred_relations = self._pred_relations[i]
142 |
143 | # convert terms
144 | converted_terms = []
145 | for term in pred_terms:
146 | term_span = term[:2]
147 | span_tokens = util.get_span_tokens(tokens, term_span)
148 | term_type = term[2].identifier
149 | # if term_type == 'None':
150 | # continue
151 | converted_term = dict(type=term_type, start=span_tokens[0].index, end=span_tokens[-1].index + 1)
152 | converted_terms.append(converted_term)
153 | converted_terms = sorted(converted_terms, key=lambda e: e['start'])
154 |
155 | # print('converted_terms: ', converted_terms)
156 |
157 | # convert relations
158 | converted_relations = []
159 | for relation in pred_relations:
160 | head, tail = relation[:2]
161 | head_span, head_type = head[:2], head[2].identifier
162 | tail_span, tail_type = tail[:2], tail[2].identifier
163 | head_span_tokens = util.get_span_tokens(tokens, head_span)
164 | tail_span_tokens = util.get_span_tokens(tokens, tail_span)
165 | relation_type = relation[2].identifier
166 |
167 | converted_head = dict(type=head_type, start=head_span_tokens[0].index,
168 | end=head_span_tokens[-1].index + 1)
169 | converted_tail = dict(type=tail_type, start=tail_span_tokens[0].index,
170 | end=tail_span_tokens[-1].index + 1)
171 |
172 | # print(converted_tail)
173 | head_idx = converted_terms.index(converted_head)
174 | tail_idx = converted_terms.index(converted_tail)
175 |
176 | converted_relation = dict(type=relation_type, head=head_idx, tail=tail_idx)
177 | converted_relations.append(converted_relation)
178 | converted_relations = sorted(converted_relations, key=lambda r: r['head'])
179 |
180 | doc_predictions = dict(tokens=[t.phrase for t in tokens], terms=converted_terms,
181 | relations=converted_relations)
182 | predictions.append(doc_predictions)
183 |
184 | # store as json
185 | label, epoch = self._dataset_label, self._epoch
186 | with open(self._predictions_path % (label, epoch), 'w') as predictions_file:
187 | json.dump(predictions, predictions_file)
188 |
189 | def store_examples(self):
190 | if jinja2 is None:
191 | warnings.warn("Examples cannot be stored since Jinja2 is not installed.")
192 | return
193 |
194 | term_examples = []
195 | rel_examples = []
196 | rel_examples_nec = []
197 |
198 | for i, doc in enumerate(self._dataset.documents):
199 | # terms
200 | term_example = self._convert_example(doc, self._gt_terms[i], self._pred_terms[i],
201 | include_term_types=True, to_html=self._term_to_html)
202 | term_examples.append(term_example)
203 |
204 | # relations
205 | # without term types
206 | rel_example = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i],
207 | include_term_types=False, to_html=self._rel_to_html)
208 | rel_examples.append(rel_example)
209 |
210 | # with term types
211 | rel_example_nec = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i],
212 | include_term_types=True, to_html=self._rel_to_html)
213 | rel_examples_nec.append(rel_example_nec)
214 |
215 | label, epoch = self._dataset_label, self._epoch
216 |
217 | # terms
218 | self._store_examples(term_examples[:self._example_count],
219 | file_path=self._examples_path % ('terms', label, epoch),
220 | template='term_examples.html')
221 |
222 | self._store_examples(sorted(term_examples[:self._example_count],
223 | key=lambda k: k['length']),
224 | file_path=self._examples_path % ('terms_sorted', label, epoch),
225 | template='term_examples.html')
226 |
227 | # relations
228 | # without term types
229 | self._store_examples(rel_examples[:self._example_count],
230 | file_path=self._examples_path % ('rel', label, epoch),
231 | template='relation_examples.html')
232 |
233 | self._store_examples(sorted(rel_examples[:self._example_count],
234 | key=lambda k: k['length']),
235 | file_path=self._examples_path % ('rel_sorted', label, epoch),
236 | template='relation_examples.html')
237 |
238 | # with term types
239 | self._store_examples(rel_examples_nec[:self._example_count],
240 | file_path=self._examples_path % ('rel_nec', label, epoch),
241 | template='relation_examples.html')
242 |
243 | self._store_examples(sorted(rel_examples_nec[:self._example_count],
244 | key=lambda k: k['length']),
245 | file_path=self._examples_path % ('rel_nec_sorted', label, epoch),
246 | template='relation_examples.html')
247 |
248 | def _convert_gt(self, docs: List[Document]):
249 | for doc in docs:
250 | gt_relations = doc.relations
251 | gt_terms = doc.terms
252 |
253 | # convert ground truth relations and terms for precision/recall/f1 evaluation
254 | sample_gt_terms = [term.as_tuple() for term in gt_terms]
255 | sample_gt_relations = [rel.as_tuple() for rel in gt_relations]
256 |
257 | if self._no_overlapping:
258 | sample_gt_terms, sample_gt_relations = self._remove_overlapping(sample_gt_terms,
259 | sample_gt_relations)
260 |
261 | self._gt_terms.append(sample_gt_terms)
262 | self._gt_relations.append(sample_gt_relations)
263 |
264 | def _convert_pred_terms(self, pred_types: torch.tensor, pred_spans: torch.tensor, pred_scores: torch.tensor):
265 | converted_preds = []
266 |
267 | for i in range(pred_types.shape[0]):
268 | label_idx = pred_types[i].item()
269 | term_type = self._input_reader.get_term_type(label_idx)
270 |
271 | start, end = pred_spans[i].tolist()
272 | score = pred_scores[i].item()
273 |
274 | converted_pred = (start, end, term_type, score)
275 | converted_preds.append(converted_pred)
276 |
277 | return converted_preds
278 |
279 | def _convert_pred_relations(self, pred_rel_types: torch.tensor, pred_term_spans: torch.tensor,
280 | pred_term_types: torch.tensor, pred_scores: torch.tensor):
281 | converted_rels = []
282 | check = set()
283 |
284 | for i in range(pred_rel_types.shape[0]):
285 | label_idx = pred_rel_types[i].item()
286 | pred_rel_type = self._input_reader.get_relation_type(label_idx)
287 | pred_head_type_idx, pred_tail_type_idx = pred_term_types[i][0].item(), pred_term_types[i][1].item()
288 | pred_head_type = self._input_reader.get_term_type(pred_head_type_idx)
289 | pred_tail_type = self._input_reader.get_term_type(pred_tail_type_idx)
290 | score = pred_scores[i].item()
291 |
292 | spans = pred_term_spans[i]
293 | head_start, head_end = spans[0].tolist()
294 | tail_start, tail_end = spans[1].tolist()
295 |
296 | converted_rel = ((head_start, head_end, pred_head_type),
297 | (tail_start, tail_end, pred_tail_type), pred_rel_type)
298 | converted_rel = self._adjust_rel(converted_rel)
299 |
300 | if converted_rel not in check:
301 | check.add(converted_rel)
302 | converted_rels.append(tuple(list(converted_rel) + [score]))
303 |
304 | return converted_rels
305 |
306 | def _remove_overlapping(self, terms, relations):
307 | non_overlapping_terms = []
308 | non_overlapping_relations = []
309 |
310 | for term in terms:
311 | if not self._is_overlapping(term, terms):
312 | non_overlapping_terms.append(term)
313 |
314 | for rel in relations:
315 | e1, e2 = rel[0], rel[1]
316 | if not self._check_overlap(e1, e2):
317 | non_overlapping_relations.append(rel)
318 |
319 | return non_overlapping_terms, non_overlapping_relations
320 |
321 | def _is_overlapping(self, e1, terms):
322 | for e2 in terms:
323 | if self._check_overlap(e1, e2):
324 | return True
325 |
326 | return False
327 |
328 | def _check_overlap(self, e1, e2):
329 | if e1 == e2 or e1[1] <= e2[0] or e2[1] <= e1[0]:
330 | return False
331 | else:
332 | return True
333 |
334 | def _adjust_rel(self, rel: Tuple):
335 | adjusted_rel = rel
336 | if rel[-1].symmetric:
337 | head, tail = rel[:2]
338 | if tail[0] < head[0]:
339 | adjusted_rel = tail, head, rel[-1]
340 |
341 | return adjusted_rel
342 |
343 | def _convert_by_setting(self, gt: List[List[Tuple]], pred: List[List[Tuple]],
344 | include_term_types: bool = True, include_score: bool = False):
345 | assert len(gt) == len(pred)
346 |
347 | # either include or remove term types based on setting
348 | def convert(t):
349 | if not include_term_types:
350 | # remove term type and score for evaluation
351 | if type(t[0]) == int: # term
352 | c = [t[0], t[1], self._pseudo_term_type]
353 | else: # relation
354 | c = [(t[0][0], t[0][1], self._pseudo_term_type),
355 | (t[1][0], t[1][1], self._pseudo_term_type), t[2]]
356 | else:
357 | c = list(t[:3])
358 |
359 | if include_score and len(t) > 3:
360 | # include prediction scores
361 | c.append(t[3])
362 |
363 | return tuple(c)
364 |
365 | converted_gt, converted_pred = [], []
366 |
367 | for sample_gt, sample_pred in zip(gt, pred):
368 | converted_gt.append([convert(t) for t in sample_gt])
369 | converted_pred.append([convert(t) for t in sample_pred])
370 |
371 | return converted_gt, converted_pred
372 |
373 | def _score(self, gt: List[List[Tuple]], pred: List[List[Tuple]], print_results: bool = False):
374 | assert len(gt) == len(pred)
375 |
376 | gt_flat = []
377 | pred_flat = []
378 | types = set()
379 |
380 | for (sample_gt, sample_pred) in zip(gt, pred):
381 | union = set()
382 | union.update(sample_gt)
383 | union.update(sample_pred)
384 |
385 | for s in union:
386 | if s in sample_gt:
387 | t = s[2]
388 | gt_flat.append(t.index)
389 | types.add(t)
390 | else:
391 | gt_flat.append(0)
392 |
393 | if s in sample_pred:
394 | t = s[2]
395 | pred_flat.append(t.index)
396 | types.add(t)
397 | else:
398 | pred_flat.append(0)
399 |
400 | metrics = self._compute_metrics(gt_flat, pred_flat, types, print_results)
401 | return metrics
402 |
403 | def _compute_metrics(self, gt_all, pred_all, types, print_results: bool = False):
404 | labels = [t.index for t in types]
405 | per_type = prfs(gt_all, pred_all, labels=labels, average=None)
406 | micro = prfs(gt_all, pred_all, labels=labels, average='micro')[:-1]
407 | macro = prfs(gt_all, pred_all, labels=labels, average='macro')[:-1]
408 | total_support = sum(per_type[-1])
409 |
410 | if print_results:
411 | self._print_results(per_type, list(micro) + [total_support], list(macro) + [total_support], types)
412 |
413 | return [m * 100 for m in micro + macro]
414 |
415 | def _print_results(self, per_type: List, micro: List, macro: List, types: List):
416 | columns = ('type', 'precision', 'recall', 'f1-score', 'support')
417 |
418 | row_fmt = "%20s" + (" %12s" * (len(columns) - 1))
419 | results = [row_fmt % columns, '\n']
420 |
421 | metrics_per_type = []
422 | for i, t in enumerate(types):
423 | metrics = []
424 | for j in range(len(per_type)):
425 | metrics.append(per_type[j][i])
426 | metrics_per_type.append(metrics)
427 |
428 | for m, t in zip(metrics_per_type, types):
429 | results.append(row_fmt % self._get_row(m, t.short_name))
430 | results.append('\n')
431 |
432 | results.append('\n')
433 |
434 | # micro
435 | results.append(row_fmt % self._get_row(micro, 'micro'))
436 | results.append('\n')
437 |
438 | # macro
439 | results.append(row_fmt % self._get_row(macro, 'macro'))
440 |
441 | results_str = ''.join(results)
442 | print(results_str)
443 |
444 | def _get_row(self, data, label):
445 | row = [label]
446 | for i in range(len(data) - 1):
447 | row.append("%.2f" % (data[i] * 100))
448 | row.append(data[3])
449 | return tuple(row)
450 |
451 | def _convert_example(self, doc: Document, gt: List[Tuple], pred: List[Tuple],
452 | include_term_types: bool, to_html):
453 | encoding = doc.encoding
454 |
455 | gt, pred = self._convert_by_setting([gt], [pred], include_term_types=include_term_types, include_score=True)
456 | gt, pred = gt[0], pred[0]
457 |
458 | # get micro precision/recall/f1 scores
459 | if gt or pred:
460 | pred_s = [p[:3] for p in pred] # remove score
461 | precision, recall, f1 = self._score([gt], [pred_s])[:3]
462 | else:
463 | # corner case: no ground truth and no predictions
464 | precision, recall, f1 = [100] * 3
465 |
466 | scores = [p[-1] for p in pred]
467 | pred = [p[:-1] for p in pred]
468 | union = set(gt + pred)
469 |
470 | # true positives
471 | tp = []
472 | # false negatives
473 | fn = []
474 | # false positives
475 | fp = []
476 |
477 | for s in union:
478 | type_verbose = s[2].verbose_name
479 |
480 | if s in gt:
481 | if s in pred:
482 | score = scores[pred.index(s)]
483 | tp.append((to_html(s, encoding), type_verbose, score))
484 | else:
485 | fn.append((to_html(s, encoding), type_verbose, -1))
486 | else:
487 | score = scores[pred.index(s)]
488 | fp.append((to_html(s, encoding), type_verbose, score))
489 |
490 | tp = sorted(tp, key=lambda p: p[-1], reverse=True)
491 | fp = sorted(fp, key=lambda p: p[-1], reverse=True)
492 |
493 | text = self._prettify(self._text_encoder.decode(encoding))
494 | return dict(text=text, tp=tp, fn=fn, fp=fp, precision=precision, recall=recall, f1=f1, length=len(doc.tokens))
495 |
496 | def _term_to_html(self, term: Tuple, encoding: List[int]):
497 | start, end = term[:2]
498 | term_type = term[2].verbose_name
499 |
500 | tag_start = ' '
501 | tag_start += '%s' % term_type
502 |
503 | ctx_before = self._text_encoder.decode(encoding[:start])
504 | e1 = self._text_encoder.decode(encoding[start:end])
505 | ctx_after = self._text_encoder.decode(encoding[end:])
506 |
507 | html = ctx_before + tag_start + e1 + ' ' + ctx_after
508 | html = self._prettify(html)
509 |
510 | return html
511 |
512 | def _rel_to_html(self, relation: Tuple, encoding: List[int]):
513 | head, tail = relation[:2]
514 | head_tag = ' %s'
515 | tail_tag = ' %s'
516 |
517 | if head[0] < tail[0]:
518 | e1, e2 = head, tail
519 | e1_tag, e2_tag = head_tag % head[2].verbose_name, tail_tag % tail[2].verbose_name
520 | else:
521 | e1, e2 = tail, head
522 | e1_tag, e2_tag = tail_tag % tail[2].verbose_name, head_tag % head[2].verbose_name
523 |
524 | segments = [encoding[:e1[0]], encoding[e1[0]:e1[1]], encoding[e1[1]:e2[0]],
525 | encoding[e2[0]:e2[1]], encoding[e2[1]:]]
526 |
527 | ctx_before = self._text_encoder.decode(segments[0])
528 | e1 = self._text_encoder.decode(segments[1])
529 | ctx_between = self._text_encoder.decode(segments[2])
530 | e2 = self._text_encoder.decode(segments[3])
531 | ctx_after = self._text_encoder.decode(segments[4])
532 |
533 | html = (ctx_before + e1_tag + e1 + ' '
534 | + ctx_between + e2_tag + e2 + ' ' + ctx_after)
535 | html = self._prettify(html)
536 |
537 | return html
538 |
539 | def _prettify(self, text: str):
540 | text = text.replace('_start_', '').replace('_classify_', '').replace('', '').replace('⁇', '')
541 | text = text.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '')
542 | return text
543 |
544 | def _store_examples(self, examples: List[Dict], file_path: str, template: str):
545 | template_path = os.path.join(SCRIPT_PATH, 'templates', template)
546 |
547 | # read template
548 | with open(os.path.join(SCRIPT_PATH, template_path)) as f:
549 | template = jinja2.Template(f.read())
550 |
551 | # write to disc
552 | template.stream(examples=examples).dump(file_path)
553 |
--------------------------------------------------------------------------------
/SynFue/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn as nn
4 | from transformers import BertConfig
5 | from transformers import BertModel
6 | from transformers import BertPreTrainedModel
7 |
8 | from SynFue import sampling
9 | from SynFue import util
10 | from SynFue import Encoder
11 | from SynFue import cross_attn
12 |
13 |
14 | def get_token(h: torch.tensor, x: torch.tensor, token: int):
15 | """ Get specific token embedding (e.g. [CLS]) """
16 | emb_size = h.shape[-1]
17 |
18 | token_h = h.view(-1, emb_size)
19 | flat = x.contiguous().view(-1)
20 |
21 | # get contextualized embedding of given token
22 | token_h = token_h[flat == token, :]
23 |
24 | return token_h
25 |
26 |
27 | def get_head_tail_rep(h, head_tail_index):
28 | """
29 |
30 | :param h: torch.tensor [batch size, seq_len, feat_dim]
31 | :param head_tail_index: [batch size, term_num, 2]
32 | :return:
33 | """
34 | res = []
35 | batch_size = head_tail_index.size(0)
36 | term_num = head_tail_index.size(1)
37 | for b in range(batch_size):
38 | temp = []
39 | for t in range(term_num):
40 | temp.append(torch.index_select(h[b], 0, head_tail_index[b][t]).view(-1))
41 | res.append(torch.stack(temp, dim=0))
42 | res = torch.stack(res)
43 | return res
44 |
45 |
46 | class SynFueBERT(BertPreTrainedModel):
47 | """ Span-based model to jointly extract terms and relations """
48 |
49 | VERSION = '1.2'
50 |
51 | def __init__(self, config: BertConfig, cls_token: int, relation_types: int, term_types: int,
52 | size_embedding: int, prop_drop: float, freeze_transformer: bool, args, max_pairs: int = 100,
53 | beta: float = 0.3, alpha: float = 1.0, sigma: float = 1.0):
54 | super(SynFueBERT, self).__init__(config)
55 |
56 | # BERT model
57 | self.bert = BertModel(config)
58 | self.bert_dropout = nn.Dropout(args.bert_dropout)
59 | self.SynFue = Encoder.SynFueEncoder(opt=args)
60 | self.cc = cross_attn.CA_module(config.hidden_size, config.hidden_size, 1, dropout=1.0)
61 |
62 | # layers
63 | self.rel_classifier = nn.Linear(config.hidden_size * 4 + size_embedding * 2, relation_types)
64 | self.rel_classifier3 = nn.Linear(config.hidden_size * 3 + size_embedding * 3, relation_types)
65 | self.term_linear = nn.Linear(config.hidden_size * 7 + size_embedding, config.hidden_size)
66 | self.term_classifier = nn.Linear(config.hidden_size, term_types)
67 | self.dep_linear = nn.Linear(config.hidden_size, relation_types)
68 | self.size_embeddings = nn.Embedding(100, size_embedding)
69 | self.dropout = nn.Dropout(prop_drop)
70 |
71 | self._cls_token = cls_token
72 | self._relation_types = relation_types
73 | self._term_types = term_types
74 | self._max_pairs = max_pairs
75 | self._beta = beta
76 | self._alpha = alpha
77 | self._sigma = sigma
78 |
79 | # weight initialization
80 | self.init_weights()
81 |
82 | if freeze_transformer:
83 | print("Freeze transformer weights")
84 |
85 | # freeze all transformer weights
86 | for param in self.bert.parameters():
87 | param.requires_grad = False
88 |
89 | # def init_weights(self):
90 | # for name, param in self.SynFue.named_parameters():
91 | # if name.find('weight') != -1:
92 | # torch.nn.init.normal_(param, 0, 1)
93 | # elif name.find('bias') != -1:
94 | # torch.nn.init.constant_(param, 0)
95 | # else:
96 | # torch.nn.init.xavier_normal_(param)
97 | # for name, param in self.cc.named_parameters():
98 | # if name.find('weight') != -1:
99 | # torch.nn.init.normal_(param, 0, 1)
100 | # elif name.find('bias') != -1:
101 | # torch.nn.init.constant_(param, 0)
102 | # else:
103 | # torch.nn.init.xavier_normal_(param)
104 |
105 | def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor,
106 | term_sizes: torch.tensor, term_spans: torch.tensor, term_types: torch.tensor,
107 | relations: torch.tensor, rel_masks: torch.tensor,
108 | simple_graph: torch.tensor, graph: torch.tensor,
109 | relations3: torch.tensor, rel_masks3: torch.tensor, pair_mask: torch.tensor,
110 | pos: torch.tensor = None, pieces2word: torch.tensor = None):
111 | """
112 |
113 | :param encodings: [B, L']
114 | :param context_masks: [B, L']
115 | :param term_masks: [B, max_term, L]
116 | :param term_sizes: [B, max_term]
117 | :param term_spans: [B, max_term, 2]
118 | :param term_types: [B, max_term]
119 | :param relations: [B, max_relation, 2]
120 | :param rel_masks: [B, max_relation, L]
121 | :param simple_graph: [B, L, L]
122 | :param graph: [B, L, L]
123 | :param relations3: [B, max_relation, 3]
124 | :param rel_masks3:[B, max_relation, max_relation, L]
125 | :param pair_mask: [B, max_relation]
126 | :param pos: [B, L]
127 | :param pieces2word: [B, L, L']
128 | :return:
129 | """
130 | # get contextualized token embeddings from last transformer layer
131 | word_reps, cls_output = self._bert_encoder(input_ids=encodings, pieces2word=pieces2word)
132 | h, dep_output = self.SynFue(word_reps=word_reps, simple_graph=simple_graph, graph=graph, pos=pos)
133 |
134 | batch_size = encodings.shape[0]
135 |
136 | # classify terms
137 | size_embeddings = self.size_embeddings(term_sizes) # embed term candidate sizes
138 | term_clf, term_reps = self._classify_terms(h, term_masks, size_embeddings, cls_output)
139 |
140 | # classify relations
141 | h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) # B, max_rel, L, H
142 | rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to(
143 | self.rel_classifier.weight.device) # B, max_rel, relation_types
144 |
145 | # get span representation
146 | # dep_output = [B, L, L, H] -> [batch size, span num, span num, feat_dim]
147 | span_repr, mapping_list = self.get_span_repr(term_spans, term_types, dep_output)
148 | # apply across-attention to compute the syntax-aware inter-span representation
149 | cross_attn_span = self.cc(span_repr) # batch size, span_num, soan_num, feat_dim
150 |
151 | # obtain relation logits
152 | # chunk processing to reduce memory usage
153 | for i in range(0, relations.shape[1], self._max_pairs):
154 | # classify relation candidates
155 | chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations(cross_attn_span,
156 | term_reps,
157 | size_embeddings,
158 | relations, rel_masks,
159 | h_large, i,
160 | relations3, rel_masks3,
161 | pair_mask, mapping_list)
162 | # apply sigmoid
163 | chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
164 | chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score
165 | rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
166 |
167 | max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item())
168 | min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item())
169 | inifite = torch.full_like(rel_clf, 1e-18)
170 | rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite)
171 |
172 | return term_clf, rel_clf
173 |
174 | def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor,
175 | term_sizes: torch.tensor, term_spans: torch.tensor, term_sample_masks: torch.tensor,
176 | simple_graph: torch.tensor, graph: torch.tensor, pos: torch.tensor = None,
177 | pieces2word: torch.tensor = None):
178 | # get contextualized token embeddings from last transformer layer
179 | # context_masks = context_masks.float()
180 | word_reps, cls_output = self._bert_encoder(input_ids=encodings, pieces2word=pieces2word)
181 | h, dep_output = self.SynFue(word_reps=word_reps, simple_graph=simple_graph, graph=graph, pos=pos)
182 |
183 | batch_size = encodings.shape[0]
184 | ctx_size = term_masks.shape[-1]
185 |
186 | # classify terms
187 | size_embeddings = self.size_embeddings(term_sizes) # embed term candidate sizes
188 | term_clf, term_reps = self._classify_terms(h, term_masks, size_embeddings, cls_output)
189 |
190 | # ignore term candidates that do not constitute an actual term for relations (based on classifier)
191 | relations, rel_masks, rel_sample_masks, relations3, rel_masks3, \
192 | rel_sample_masks3, pair_mask, span_repr, mapping_list = self._filter_spans(term_clf, term_spans,
193 | term_sample_masks,
194 | ctx_size, dep_output)
195 |
196 | rel_sample_masks = rel_sample_masks.float().unsqueeze(-1)
197 | # h = self.rel_bert(input_ids=encodings, attention_mask=context_masks)[0]
198 | h_large = h.unsqueeze(1).repeat(1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
199 | rel_clf = torch.zeros([batch_size, relations.shape[1], self._relation_types]).to(
200 | self.rel_classifier.weight.device)
201 |
202 | # get span representation
203 | cross_attn_span = self.cc(span_repr) # batch size, seq_len, seq_len, feat_dim
204 |
205 | # obtain relation logits
206 | # chunk processing to reduce memory usage
207 | for i in range(0, relations.shape[1], self._max_pairs):
208 | # classify relation candidates
209 | chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations(cross_attn_span,
210 | term_reps,
211 | size_embeddings,
212 | relations, rel_masks,
213 | h_large, i,
214 | relations3, rel_masks3,
215 | pair_mask, mapping_list)
216 | # apply sigmoid
217 | chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
218 | chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score
219 | rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
220 |
221 | max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item())
222 | min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item())
223 | inifite = torch.full_like(rel_clf, 1e-18)
224 | rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite)
225 |
226 | rel_clf = rel_clf * rel_sample_masks # mask
227 |
228 | # apply softmax
229 | term_clf = torch.softmax(term_clf, dim=2)
230 |
231 | return term_clf, rel_clf, relations
232 |
233 | def _classify_terms(self, h, term_masks, size_embeddings, cls_output):
234 | """
235 |
236 | :param h: [b, L, H*2]
237 | :param term_masks: [B, max_term, L]
238 | :param size_embeddings: [b, max_term, size_embedding]
239 | :param cls_output: [B, H]
240 | :return:
241 | term_clf: [B, max_term, term_types]
242 | term_spans_pool: [B, max_term, H]
243 | """
244 | # max pool term candidate spans
245 | # m = (term_masks.unsqueeze(-1) == 0).float() * (-1e30)
246 | # term_spans_pool = m + h.unsqueeze(1).repeat(1, term_masks.shape[1], 1, 1)
247 | # term_spans_pool = term_spans_pool.max(dim=2)[0]
248 |
249 | min_value = torch.min(h).item()
250 | _h = h.unsqueeze(1).expand(-1, term_masks.size(1), -1, -1)
251 | _h = torch.masked_fill(_h, term_masks.eq(0).unsqueeze(-1), min_value)
252 | term_spans_pool, _ = torch.max(_h, dim=2)
253 |
254 | # get cls token as candidate context representation
255 | # term_ctx = get_token(h, encodings, self._cls_token)
256 |
257 | # get head and tail token representation
258 | m = term_masks.to(dtype=torch.long)
259 | k = torch.tensor(np.arange(0, term_masks.size(-1)), dtype=torch.long)
260 | k = k.unsqueeze(0).unsqueeze(0).repeat(term_masks.size(0), term_masks.size(1), 1).to(m.device)
261 | mk = torch.mul(m, k) # element-wise multiply
262 | mk_max = torch.argmax(mk, dim=-1, keepdim=True)
263 | mk_min = torch.argmin(mk, dim=-1, keepdim=True)
264 | mk = torch.cat([mk_min, mk_max], dim=-1)
265 | head_tail_rep = get_head_tail_rep(h, mk) # [batch size, term_num, bert_dim*2)
266 |
267 | # create candidate representations including context, max pooled span and size embedding, head and tail
268 | term_repr = torch.cat([cls_output.unsqueeze(1).repeat(1, term_spans_pool.shape[1], 1),
269 | term_spans_pool, size_embeddings, head_tail_rep], dim=2)
270 | term_repr = self.dropout(term_repr)
271 | term_repr = self.term_linear(term_repr)
272 |
273 | # classify term candidates
274 | term_clf = self.term_classifier(term_repr)
275 |
276 | return term_clf, term_repr
277 |
278 | def _classify_relations(self, spans_matrix, term_spans_repr, size_embeddings, relations, rel_masks,
279 | h, chunk_start, relations3, rel_masks3, pair_mask, rel_to_span):
280 | """
281 |
282 | :param spans_matrix: [B, max_term, max_term, H]
283 | :param term_spans_repr: [B, max_term, H]
284 | :param size_embeddings: [B, max_term, term_size]
285 | :param relations: [B, max_rel, 2]
286 | :param rel_masks: [B, max_rel, L]
287 | :param h: [B, max_rel, L, H*2]
288 | :param chunk_start:
289 | :param relations3: [B, max_rel, 3]
290 | :param rel_masks3: [B, max_rel, max_rel, L]
291 | :param pair_mask:
292 | :param rel_to_span:
293 | :return:
294 | """
295 | batch_size = relations.shape[0]
296 | feat_dim = spans_matrix.size(-1)
297 |
298 | # create chunks if necessary
299 | if relations.shape[1] > self._max_pairs:
300 | relations = relations[:, chunk_start:chunk_start + self._max_pairs]
301 | rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs]
302 | h = h[:, :relations.shape[1], :]
303 |
304 | def get_span_idx(mapping_list, idx1, idx2):
305 | for x in mapping_list:
306 | if idx1 == x[0][0] and idx2 == x[0][1]:
307 | return x[1][0], x[1][1]
308 |
309 | batch_dep_score = []
310 | for i in range(batch_size):
311 | _rel = relations[i]
312 | dep_score_list = []
313 | r_2_s = rel_to_span[i]
314 | for r in _rel:
315 | i1, i2 = r[0].item(), r[1].item()
316 | idx1, idx2 = get_span_idx(r_2_s, i1, i2)
317 | try:
318 | feat = spans_matrix[i][idx1][idx2]
319 | except:
320 | print('Out of bundary', spans_matrix.size(), i, i1, i2)
321 | feat = torch.zeros(feat_dim)
322 | dep_socre = self.dep_linear(feat).item()
323 | dep_score_list.append([dep_socre])
324 | batch_dep_score.append(dep_score_list)
325 |
326 | batch_dep_score = torch.sigmoid(
327 | torch.tensor(batch_dep_score).to(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
328 |
329 | # get pairs of term candidate representations
330 | term_pairs = util.batch_index(term_spans_repr, relations)
331 | term_pairs = term_pairs.view(batch_size, term_pairs.shape[1], -1)
332 |
333 | # get corresponding size embeddings
334 | size_pair_embeddings = util.batch_index(size_embeddings, relations)
335 | size_pair_embeddings = size_pair_embeddings.view(batch_size, size_pair_embeddings.shape[1], -1)
336 |
337 | # relation context (context between term candidate pair)
338 | # mask non term candidate tokens
339 | m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1)
340 | rel_ctx = m + h
341 | # max pooling
342 | rel_ctx = rel_ctx.max(dim=2)[0]
343 | # set the context vector of neighboring or adjacent term candidates to zero
344 | rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0
345 |
346 | # create relation candidate representations including context, max pooled term candidate pairs
347 | # and corresponding size embeddings
348 | rel_repr = torch.cat([rel_ctx, term_pairs, size_pair_embeddings], dim=2)
349 | rel_repr = self.dropout(rel_repr)
350 | # classify relation candidates
351 | chunk_rel_logits = self.rel_classifier(rel_repr)
352 |
353 | if relations3.shape[1] > self._max_pairs:
354 | relations3 = relations3[:, chunk_start:chunk_start + self._max_pairs]
355 | # rel_masks3 = rel_masks3[:, chunk_start:chunk_start + self._max_pairs]
356 |
357 | p_num = relations3.size(1)
358 | p_tris = relations3.size(2)
359 |
360 | relations3 = relations3.view(batch_size, -1, 3)
361 |
362 | # get three pairs candidata representations
363 | term_pairs3 = util.batch_index(term_spans_repr, relations3)
364 | term_pairs3 = term_pairs3.view(batch_size, term_pairs3.shape[1], -1)
365 |
366 | size_pair_embeddings3 = util.batch_index(size_embeddings, relations3)
367 | size_pair_embeddings3 = size_pair_embeddings3.view(batch_size, size_pair_embeddings3.shape[1], -1)
368 |
369 | rel_repr = torch.cat([term_pairs3, size_pair_embeddings3], dim=2)
370 | rel_repr = self.dropout(rel_repr)
371 | # classify relation candidates
372 | chunk_rel_logits3 = self.rel_classifier3(rel_repr)
373 |
374 | chunk_rel_clf3 = chunk_rel_logits3.view(batch_size, p_num, p_tris, -1)
375 | chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3)
376 |
377 | chunk_rel_clf3 = torch.sum(chunk_rel_clf3, dim=2)
378 | chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3)
379 |
380 | return chunk_rel_logits, chunk_rel_clf3, batch_dep_score
381 |
382 | def _filter_spans(self, term_clf, term_spans, term_sample_masks, ctx_size, token_repr):
383 | """
384 | according to the results of term type detection, we first filter the invalid term, i.e., only keeping aspect
385 | term and opinion term, and then we construct term pairs and term triplets for the following relations detection.
386 | :param term_clf: [B, max_term, term_types]
387 | :param term_spans: [B, max_term, 2]
388 | :param term_sample_masks:
389 | :param ctx_size: L
390 | :param token_repr: [B, L, H]
391 | :return:
392 | """
393 | batch_size = term_clf.shape[0]
394 | feat_dim = token_repr.size(-1)
395 | term_logits_max = term_clf.argmax(dim=-1) * term_sample_masks.long() # get term type (including none)
396 | batch_relations = []
397 | batch_rel_masks = []
398 | batch_rel_sample_masks = []
399 |
400 | batch_relations3 = []
401 | batch_rel_masks3 = []
402 | batch_rel_sample_masks3 = []
403 | batch_pair_mask = []
404 |
405 | batch_span_repr = []
406 | batch_rel_to_span = []
407 |
408 | for i in range(batch_size):
409 | rels = []
410 | rel_masks = []
411 | sample_masks = []
412 | rels3 = []
413 | rel_masks3 = []
414 | sample_masks3 = []
415 |
416 | span_repr = []
417 | rel_to_span = []
418 |
419 | # get spans classified as terms
420 | non_zero_indices = (term_logits_max[i] != 0).nonzero().view(-1)
421 | non_zero_spans = term_spans[i][non_zero_indices].tolist()
422 | non_zero_indices = non_zero_indices.tolist()
423 |
424 | # create relations and masks
425 | pair_mask = []
426 | for idx1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)):
427 | temp = []
428 | for idx2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)):
429 | if i1 != i2:
430 | rels.append((i1, i2))
431 | rel_masks.append(sampling.create_rel_mask(s1, s2, ctx_size))
432 | sample_masks.append(1)
433 | p_rels3 = []
434 | p_masks3 = []
435 | for i3, s3 in zip(non_zero_indices, non_zero_spans):
436 | if i1 != i2 and i1 != i3 and i2 != i3:
437 | p_rels3.append((i1, i2, i3))
438 | p_masks3.append(sampling.create_rel_mask3(s1, s2, s3, ctx_size))
439 | sample_masks3.append(1)
440 | if len(p_rels3) > 0:
441 | rels3.append(p_rels3)
442 | rel_masks3.append(p_masks3)
443 | pair_mask.append(1)
444 | else:
445 | rels3.append([(i1, i2, 0)])
446 | rel_masks3.append([sampling.create_rel_mask3(s1, s2, (0, 0), ctx_size)])
447 | pair_mask.append(0)
448 | rel_to_span.append([[i1, i2], [idx1, idx2]])
449 | feat = \
450 | torch.max(token_repr[i, s1[0]: s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view(-1, feat_dim),
451 | dim=0)[0]
452 | temp.append(feat)
453 | span_repr.append(temp)
454 |
455 | if not rels:
456 | # case: no more than two spans classified as terms
457 | batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long))
458 | batch_rel_masks.append(torch.tensor([[0] * ctx_size], dtype=torch.bool))
459 | batch_rel_sample_masks.append(torch.tensor([0], dtype=torch.bool))
460 | batch_span_repr.append(torch.tensor([[[0] * feat_dim]], dtype=torch.float))
461 | batch_rel_to_span.append([[[0, 0], [0, 0]]])
462 | else:
463 | # case: more than two spans classified as terms
464 | batch_relations.append(torch.tensor(rels, dtype=torch.long))
465 | batch_rel_masks.append(torch.stack(rel_masks))
466 | batch_rel_sample_masks.append(torch.tensor(sample_masks, dtype=torch.bool))
467 | batch_span_repr.append(torch.stack([torch.stack(x) for x in span_repr]))
468 | batch_rel_to_span.append(rel_to_span)
469 |
470 | if not rels3:
471 | batch_relations3.append(torch.tensor([[[0, 0, 0]]], dtype=torch.long))
472 | batch_rel_masks3.append(torch.tensor([[0] * ctx_size], dtype=torch.bool))
473 | batch_rel_sample_masks3.append(torch.tensor([0], dtype=torch.bool))
474 | batch_pair_mask.append(torch.tensor([0], dtype=torch.bool))
475 |
476 | else:
477 | max_tri = max([len(x) for x in rels3])
478 | # print(max_tri)
479 | for idx, r in enumerate(rels3):
480 | r_len = len(r)
481 | if r_len < max_tri:
482 | rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len))
483 | rel_masks3[idx].extend(
484 | [rel_masks3[idx][0]] * (max_tri - r_len))
485 | batch_relations3.append(torch.tensor(rels3, dtype=torch.long))
486 | batch_rel_masks3.append(torch.stack([torch.stack(x) for x in rel_masks3]))
487 | batch_rel_sample_masks3.append(torch.tensor(sample_masks3, dtype=torch.bool))
488 | batch_pair_mask.append(torch.tensor(pair_mask, dtype=torch.bool))
489 |
490 | # stack
491 | device = self.rel_classifier.weight.device
492 | batch_relations = util.padded_stack(batch_relations).to(device)
493 | batch_rel_masks = util.padded_stack(batch_rel_masks).to(device)
494 | batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(device)
495 | batch_span_repr = util.padded_stack(batch_span_repr).to(device)
496 |
497 | batch_relations3 = util.padded_stack(batch_relations3).to(device)
498 | batch_rel_masks3 = util.padded_stack(batch_rel_masks3).to(device)
499 | batch_rel_sample_masks3 = util.padded_stack(batch_rel_sample_masks3).to(device)
500 | batch_pair_mask = util.padded_stack(batch_pair_mask).to(device)
501 |
502 | return batch_relations, batch_rel_masks, batch_rel_sample_masks, \
503 | batch_relations3, batch_rel_masks3, batch_rel_sample_masks3, \
504 | batch_pair_mask, batch_span_repr, batch_rel_to_span
505 |
506 | def get_span_repr(self, term_spans, term_types, token_repr):
507 | """
508 |
509 | :param term_spans: [B, span_num, 2]
510 | :param term_types: [B, span_num]
511 | :param token_repr: [B, L, L, feat_dim]
512 | :return:
513 | batch_span_repr: [B, span_num, span_num, feat_dim]
514 | batch_mapping_list: B. It is used to store the correspondence between the index in span matrix (x1, x2)
515 | and the index in positive term span list (i1, i2).
516 | """
517 | batch_size = term_spans.size(0)
518 | feat_dim = token_repr.size(-1)
519 | batch_span_repr = []
520 | batch_mapping_list = []
521 | for i in range(batch_size):
522 | span_repr = []
523 | mapping_list = []
524 | # get target spans as aspect term or opinion term
525 | non_zero_indices = (term_types[i] != 0).nonzero().view(-1)
526 | non_zero_spans = term_spans[i][non_zero_indices].tolist()
527 | non_zero_indices = non_zero_indices.tolist()
528 | for x1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)):
529 | temp = []
530 | for x2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)):
531 | feat = \
532 | torch.max(token_repr[i, s1[0]: s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view(-1, feat_dim),
533 | dim=0)[0]
534 | temp.append(feat)
535 | mapping_list.append([[i1, i2], [x1, x2]])
536 |
537 | span_repr.append(torch.stack(temp))
538 | batch_span_repr.append(torch.stack(span_repr))
539 | batch_mapping_list.append(mapping_list)
540 |
541 | device = self.rel_classifier.weight.device
542 | batch_span_repr = util.padded_stack(batch_span_repr).to(device)
543 |
544 | return batch_span_repr, batch_mapping_list
545 |
546 | def _bert_encoder(self, input_ids, pieces2word):
547 | """
548 |
549 | :param input_ids: [B, L'], L' not equal L
550 | :param pieces2word: [B, L, L']
551 | :return:
552 | word_reps: [B, L, H]
553 | pooler_output: [B, H]
554 | """
555 | # sequence_output, pooled_output = self.bert(input_ids)
556 | bert_output = self.bert(input_ids=input_ids, attention_mask=input_ids.ne(0).float())
557 | sequence_output, pooler_output = bert_output[0], bert_output[1]
558 | bert_embs = self.bert_dropout(sequence_output)
559 |
560 | length = pieces2word.size(1)
561 | min_value = torch.min(bert_embs).item()
562 |
563 | # Max pooling word representations from pieces
564 | _bert_embs = bert_embs.unsqueeze(1).expand(-1, length, -1, -1)
565 | _bert_embs = torch.masked_fill(_bert_embs, pieces2word.eq(0).unsqueeze(-1), min_value)
566 | word_reps, _ = torch.max(_bert_embs, dim=2)
567 | return word_reps, pooler_output
568 |
569 | def forward(self, *args, evaluate=False, **kwargs):
570 | if not evaluate:
571 | return self._forward_train(*args, **kwargs)
572 | else:
573 | return self._forward_eval(*args, **kwargs)
574 |
575 |
576 | # Model access
577 |
578 | _MODELS = {
579 | 'synfue': SynFueBERT,
580 | }
581 |
582 |
583 | def get_model(name):
584 | return _MODELS[name]
585 |
--------------------------------------------------------------------------------