├── .DS_Store
├── .idea
├── .gitignore
├── CasRel_fastNLP.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── config.py
├── data
├── .DS_Store
├── NYT
│ ├── README.md
│ ├── build_data.py
│ ├── raw_NYT
│ │ ├── README.md
│ │ └── generate.py
│ ├── test_split_by_num
│ │ ├── README.md
│ │ └── split_by_num.py
│ └── test_split_by_type
│ │ ├── README.md
│ │ └── generate_triples.py
└── WebNLG
│ ├── README.md
│ ├── build_data.py
│ ├── raw_WebNLG
│ ├── README.md
│ └── generate.py
│ ├── test_split_by_num
│ ├── README.md
│ └── split_by_num.py
│ └── test_split_by_type
│ ├── README.md
│ └── generate_triples.py
├── data_loader.py
├── model.py
├── pretrained_bert_models
└── .DS_Store
├── run.py
├── saved_logs
├── .DS_Store
└── NYT
│ └── .DS_Store
├── saved_weights
├── .DS_Store
└── NYT
│ └── .DS_Store
├── test.py
├── train.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/CasRel_fastNLP.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CasRel_fastNLP
2 | This repo is [fastNLP](https://fastnlp.readthedocs.io/zh/latest/) reimplementation of the paper: [**"A Novel Cascade Binary Tagging Framework for Relational Triple Extraction"**](https://www.aclweb.org/anthology/2020.acl-main.136.pdf), which was published in ACL2020. The [original code](https://github.com/weizhepei/CasRel) was written in keras.
3 |
4 | ## Requirements
5 |
6 | - Python 3.8
7 | - Pytorch 1.7
8 | - fastNLP 0.6.0
9 | - keras-bert 0.86.0
10 | - numpy 1.19.1
11 | - transformers 4.0.0
12 |
13 | Other dependent packages described in [fastNLP Docs](https://fastnlp.readthedocs.io/zh/latest/user/installation.html).
14 |
15 | ## Datasets
16 |
17 | - [NYT](https://github.com/weizhepei/CasRel/tree/master/data/NYT)
18 | - [WebNLG](https://github.com/weizhepei/CasRel/tree/master/data/WebNLG)
19 |
20 | ## Usage
21 |
22 | 2. **Build dataset in the form of triples** (DONE)
23 | Take the NYT dataset for example:
24 |
25 | a) Switch to the corresponding directory and download the dataset
26 |
27 | ```
28 | cd CasRel/data/NYT/raw_NYT
29 | ```
30 |
31 | b) Follow the instructions at the same directory, and just run
32 |
33 | ```
34 | python generate.py
35 | ```
36 |
37 | c) Finally, build dataset in the form of triples
38 |
39 | ```
40 | cd CasRel/data/NYT
41 | python build_data.py
42 | ```
43 |
44 | This will convert the raw numerical dataset into a proper format for our model and generate `train.json`, `test.json` and `dev.json`. Then split the test dataset by type and num for in-depth analysis on different scenarios of overlapping triples. (Run `python generate.py` under corresponding folders)
45 |
46 | - NYT:
47 | - Train: 56195, dev: 4999, test: 5000
48 | - normal : EPO : SEO = 3266 : 978 : 1297
49 | - WebNLG:
50 | - Train: 5019, dev: 703, test: 500
51 | - normal : EPO : SEO = 182 : 16 : 318
52 |
53 | 2. **Specify the experimental settings** (DONE)
54 |
55 | By default, we use the following settings in train.py:
56 |
57 | ```
58 | {
59 | "model_name": "CasRel",
60 | "dataset": "NYT",
61 | "bert_model_name": "bert-base-cased",
62 | "lr": 1e-5,
63 | "nulti-gpu": False,
64 | "batch_size": 6,
65 | "max_epoch": 100,
66 | "test_epoch": 5,
67 | "max_len": 100,
68 | "period": 50,
69 | }
70 | ```
71 |
72 | 3. **Train and select the model**
73 | - [ ] Now trained in cpu, plan to move to gpu later.
74 | - [ ] Now define my own train model, plan to use Trainer class in fastNLP.
75 |
76 | 4. **Evaluate on the test set**
77 |
78 |
79 |
80 | ## Results
81 |
82 |
83 |
84 | ## References
85 |
86 | [1] https://github.com/weizhepei/CasRel
87 |
88 | [2] https://github.com/longlongman/CasRel-pytorch-reimplement
89 |
90 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | class Config(object):
2 | def __init__(self, args):
3 | self.args = args
4 |
5 | # dataset
6 | self.dataset = args.dataset
7 |
8 | # train hyper parameters
9 | self.multi_gpu = args.multi_gpu
10 | self.learning_rate = args.lr
11 | self.batch_size = args.batch_size
12 | self.max_epoch = args.max_epoch
13 | self.max_len = args.max_len
14 |
15 | # path and name
16 | self.bert_model_name = args.bert_model_name
17 | self.train_path = 'data/' + self.dataset + '/train_triples.json'
18 | self.dev_path = 'data/' + self.dataset + '/dev_triples.json'
19 | self.test_path = 'data/' + self.dataset + '/test_triples.json' # overall test
20 | self.rel_dict_path = 'data/' + self.dataset + '/rel2id.json'
21 | self.save_weights_dir = 'saved_weights/' + self.dataset + '/'
22 | self.save_logs_dir = 'saved_logs/' + self.dataset + '/'
23 | self.result_dir = 'results/' + self.dataset
24 | self.weights_save_name = args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(
25 | self.learning_rate) + "_BS_" + str(self.batch_size)
26 | self.log_save_name = 'LOG_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(
27 | self.learning_rate) + "_BS_" + str(self.batch_size)
28 | self.result_save_name = 'RESULT_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(
29 | self.learning_rate) + "_BS_" + str(self.batch_size) + ".json"
30 |
31 | # log setting
32 | self.period = args.period
33 | self.test_epoch = args.test_epoch
34 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/data/.DS_Store
--------------------------------------------------------------------------------
/data/NYT/README.md:
--------------------------------------------------------------------------------
1 | First download and process the data at dir raw_NYT/
2 |
3 | Then run build_data.py to get triple files.
4 |
--------------------------------------------------------------------------------
/data/NYT/build_data.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 |
4 | import json
5 | from tqdm import tqdm
6 | import codecs
7 |
8 | rel_set = set()
9 |
10 |
11 | train_data = []
12 | dev_data = []
13 | test_data = []
14 |
15 | with open('train.json') as f:
16 | for l in tqdm(f):
17 | a = json.loads(l)
18 | if not a['relationMentions']:
19 | continue
20 | line = {
21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
23 | }
24 | if not line['triple_list']:
25 | continue
26 | train_data.append(line)
27 | for rm in a['relationMentions']:
28 | if rm['label'] != 'None':
29 | rel_set.add(rm['label'])
30 |
31 |
32 | with open('dev.json') as f:
33 | for l in tqdm(f):
34 | a = json.loads(l)
35 | if not a['relationMentions']:
36 | continue
37 | line = {
38 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
39 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
40 | }
41 | if not line['triple_list']:
42 | continue
43 | dev_data.append(line)
44 | for rm in a['relationMentions']:
45 | if rm['label'] != 'None':
46 | rel_set.add(rm['label'])
47 |
48 | cnt = 0
49 | with open('test.json') as f:
50 | for l in tqdm(f):
51 | a = json.loads(l)
52 | if not a['relationMentions']:
53 | continue
54 | line = {
55 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
56 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
57 | }
58 | if not line['triple_list']:
59 | continue
60 | cnt += len(line['triple_list'])
61 | test_data.append(line)
62 |
63 | print(f'test triples:{cnt}')
64 |
65 |
66 | id2predicate = {i:j for i,j in enumerate(sorted(rel_set))}
67 | predicate2id = {j:i for i,j in id2predicate.items()}
68 |
69 |
70 | with codecs.open('rel2id.json', 'w', encoding='utf-8') as f:
71 | json.dump([id2predicate, predicate2id], f, indent=4, ensure_ascii=False)
72 |
73 |
74 | with codecs.open('train_triples.json', 'w', encoding='utf-8') as f:
75 | json.dump(train_data, f, indent=4, ensure_ascii=False)
76 |
77 |
78 | with codecs.open('dev_triples.json', 'w', encoding='utf-8') as f:
79 | json.dump(dev_data, f, indent=4, ensure_ascii=False)
80 |
81 |
82 | with codecs.open('test_triples.json', 'w', encoding='utf-8') as f:
83 | json.dump(test_data, f, indent=4, ensure_ascii=False)
84 |
--------------------------------------------------------------------------------
/data/NYT/raw_NYT/README.md:
--------------------------------------------------------------------------------
1 | Download CopyR's NYT data at https://drive.google.com/open?id=10f24s9gM7NdyO3z5OqQxJgYud4NnCJg3
2 |
3 | Run generate.py to convert the numerical mentions into string mentions. Move the converted files to NYT/
4 |
5 | Also, the test data should be split by overlapping types, i.e., Normal, SEO, and EPO. Move the split files to NYT/test_split_by_type/
6 |
--------------------------------------------------------------------------------
/data/NYT/raw_NYT/generate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 |
4 | def is_normal_triple(triples, is_relation_first=False):
5 | entities = set()
6 | for i, e in enumerate(triples):
7 | key = 0 if is_relation_first else 2
8 | if i % 3 != key:
9 | entities.add(e)
10 | return len(entities) == 2 * int(len(triples) / 3)
11 |
12 | def is_multi_label(triples, is_relation_first=False):
13 | if is_normal_triple(triples, is_relation_first):
14 | return False
15 | if is_relation_first:
16 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))]
17 | else:
18 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))]
19 | # if is multi label, then, at least one entity pair appeared more than once
20 | return len(entity_pair) != len(set(entity_pair))
21 |
22 | def is_over_lapping(triples, is_relation_first=False):
23 | if is_normal_triple(triples, is_relation_first):
24 | return False
25 | if is_relation_first:
26 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))]
27 | else:
28 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))]
29 | # remove the same entity_pair, then, if one entity appear more than once, it's overlapping
30 | entity_pair = set(entity_pair)
31 | entities = []
32 | for pair in entity_pair:
33 | entities.extend(pair)
34 | entities = set(entities)
35 | return len(entities) != 2 * len(entity_pair)
36 |
37 | def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file):
38 | with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, open(epo_file, 'w') as f4, open(seo_file, 'w') as f5:
39 | cnt_normal = 0
40 | cnt_epo = 0
41 | cnt_seo = 0
42 | lines = f1.readlines()
43 | for line in lines:
44 | line = json.loads(line)
45 | print(len(line))
46 | lengths, sents, spos = line[0], line[1], line[2]
47 | print(len(spos))
48 | print(len(sents))
49 | for i in range(len(sents)):
50 | new_line = dict()
51 | #print(sents[i])
52 | #print(spos[i])
53 | tokens = [word_dict[i] for i in sents[i]]
54 | sent = ' '.join(tokens)
55 | new_line['sentText'] = sent
56 | triples = np.reshape(spos[i], (-1,3))
57 | relationMentions = []
58 | for triple in triples:
59 | rel = dict()
60 | rel['em1Text'] = tokens[triple[0]]
61 | rel['em2Text'] = tokens[triple[1]]
62 | rel['label'] = rel_dict[triple[2]]
63 | relationMentions.append(rel)
64 | new_line['relationMentions'] = relationMentions
65 | f2.write(json.dumps(new_line) + '\n')
66 | if is_normal_triple(spos[i]):
67 | f3.write(json.dumps(new_line) + '\n')
68 | if is_multi_label(spos[i]):
69 | f4.write(json.dumps(new_line) + '\n')
70 | if is_over_lapping(spos[i]):
71 | f5.write(json.dumps(new_line) + '\n')
72 |
73 | if __name__ == '__main__':
74 | file_name = 'valid.json'
75 | output = 'new_valid.json'
76 | output_normal = 'new_valid_normal.json'
77 | output_epo = 'new_valid_epo.json'
78 | output_seo = 'new_valid_seo.json'
79 | with open('relations2id.json', 'r') as f1, open('words2id.json', 'r') as f2:
80 | rel2id = json.load(f1)
81 | words2id = json.load(f2)
82 | rel_dict = {j:i for i,j in rel2id.items()}
83 | word_dict = {j:i for i,j in words2id.items()}
84 | load_data(file_name, word_dict, rel_dict, output, output_normal, output_epo, output_seo)
85 |
--------------------------------------------------------------------------------
/data/NYT/test_split_by_num/README.md:
--------------------------------------------------------------------------------
1 | Copy ../test.json here, and run split_by_num.py to split the test data by triple nums.
2 |
--------------------------------------------------------------------------------
/data/NYT/test_split_by_num/split_by_num.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 |
4 | import json
5 | from tqdm import tqdm
6 | import codecs
7 |
8 |
9 | test_1 = []
10 | test_2 = []
11 | test_3 = []
12 | test_4 = []
13 | test_other = []
14 |
15 | with open('test.json') as f:
16 | for l in tqdm(f):
17 | a = json.loads(l)
18 | if not a['relationMentions']:
19 | continue
20 | line = {
21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
23 | }
24 | if not line['triple_list']:
25 | continue
26 | spo_num = len(line['triple_list'])
27 | if spo_num == 1:
28 | test_1.append(line)
29 | elif spo_num == 2:
30 | test_2.append(line)
31 | elif spo_num == 3:
32 | test_3.append(line)
33 | elif spo_num == 4:
34 | test_4.append(line)
35 | else:
36 | test_other.append(line)
37 |
38 |
39 | with codecs.open('test_triples_1.json', 'w', encoding='utf-8') as f:
40 | json.dump(test_1, f, indent=4, ensure_ascii=False)
41 |
42 | with codecs.open('test_triples_2.json', 'w', encoding='utf-8') as f:
43 | json.dump(test_2, f, indent=4, ensure_ascii=False)
44 |
45 | with codecs.open('test_triples_3.json', 'w', encoding='utf-8') as f:
46 | json.dump(test_3, f, indent=4, ensure_ascii=False)
47 |
48 | with codecs.open('test_triples_4.json', 'w', encoding='utf-8') as f:
49 | json.dump(test_4, f, indent=4, ensure_ascii=False)
50 |
51 | with codecs.open('test_triples_5.json', 'w', encoding='utf-8') as f:
52 | json.dump(test_other, f, indent=4, ensure_ascii=False)
53 |
--------------------------------------------------------------------------------
/data/NYT/test_split_by_type/README.md:
--------------------------------------------------------------------------------
1 | Run generate_triples.py to get triple files.
2 |
--------------------------------------------------------------------------------
/data/NYT/test_split_by_type/generate_triples.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 |
4 | import json
5 | from tqdm import tqdm
6 | import codecs
7 |
8 |
9 | test_normal = []
10 | test_epo = []
11 | test_seo = []
12 |
13 | with open('test_normal.json') as f:
14 | for l in tqdm(f):
15 | a = json.loads(l)
16 | if not a['relationMentions']:
17 | continue
18 | line = {
19 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
20 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
21 | }
22 | if not line['triple_list']:
23 | continue
24 | spo_num = len(line['triple_list'])
25 | test_normal.append(line)
26 |
27 | with open('test_epo.json') as f:
28 | for l in tqdm(f):
29 | a = json.loads(l)
30 | if not a['relationMentions']:
31 | continue
32 | line = {
33 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
34 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
35 | }
36 | if not line['triple_list']:
37 | continue
38 | spo_num = len(line['triple_list'])
39 | test_epo.append(line)
40 |
41 | with open('test_seo.json') as f:
42 | for l in tqdm(f):
43 | a = json.loads(l)
44 | if not a['relationMentions']:
45 | continue
46 | line = {
47 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
48 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
49 | }
50 | if not line['triple_list']:
51 | continue
52 | spo_num = len(line['triple_list'])
53 | test_seo.append(line)
54 |
55 | with codecs.open('test_triples_normal.json', 'w', encoding='utf-8') as f:
56 | json.dump(test_normal, f, indent=4, ensure_ascii=False)
57 |
58 | with codecs.open('test_triples_epo.json', 'w', encoding='utf-8') as f:
59 | json.dump(test_epo, f, indent=4, ensure_ascii=False)
60 |
61 | with codecs.open('test_triples_seo.json', 'w', encoding='utf-8') as f:
62 | json.dump(test_seo, f, indent=4, ensure_ascii=False)
63 |
64 |
--------------------------------------------------------------------------------
/data/WebNLG/README.md:
--------------------------------------------------------------------------------
1 | First, download and process the raw WebNLG data in dir raw_WebNLG/
2 |
3 | Then run build_data.py to get triple files.
4 |
5 | For WebNLG, please train the model for at least 10 epochs.
6 |
7 |
--------------------------------------------------------------------------------
/data/WebNLG/build_data.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 |
4 | import json
5 | from tqdm import tqdm
6 | import codecs
7 |
8 | rel_set = set()
9 |
10 |
11 | train_data = []
12 | dev_data = []
13 | test_data = []
14 |
15 | with open('train.json') as f:
16 | for l in tqdm(f):
17 | a = json.loads(l)
18 | if not a['relationMentions']:
19 | continue
20 | line = {
21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
23 | }
24 | if not line['triple_list']:
25 | continue
26 | train_data.append(line)
27 | for rm in a['relationMentions']:
28 | if rm['label'] != 'None':
29 | rel_set.add(rm['label'])
30 |
31 |
32 | with open('dev.json') as f:
33 | for l in tqdm(f):
34 | a = json.loads(l)
35 | if not a['relationMentions']:
36 | continue
37 | line = {
38 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
39 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
40 | }
41 | if not line['triple_list']:
42 | continue
43 | dev_data.append(line)
44 | for rm in a['relationMentions']:
45 | if rm['label'] != 'None':
46 | rel_set.add(rm['label'])
47 |
48 |
49 | with open('test.json') as f:
50 | for l in tqdm(f):
51 | a = json.loads(l)
52 | if not a['relationMentions']:
53 | continue
54 | line = {
55 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
56 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
57 | }
58 | if not line['triple_list']:
59 | continue
60 | test_data.append(line)
61 |
62 |
63 |
64 | id2rel = {i:j for i,j in enumerate(sorted(rel_set))}
65 | rel2id = {j:i for i,j in id2rel.items()}
66 |
67 |
68 | with codecs.open('rel2id.json', 'w', encoding='utf-8') as f:
69 | json.dump([id2rel, rel2id], f, indent=4, ensure_ascii=False)
70 |
71 |
72 | with codecs.open('train_triples.json', 'w', encoding='utf-8') as f:
73 | json.dump(train_data, f, indent=4, ensure_ascii=False)
74 |
75 |
76 | with codecs.open('dev_triples.json', 'w', encoding='utf-8') as f:
77 | json.dump(dev_data, f, indent=4, ensure_ascii=False)
78 |
79 |
80 | with codecs.open('test_triples.json', 'w', encoding='utf-8') as f:
81 | json.dump(test_data, f, indent=4, ensure_ascii=False)
82 |
--------------------------------------------------------------------------------
/data/WebNLG/raw_WebNLG/README.md:
--------------------------------------------------------------------------------
1 | Download the CopyR's WebNLG dataset at https://drive.google.com/open?id=1zISxYa-8ROe2Zv8iRc82jY9QsQrfY1Vj
2 |
3 | Unzip it and all we need is webnlg/pre_processed_data/*
4 |
5 | Run generate.py to transfromer the numerical mentions to string mentions. Move the transformed files to ../
6 |
7 | dev.json -> new_dev.json (then rename it test.json)
8 | train.json -> new_train.json (then rename it train.json)
9 | valid.json -> new_valid.json (then rename it dev.json)
10 |
11 | Also, the test data will be split by sentence types, i.e., Normal, SEO, and EPO. Then move the split files to ../test_split_by_type/
12 |
13 |
--------------------------------------------------------------------------------
/data/WebNLG/raw_WebNLG/generate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 |
4 | def is_normal_triple(triples, is_relation_first=False):
5 | entities = set()
6 | for i, e in enumerate(triples):
7 | key = 0 if is_relation_first else 2
8 | if i % 3 != key:
9 | entities.add(e)
10 | return len(entities) == 2 * int(len(triples) / 3)
11 |
12 | def is_multi_label(triples, is_relation_first=False):
13 | if is_normal_triple(triples, is_relation_first):
14 | return False
15 | if is_relation_first:
16 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))]
17 | else:
18 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))]
19 | # if is multi label, then, at least one entity pair appeared more than once
20 | return len(entity_pair) != len(set(entity_pair))
21 |
22 | def is_over_lapping(triples, is_relation_first=False):
23 | if is_normal_triple(triples, is_relation_first):
24 | return False
25 | if is_relation_first:
26 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))]
27 | else:
28 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))]
29 | # remove the same entity_pair, then, if one entity appear more than once, it's overlapping
30 | entity_pair = set(entity_pair)
31 | entities = []
32 | for pair in entity_pair:
33 | entities.extend(pair)
34 | entities = set(entities)
35 | return len(entities) != 2 * len(entity_pair)
36 |
37 | def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file):
38 | with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, open(epo_file, 'w') as f4, open(seo_file, 'w') as f5:
39 | cnt_normal = 0
40 | cnt_epo = 0
41 | cnt_seo = 0
42 | lines = f1.readlines()
43 | for line in lines:
44 | line = json.loads(line)
45 | print(len(line))
46 | sents, spos = line[0], line[1]
47 | print(len(spos))
48 | print(len(sents))
49 | for i in range(len(sents)):
50 | new_line = dict()
51 | #print(sents[i])
52 | #print(spos[i])
53 | tokens = [word_dict[i] for i in sents[i]]
54 | sent = ' '.join(tokens)
55 | new_line['sentText'] = sent
56 | triples = np.reshape(spos[i], (-1,3))
57 | relationMentions = []
58 | for triple in triples:
59 | rel = dict()
60 | rel['em1Text'] = tokens[triple[0]]
61 | rel['em2Text'] = tokens[triple[1]]
62 | rel['label'] = rel_dict[triple[2]]
63 | relationMentions.append(rel)
64 | new_line['relationMentions'] = relationMentions
65 | f2.write(json.dumps(new_line) + '\n')
66 | if is_normal_triple(spos[i]):
67 | f3.write(json.dumps(new_line) + '\n')
68 | if is_multi_label(spos[i]):
69 | f4.write(json.dumps(new_line) + '\n')
70 | if is_over_lapping(spos[i]):
71 | f5.write(json.dumps(new_line) + '\n')
72 |
73 | if __name__ == '__main__':
74 | file_name = 'dev.json'
75 | output = 'new_dev.json'
76 | output_normal = 'new_dev_normal.json'
77 | output_epo = 'new_dev_epo.json'
78 | output_seo = 'new_dev_seo.json'
79 | with open('relations2id.json', 'r') as f1, open('words2id.json', 'r') as f2:
80 | rel2id = json.load(f1)
81 | words2id = json.load(f2)
82 | rel_dict = {j:i for i,j in rel2id.items()}
83 | word_dict = {j:i for i,j in words2id.items()}
84 | load_data(file_name, word_dict, rel_dict, output, output_normal, output_epo, output_seo)
85 |
--------------------------------------------------------------------------------
/data/WebNLG/test_split_by_num/README.md:
--------------------------------------------------------------------------------
1 | Copy ../test.json to here.
2 |
3 | Run split_by_num.py to split the test dataset and store the split files in the form of triples.
4 |
--------------------------------------------------------------------------------
/data/WebNLG/test_split_by_num/split_by_num.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 |
4 | import json
5 | from tqdm import tqdm
6 | import codecs
7 |
8 |
9 | test_1 = []
10 | test_2 = []
11 | test_3 = []
12 | test_4 = []
13 | test_other = []
14 |
15 | with open('test.json') as f:
16 | for l in tqdm(f):
17 | a = json.loads(l)
18 | if not a['relationMentions']:
19 | continue
20 | line = {
21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
23 | }
24 | if not line['triple_list']:
25 | continue
26 | spo_num = len(line['triple_list'])
27 | if spo_num == 1:
28 | test_1.append(line)
29 | elif spo_num == 2:
30 | test_2.append(line)
31 | elif spo_num == 3:
32 | test_3.append(line)
33 | elif spo_num == 4:
34 | test_4.append(line)
35 | else:
36 | test_other.append(line)
37 |
38 |
39 | with codecs.open('test_triples_1.json', 'w', encoding='utf-8') as f:
40 | json.dump(test_1, f, indent=4, ensure_ascii=False)
41 |
42 | with codecs.open('test_triples_2.json', 'w', encoding='utf-8') as f:
43 | json.dump(test_2, f, indent=4, ensure_ascii=False)
44 |
45 | with codecs.open('test_triples_3.json', 'w', encoding='utf-8') as f:
46 | json.dump(test_3, f, indent=4, ensure_ascii=False)
47 |
48 | with codecs.open('test_triples_4.json', 'w', encoding='utf-8') as f:
49 | json.dump(test_4, f, indent=4, ensure_ascii=False)
50 |
51 | with codecs.open('test_triples_5.json', 'w', encoding='utf-8') as f:
52 | json.dump(test_other, f, indent=4, ensure_ascii=False)
53 |
--------------------------------------------------------------------------------
/data/WebNLG/test_split_by_type/README.md:
--------------------------------------------------------------------------------
1 | Move the split files from raw_WebNLG/ here.
2 |
3 | Then run generate_triples.py to get triple files.
4 |
--------------------------------------------------------------------------------
/data/WebNLG/test_split_by_type/generate_triples.py:
--------------------------------------------------------------------------------
1 | #! -*- coding:utf-8 -*-
2 |
3 |
4 | import json
5 | from tqdm import tqdm
6 | import codecs
7 |
8 |
9 | test_normal = []
10 | test_epo = []
11 | test_seo = []
12 |
13 | with open('test_normal.json') as f:
14 | for l in tqdm(f):
15 | a = json.loads(l)
16 | if not a['relationMentions']:
17 | continue
18 | line = {
19 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
20 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
21 | }
22 | if not line['triple_list']:
23 | continue
24 | spo_num = len(line['triple_list'])
25 | test_normal.append(line)
26 |
27 | with open('test_epo.json') as f:
28 | for l in tqdm(f):
29 | a = json.loads(l)
30 | if not a['relationMentions']:
31 | continue
32 | line = {
33 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
34 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
35 | }
36 | if not line['triple_list']:
37 | continue
38 | spo_num = len(line['triple_list'])
39 | test_epo.append(line)
40 |
41 | with open('test_seo.json') as f:
42 | for l in tqdm(f):
43 | a = json.loads(l)
44 | if not a['relationMentions']:
45 | continue
46 | line = {
47 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'),
48 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None']
49 | }
50 | if not line['triple_list']:
51 | continue
52 | spo_num = len(line['triple_list'])
53 | test_seo.append(line)
54 |
55 | with codecs.open('test_triples_normal.json', 'w', encoding='utf-8') as f:
56 | json.dump(test_normal, f, indent=4, ensure_ascii=False)
57 |
58 | with codecs.open('test_triples_epo.json', 'w', encoding='utf-8') as f:
59 | json.dump(test_epo, f, indent=4, ensure_ascii=False)
60 |
61 | with codecs.open('test_triples_seo.json', 'w', encoding='utf-8') as f:
62 | json.dump(test_seo, f, indent=4, ensure_ascii=False)
63 |
64 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import json
2 | from random import choice
3 | from fastNLP import RandomSampler, DataSetIter, DataSet, TorchLoaderIter
4 | from fastNLP.io import JsonLoader
5 | from fastNLP import Vocabulary
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import DataLoader, Dataset
9 | from utils import get_tokenizer
10 |
11 | BERT_MAX_LEN = 512
12 |
13 | tokenizer = get_tokenizer('data/vocab.txt')
14 |
15 |
16 | def load_data(train_path, dev_path, test_path, rel_dict_path):
17 | """
18 | Load dataset from origin dataset using fastNLP
19 | :param train_path: train data path
20 | :param dev_path: dev data path
21 | :param test_path: test data path
22 | :param rel_dict_path: relation dictionary path
23 | :return: data_bundle(contain train, dev, test data), rel_vocab(vocabulary of relations), num_rels(number of relations)
24 | """
25 | paths = {'train': train_path, 'dev': dev_path, "test": test_path}
26 | loader = JsonLoader({"text": "text", "triple_list": "triple_list"})
27 | data_bundle = loader.load(paths)
28 | print(data_bundle)
29 |
30 | id2rel, rel2id = json.load(open(rel_dict_path))
31 | rel_vocab = Vocabulary(unknown=None, padding=None)
32 | rel_vocab.add_word_lst(list(id2rel.values()))
33 | print(rel_vocab)
34 | num_rels = len(id2rel)
35 | print("number of relations: " + str(num_rels))
36 |
37 | return data_bundle, rel_vocab, num_rels
38 |
39 |
40 | def find_head_idx(source, target):
41 | """
42 | Find the index of the head of target in source
43 | :param source: tokenizer list
44 | :param target: target subject or object
45 | :return: the index of the head of target in source, if not found, return -1
46 | """
47 | target_len = len(target)
48 | for i in range(len(source)):
49 | if source[i: i + target_len] == target:
50 | return i
51 | return -1
52 |
53 |
54 | # def preprocess_data(config, dataset, rel_vocab, num_rels, is_test):
55 | # token_ids_list = []
56 | # masks_list = []
57 | # text_len_list = []
58 | # sub_heads_list = []
59 | # sub_tails_list = []
60 | # sub_head_list = []
61 | # sub_tail_list = []
62 | # obj_heads_list = []
63 | # obj_tails_list = []
64 | # triple_list = []
65 | # tokens_list = []
66 | # for item in range(len(dataset)):
67 | # # get tokenizer list
68 | # ins_json_data = dataset[item]
69 | # text = ins_json_data['text']
70 | # text = ' '.join(text.split()[:config.max_len])
71 | # tokens = tokenizer.tokenize(text)
72 | # if len(tokens) > BERT_MAX_LEN:
73 | # tokens = tokens[: BERT_MAX_LEN]
74 | # text_len = len(tokens)
75 | #
76 | # if not is_test:
77 | # # build subject to relation_object map, s2ro_map[(sub_head, sub_tail)] = (obj_head, obj_tail, rel_idx)
78 | # s2ro_map = {}
79 | # for triple in ins_json_data['triple_list']:
80 | # triple = (tokenizer.tokenize(triple[0])[1:-1], triple[1], tokenizer.tokenize(triple[2])[1:-1])
81 | # sub_head_idx = find_head_idx(tokens, triple[0])
82 | # obj_head_idx = find_head_idx(tokens, triple[2])
83 | # if sub_head_idx != -1 and obj_head_idx != -1:
84 | # sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
85 | # if sub not in s2ro_map:
86 | # s2ro_map[sub] = []
87 | # s2ro_map[sub].append(
88 | # (obj_head_idx, obj_head_idx + len(triple[2]) - 1, rel_vocab.to_index(triple[1])))
89 | #
90 | # if s2ro_map:
91 | # token_ids, segment_ids = tokenizer.encode(first=text)
92 | # masks = segment_ids
93 | # if len(token_ids) > text_len:
94 | # token_ids = token_ids[:text_len]
95 | # masks = masks[:text_len]
96 | # token_ids = np.array(token_ids)
97 | # masks = np.array(masks) + 1
98 | # # sub_heads[i]: if index i is the head of any subjects in text
99 | # sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
100 | # for s in s2ro_map:
101 | # sub_heads[s[0]] = 1
102 | # sub_tails[s[1]] = 1
103 | # # randomly select one subject in text and set sub_head and sub_tail
104 | # sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
105 | # sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
106 | # sub_head[sub_head_idx] = 1
107 | # sub_tail[sub_tail_idx] = 1
108 | # obj_heads, obj_tails = np.zeros((text_len, num_rels)), np.zeros((text_len, num_rels))
109 | # for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
110 | # obj_heads[ro[0]][ro[2]] = 1
111 | # obj_tails[ro[1]][ro[2]] = 1
112 | # token_ids_list.append(token_ids)
113 | # masks_list.append(masks)
114 | # text_len_list.append(text_len)
115 | # sub_heads_list.append(sub_heads)
116 | # sub_tails_list.append(sub_tails)
117 | # sub_head_list.append(sub_head)
118 | # sub_tail_list.append(sub_tail)
119 | # obj_heads_list.append(obj_heads)
120 | # obj_tails_list.append(obj_tails)
121 | # triple_list.append(ins_json_data['triple_list'])
122 | # tokens_list.append(tokens)
123 | # else:
124 | # token_ids, segment_ids = tokenizer.encode(first=text)
125 | # masks = segment_ids
126 | # if len(token_ids) > text_len:
127 | # token_ids = token_ids[:text_len]
128 | # masks = masks[:text_len]
129 | # token_ids = np.array(token_ids)
130 | # masks = np.array(masks) + 1
131 | # # initialize these variant with 0
132 | # sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
133 | # sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
134 | # obj_heads, obj_tails = np.zeros((text_len, num_rels)), np.zeros((text_len, num_rels))
135 | # token_ids_list.append(token_ids)
136 | # masks_list.append(masks)
137 | # text_len_list.append(text_len)
138 | # sub_heads_list.append(sub_heads)
139 | # sub_tails_list.append(sub_tails)
140 | # sub_head_list.append(sub_head)
141 | # sub_tail_list.append(sub_tail)
142 | # obj_heads_list.append(obj_heads)
143 | # obj_tails_list.append(obj_tails)
144 | # triple_list.append(ins_json_data['triple_list'])
145 | # tokens_list.append(tokens)
146 | #
147 | # data_dict = {'token_ids': token_ids_list,
148 | # 'marks': masks_list,
149 | # 'text_len': text_len_list,
150 | # 'sub_heads': sub_heads_list,
151 | # 'sub_tails': sub_tails_list,
152 | # 'sub_head': sub_head_list,
153 | # 'sub_tail': sub_tail_list,
154 | # 'obj_heads': obj_heads_list,
155 | # 'obj_tails': obj_tails_list,
156 | # 'triple_list': triple_list,
157 | # 'tokens': tokens_list}
158 | # process_dataset = DataSet(data_dict)
159 | # process_dataset.set_input('token_ids', 'marks', 'text_len', 'sub_head', 'sub_tail', 'tokens')
160 | # process_dataset.set_target('sub_heads', 'sub_tails', 'obj_heads', 'obj_tails', 'triple_list')
161 | # return process_dataset
162 |
163 |
164 | class MyDataset(Dataset):
165 | def __init__(self, config, dataset, rel_vocab, num_rels, is_test):
166 | self.config = config
167 | self.dataset = dataset
168 | self.rel_vocab = rel_vocab
169 | self.num_rels = num_rels
170 | self.is_test = is_test
171 | self.tokenizer = tokenizer
172 |
173 | def __getitem__(self, item):
174 | """
175 | The way of reading data, so that we can get data through MyDataset[i]
176 | :param item: index number
177 | :return: the item st text attribute in dataset
178 | """
179 | # get tokenizer list
180 | ins_json_data = self.dataset[item]
181 | text = ins_json_data['text']
182 | text = ' '.join(text.split()[:self.config.max_len])
183 | tokens = self.tokenizer.tokenize(text)
184 | if len(tokens) > BERT_MAX_LEN:
185 | tokens = tokens[: BERT_MAX_LEN]
186 | text_len = len(tokens)
187 |
188 | if not self.is_test:
189 | # build subject to relation_object map, s2ro_map[(sub_head, sub_tail)] = (obj_head, obj_tail, rel_idx)
190 | s2ro_map = {}
191 | for triple in ins_json_data['triple_list']:
192 | triple = (self.tokenizer.tokenize(triple[0])[1:-1], triple[1], self.tokenizer.tokenize(triple[2])[1:-1])
193 | sub_head_idx = find_head_idx(tokens, triple[0])
194 | obj_head_idx = find_head_idx(tokens, triple[2])
195 | if sub_head_idx != -1 and obj_head_idx != -1:
196 | sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
197 | if sub not in s2ro_map:
198 | s2ro_map[sub] = []
199 | s2ro_map[sub].append(
200 | (obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel_vocab.to_index(triple[1])))
201 |
202 | if s2ro_map:
203 | token_ids, segment_ids = self.tokenizer.encode(first=text)
204 | masks = segment_ids
205 | if len(token_ids) > text_len:
206 | token_ids = token_ids[:text_len]
207 | masks = masks[:text_len]
208 | token_ids = np.array(token_ids)
209 | masks = np.array(masks) + 1
210 | # sub_heads[i]: if index i is the head of any subjects in text
211 | sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
212 | for s in s2ro_map:
213 | sub_heads[s[0]] = 1
214 | sub_tails[s[1]] = 1
215 | # randomly select one subject in text and set sub_head and sub_tail
216 | sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
217 | sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
218 | sub_head[sub_head_idx] = 1
219 | sub_tail[sub_tail_idx] = 1
220 | obj_heads, obj_tails = np.zeros((text_len, self.num_rels)), np.zeros((text_len, self.num_rels))
221 | for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
222 | obj_heads[ro[0]][ro[2]] = 1
223 | obj_tails[ro[1]][ro[2]] = 1
224 | return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, \
225 | ins_json_data['triple_list'], tokens
226 | else:
227 | return None
228 | else:
229 | token_ids, segment_ids = self.tokenizer.encode(first=text)
230 | masks = segment_ids
231 | if len(token_ids) > text_len:
232 | token_ids = token_ids[:text_len]
233 | masks = masks[:text_len]
234 | token_ids = np.array(token_ids)
235 | masks = np.array(masks) + 1
236 | # initialize these variant with 0
237 | sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
238 | sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
239 | obj_heads, obj_tails = np.zeros((text_len, self.num_rels)), np.zeros((text_len, self.num_rels))
240 | return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, \
241 | ins_json_data['triple_list'], tokens
242 |
243 | def __len__(self):
244 | return len(self.dataset)
245 |
246 |
247 | def my_collate_fn(batch):
248 | """
249 | Merge data in one batch
250 | :param batch: the batch size
251 | :return: a dictionary
252 | """
253 | batch = list(filter(lambda x: x is not None, batch))
254 | batch.sort(key=lambda x: x[2], reverse=True)
255 | token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, triples, tokens = zip(
256 | *batch)
257 | cur_batch = len(batch)
258 | max_text_len = max(text_len)
259 | batch_token_ids = torch.LongTensor(cur_batch, max_text_len).zero_()
260 | batch_masks = torch.LongTensor(cur_batch, max_text_len).zero_()
261 | batch_sub_heads = torch.Tensor(cur_batch, max_text_len).zero_()
262 | batch_sub_tails = torch.Tensor(cur_batch, max_text_len).zero_()
263 | batch_sub_head = torch.Tensor(cur_batch, max_text_len).zero_()
264 | batch_sub_tail = torch.Tensor(cur_batch, max_text_len).zero_()
265 | batch_obj_heads = torch.Tensor(cur_batch, max_text_len, 24).zero_()
266 | batch_obj_tails = torch.Tensor(cur_batch, max_text_len, 24).zero_()
267 |
268 | for i in range(cur_batch):
269 | batch_token_ids[i, :text_len[i]].copy_(torch.from_numpy(token_ids[i]))
270 | batch_masks[i, :text_len[i]].copy_(torch.from_numpy(masks[i]))
271 | batch_sub_heads[i, :text_len[i]].copy_(torch.from_numpy(sub_heads[i]))
272 | batch_sub_tails[i, :text_len[i]].copy_(torch.from_numpy(sub_tails[i]))
273 | batch_sub_head[i, :text_len[i]].copy_(torch.from_numpy(sub_head[i]))
274 | batch_sub_tail[i, :text_len[i]].copy_(torch.from_numpy(sub_tail[i]))
275 | batch_obj_heads[i, :text_len[i], :].copy_(torch.from_numpy(obj_heads[i]))
276 | batch_obj_tails[i, :text_len[i], :].copy_(torch.from_numpy(obj_tails[i]))
277 |
278 | return {'token_ids': batch_token_ids,
279 | 'mask': batch_masks,
280 | 'sub_head': batch_sub_head,
281 | 'sub_tail': batch_sub_tail,
282 | 'tokens': tokens}, \
283 | {'sub_heads': batch_sub_heads,
284 | 'sub_tails': batch_sub_tails,
285 | 'obj_heads': batch_obj_heads,
286 | 'obj_tails': batch_obj_tails,
287 | 'triples': triples,
288 | }
289 |
290 |
291 | def get_data_iter(config, dataset, rel_vocab, num_rels, is_test=False, num_workers=0, collate_fn=my_collate_fn):
292 | """
293 | Build a data Iterator that combines a dataset and a sampler, and provides single- or multi-process iterators
294 | over the dataset.
295 | :param config: configuration
296 | :param dataset: certain dataset in data bundle processed by fastNLP
297 | :param rel_vocab: vocabulary of relations
298 | :param num_rels: the number of relations
299 | :param is_test: if not test, use RandomSampler; if test, use SequentialSampler
300 | :param num_workers: how many subprocesses to use for data loading
301 | :param collate_fn: merges a list of samples to form a mini-batch
302 | :return: a dataloader
303 | """
304 | dataset = MyDataset(config, dataset, rel_vocab, num_rels, is_test)
305 | # dataset = preprocess_data(config, dataset, rel_vocab, num_rels, is_test)
306 | # print(dataset)
307 | if not is_test:
308 | sampler = RandomSampler()
309 | data_iter = TorchLoaderIter(dataset=dataset,
310 | collate_fn=collate_fn,
311 | batch_size=config.batch_size,
312 | num_workers=num_workers,
313 | pin_memory=True)
314 | # data_iter = DataSetIter(dataset=dataset,
315 | # batch_size=config.batch_size,
316 | # num_workers=num_workers,
317 | # sampler=sampler,
318 | # pin_memory=True)
319 | else:
320 | data_iter = TorchLoaderIter(dataset=dataset,
321 | collate_fn=collate_fn,
322 | batch_size=1,
323 | num_workers=num_workers,
324 | pin_memory=True)
325 | # data_iter = DataSetIter(dataset=dataset,
326 | # batch_size=1,
327 | # num_workers=num_workers,
328 | # pin_memory=True)
329 | return data_iter
330 |
331 | # class DataPreFetcher(object):
332 | # def __init__(self, loader):
333 | # self.loader = iter(loader)
334 | # # self.stream = torch.cuda.Stream()
335 | # self.preload()
336 | #
337 | # def preload(self):
338 | # try:
339 | # self.next_data = next(self.loader)
340 | # except StopIteration:
341 | # self.next_data = None
342 | # return
343 | # # with torch.cuda.stream(self.stream):
344 | # # for k, v in self.next_data.items():
345 | # # if isinstance(v, torch.Tensor):
346 | # # self.next_data[k] = self.next_data[k].cuda(non_blocking=True)
347 | #
348 | # def next(self):
349 | # # torch.cuda.current_stream().wait_stream(self.stream)
350 | # data = self.next_data
351 | # self.preload()
352 | # return data
353 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from fastNLP import Callback
7 | from transformers import *
8 | from utils import metric
9 |
10 |
11 | class CasRel(nn.Module):
12 | def __init__(self, config, num_rels):
13 | super().__init__()
14 | self.config = config
15 | self.bert_encoder = BertModel.from_pretrained(self.config.bert_model_name)
16 | self.bert_dim = 768 # the size of hidden state
17 | self.sub_heads_linear = nn.Linear(self.bert_dim, 1)
18 | self.sub_tails_linear = nn.Linear(self.bert_dim, 1)
19 | self.obj_heads_linear = nn.Linear(self.bert_dim, num_rels)
20 | self.obj_tails_linear = nn.Linear(self.bert_dim, num_rels)
21 |
22 | def get_encoded_text(self, token_ids, mask):
23 | # [batch_size, seq_len, bert_dim(768)]
24 | encoded_text = self.bert_encoder(token_ids, attention_mask=mask)[0]
25 | return encoded_text
26 |
27 | def get_subs(self, encoded_text):
28 | """
29 | Subject Taggers
30 | :param encoded_text: input sentenced pretrained with BERT
31 | :return: predicted subject head, predicted object head
32 | """
33 | # [batch_size, seq_len, 1]
34 | pred_sub_heads = self.sub_heads_linear(encoded_text)
35 | pred_sub_heads = torch.sigmoid(pred_sub_heads)
36 | # [batch_size, seq_len, 1]
37 | pred_sub_tails = self.sub_tails_linear(encoded_text)
38 | pred_sub_tails = torch.sigmoid(pred_sub_tails)
39 | return pred_sub_heads, pred_sub_tails
40 |
41 | def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text):
42 | """
43 | Relation-specific Object Taggers
44 | :param sub_head_mapping:
45 | :param sub_tail_mapping:
46 | :param encoded_text: input sentenced pretrained with BERT
47 | :return: predicted object head, predicted object tail
48 | """
49 | # [batch_size, 1, bert_dim]
50 | sub_head = torch.matmul(sub_head_mapping, encoded_text)
51 | # [batch_size, 1, bert_dim]
52 | sub_tail = torch.matmul(sub_tail_mapping, encoded_text)
53 | # [batch_size, 1 bert_dim]
54 | sub = (sub_head + sub_tail) / 2
55 | # [batch_size, seq_len, bert_dim]
56 | encoded_text = encoded_text + sub
57 | # [batch_size, seq_len, rel_num]
58 | pred_obj_heads = self.obj_heads_linear(encoded_text)
59 | pred_obj_heads = torch.sigmoid(pred_obj_heads)
60 | pred_obj_tails = self.obj_tails_linear(encoded_text)
61 | pred_obj_tails = torch.sigmoid(pred_obj_tails)
62 | return pred_obj_heads, pred_obj_tails
63 |
64 | def forward(self, data):
65 | # [batch_size, seq_len]
66 | token_ids = data['token_ids']
67 | # [batch_size, seq_len]
68 | mask = data['mask']
69 | # [batch_size, seq_len, bert_dim(768)]
70 | encoded_text = self.get_encoded_text(token_ids, mask)
71 | # [batch_size, seq_len, 1]
72 | pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
73 | # [batch_size, 1, seq_len]
74 | sub_head_mapping = data['sub_head'].unsqueeze(1)
75 | # [batch_size, 1, seq_len]
76 | sub_tail_mapping = data['sub_tail'].unsqueeze(1)
77 | # [batch_size, seq_len, rel_num]
78 | pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping,
79 | encoded_text)
80 | return pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails
81 |
82 |
83 | class MyCallBack(Callback):
84 | def __init__(self, data_iter, rel_vocab, config):
85 | super().__init__()
86 | self.loss_sum = 0
87 | self.global_step = 0
88 |
89 | self.data_iter = data_iter
90 | self.rel_vocab = rel_vocab
91 | self.config = config
92 |
93 | def logging(self, s, print_=True, log_=True):
94 | if print_:
95 | print(s)
96 | if log_:
97 | with open(os.path.join(self.config.save_logs_dir, self.config.log_save_name), 'a+') as f_log:
98 | f_log.write(s + '\n')
99 |
100 | # define the loss function
101 | def loss(self, pred, gold, mask):
102 | pred = pred.squeeze(-1)
103 | los = F.binary_cross_entropy(pred, gold, reduction='none')
104 | if los.shape != mask.shape:
105 | mask = mask.unsqueeze(-1)
106 | los = torch.sum(los * mask) / torch.sum(mask)
107 | return los
108 |
109 | def on_train_begin(self):
110 | self.best_f1_score = 0
111 | self.best_precision = 0
112 | self.best_recall = 0
113 |
114 | self.best_epoch = 0
115 | self.init_time = time.time()
116 | self.start_time = time.time()
117 | print("-" * 5 + "Initializing the model" + "-" * 5)
118 |
119 | def on_epoch_begin(self):
120 | self.eval_start_time = time.time()
121 |
122 | def on_batch_begin(self, batch_x, batch_y, indices):
123 | pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = self.model(batch_x)
124 | sub_heads_loss = self.loss(batch_y['sub_heads'], pred_sub_heads, batch_x['mask'])
125 | sub_tails_loss = self.loss(batch_y['sub_tails'], pred_sub_tails, batch_x['mask'])
126 | obj_heads_loss = self.loss(batch_y['obj_heads'], pred_obj_heads, batch_x['mask'])
127 | obj_tails_loss = self.loss(batch_y['obj_tails'], pred_obj_tails, batch_x['mask'])
128 | total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)
129 |
130 | self.optimizer.zero_grad()
131 | total_loss.backward()
132 | self.optimizer.step()
133 |
134 | self.loss_sum += total_loss.item()
135 |
136 | def on_epoch_end(self):
137 | precision, recall, f1_score = metric(self.data_iter, self.rel_vocab, self.config, self.model)
138 | self.logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'.
139 | format(self.epoch, time.time() - self.eval_start_time, f1_score, precision, recall))
140 | if f1_score > self.best_f1_score:
141 | self.best_f1_score = f1_score
142 | self.best_epoch = self.epoch
143 | self.best_precision = precision
144 | self.best_recall = recall
145 | self.logging("Saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".
146 | format(self.best_epoch, self.best_f1_score, precision, recall))
147 | # save the best model
148 | path = os.path.join(self.config.save_weights_dir, self.config.weights_save_name)
149 | torch.save(self.model.state_dict(), path)
150 |
151 | def on_train_end(self):
152 | self.logging("-" * 5 + "Finish training" + "-" * 5)
153 | self.logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s".
154 | format(self.best_epoch, self.best_f1_score, self.best_precision, self.best_recall,
155 | time.time() - self.init_time))
156 |
--------------------------------------------------------------------------------
/pretrained_bert_models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/pretrained_bert_models/.DS_Store
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from fastNLP import Trainer, LossFunc
2 | from data_loader import load_data
3 | from model import CasRel
4 | from utils import get_tokenizer
5 | import argparse
6 |
7 | parser = argparse.ArgumentParser(description='Model Controller')
8 | parser.add_argument('--train', default=False, type=bool, help='to train the HBT model, python run.py --train=True')
9 | parser.add_argument('--dataset', default='WebNLG', type=str,
10 | help='specify the dataset from ["NYT","WebNLG"]')
11 | args = parser.parse_args()
12 |
13 | if __name__ == '__main__':
14 | # pre-trained bert model name
15 | bert_model_name = 'en-base-cased'
16 |
17 | vocab_path = 'data/vocab.txt'
18 | # load dataset
19 | # dataset = args.dataset
20 | dataset = "NYT"
21 | train_path = 'data/' + dataset + '/train_triples.json'
22 | dev_path = 'data/' + dataset + '/dev_triples.json'
23 | test_path = 'data/' + dataset + '/test_triples.json' # overall test
24 | # test_path = 'data/' + dataset + '/test_split_by_num/test_triples_5.json' # ['1','2','3','4','5']
25 | # test_path = 'data/' + dataset + '/test_split_by_type/test_triples_seo.json' # ['normal', 'seo', 'epo']
26 | rel_dict_path = 'data/' + dataset + '/rel2id.json'
27 | save_weights_path = 'saved_weights/' + dataset + '/best_model.weights'
28 | save_logs_path = '/saved_logs/' + dataset + '/log'
29 |
30 | # parameters
31 | LR = 1e-5
32 | tokenizer = get_tokenizer(vocab_path)
33 | data_bundle, rel_vocab, num_rels = load_data(train_path, dev_path, test_path, rel_dict_path)
34 | model = CasRel(bert_model_name, num_rels)
35 |
36 | if args.train:
37 | BATCH_SIZE = 6
38 | EPOCH = 100
39 | MAX_LEN = 100
40 | STEPS = len(data_bundle.get_dataset('train')) // BATCH_SIZE
41 | metric = SpanFPreRecMetric()
42 | optimizer = Adam(lr=LR)
43 | trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, batch_size=BATCH_SIZE,
44 | metrics=metric)
45 |
--------------------------------------------------------------------------------
/saved_logs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_logs/.DS_Store
--------------------------------------------------------------------------------
/saved_logs/NYT/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_logs/NYT/.DS_Store
--------------------------------------------------------------------------------
/saved_weights/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_weights/.DS_Store
--------------------------------------------------------------------------------
/saved_weights/NYT/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_weights/NYT/.DS_Store
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import time
5 | from fastNLP import Trainer
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | import config
12 | from data_loader import load_data, get_data_iter
13 | from model import CasRel, MyCallBack
14 | from utils import metric
15 |
16 | seed = 1234
17 | torch.manual_seed(seed)
18 | torch.cuda.manual_seed(seed)
19 | np.random.seed(seed)
20 | random.seed(seed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 |
24 | parser = argparse.ArgumentParser(description='Model Controller')
25 | parser.add_argument('--model_name', type=str, default='CasRel', help='name of the model')
26 | parser.add_argument('--dataset', type=str, default='NYT', help='specify the dataset from ["NYT","WebNLG"]')
27 | parser.add_argument('--bert_model_name', type=str, default='bert-base-cased', help='the name of pretrained bert model')
28 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
29 | parser.add_argument('--multi_gpu', type=bool, default=False, help='if use multiple gpu')
30 | parser.add_argument('--batch_size', type=int, default=6)
31 | parser.add_argument('--max_epoch', type=int, default=200)
32 | parser.add_argument('--test_epoch', type=int, default=1)
33 | parser.add_argument('--max_len', type=int, default=150)
34 | parser.add_argument('--period', type=int, default=50)
35 | args = parser.parse_args()
36 |
37 | con = config.Config(args)
38 |
39 | # get the data and dataloader
40 | print("-" * 5 + "Starting processing data" + "-" * 5)
41 | data_bundle, rel_vocab, num_rels = load_data(con.train_path, con.dev_path, con.test_path, con.rel_dict_path)
42 | print("Test process data:")
43 | test_data_iter = get_data_iter(con, data_bundle.get_dataset('test'), rel_vocab, num_rels, is_test=True)
44 | print("-" * 5 + "Data processing done" + "-" * 5)
45 |
46 | # check the checkpoint dir
47 | if not os.path.exists(con.save_weights_dir):
48 | os.mkdir(con.save_weights_dir)
49 | # check the log dir
50 | if not os.path.exists(con.save_logs_dir):
51 | os.mkdir(con.save_logs_dir)
52 |
53 |
54 | def test():
55 | model = CasRel(con, num_rels)
56 | path = os.path.join(con.save_weights_dir, con.weights_save_name)
57 | model.load_state_dict(torch.load(path))
58 | model.cuda()
59 | model.eval()
60 | precision, recall, f1_score = metric(test_data_iter, rel_vocab, con, model, True)
61 | print("f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".format(f1_score, precision, recall))
62 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import time
5 | from fastNLP import Trainer
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | import config
12 | from data_loader import load_data, get_data_iter
13 | from model import CasRel, MyCallBack
14 | from utils import metric
15 |
16 | seed = 1234
17 | torch.manual_seed(seed)
18 | torch.cuda.manual_seed(seed)
19 | np.random.seed(seed)
20 | random.seed(seed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 |
24 | parser = argparse.ArgumentParser(description='Model Controller')
25 | parser.add_argument('--model_name', type=str, default='CasRel', help='name of the model')
26 | parser.add_argument('--dataset', type=str, default='NYT', help='specify the dataset from ["NYT","WebNLG"]')
27 | parser.add_argument('--bert_model_name', type=str, default='bert-base-cased', help='the name of pretrained bert model')
28 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
29 | parser.add_argument('--multi_gpu', type=bool, default=False, help='if use multiple gpu')
30 | parser.add_argument('--batch_size', type=int, default=6)
31 | parser.add_argument('--max_epoch', type=int, default=300)
32 | parser.add_argument('--test_epoch', type=int, default=5)
33 | parser.add_argument('--max_len', type=int, default=150)
34 | parser.add_argument('--period', type=int, default=50)
35 | args = parser.parse_args()
36 |
37 | con = config.Config(args)
38 |
39 | # get the data and dataloader
40 | print("-" * 5 + "Starting processing data" + "-" * 5)
41 | data_bundle, rel_vocab, num_rels = load_data(con.train_path, con.dev_path, con.test_path, con.rel_dict_path)
42 | print("Train process data:")
43 | train_data_iter = get_data_iter(con, data_bundle.get_dataset('train'), rel_vocab, num_rels)
44 | print("Dev process data:")
45 | dev_data_iter = get_data_iter(con, data_bundle.get_dataset('dev'), rel_vocab, num_rels, is_test=True)
46 | print("-" * 5 + "Data processing done" + "-" * 5)
47 |
48 | # check the checkpoint dir
49 | if not os.path.exists(con.save_weights_dir):
50 | os.mkdir(con.save_weights_dir)
51 | # check the log dir
52 | if not os.path.exists(con.save_logs_dir):
53 | os.mkdir(con.save_logs_dir)
54 |
55 |
56 | # define the loss function
57 | def loss(pred, gold, mask):
58 | pred = pred.squeeze(-1)
59 | los = F.binary_cross_entropy(pred, gold, reduction='none')
60 | if los.shape != mask.shape:
61 | mask = mask.unsqueeze(-1)
62 | los = torch.sum(los * mask) / torch.sum(mask)
63 | return los
64 |
65 |
66 | def logging(s, print_=True, log_=True):
67 | if print_:
68 | print(s)
69 | if log_:
70 | with open(os.path.join(con.save_logs_dir, con.log_save_name), 'a+') as f_log:
71 | f_log.write(s + '\n')
72 |
73 |
74 | def train():
75 | # initialize the model
76 | print("-" * 5 + "Initializing the model" + "-" * 5)
77 | model = CasRel(con, num_rels)
78 | # model.cuda()
79 |
80 | # whether use multi GPU
81 | if con.multi_gpu:
82 | model = nn.DataParallel(model)
83 | model.train()
84 |
85 | # define the optimizer
86 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=con.learning_rate)
87 |
88 | # other
89 | global_step = 0
90 | loss_sum = 0
91 |
92 | best_f1_score = 0.0
93 | best_precision = 0.0
94 | best_recall = 0.0
95 |
96 | best_epoch = 0
97 | init_time = time.time()
98 | start_time = time.time()
99 |
100 | # the training loop
101 | print("-" * 5 + "Start training" + "-" * 5)
102 | for epoch in range(con.max_epoch):
103 | for batch_x, batch_y in train_data_iter:
104 | if global_step == 20:
105 | break
106 | pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = model(batch_x)
107 | sub_heads_loss = loss(pred_sub_heads, batch_y['sub_heads'], batch_x['mask'])
108 | sub_tails_loss = loss(pred_sub_tails, batch_y['sub_tails'], batch_x['mask'])
109 | obj_heads_loss = loss(pred_obj_heads, batch_y['obj_heads'], batch_x['mask'])
110 | obj_tails_loss = loss(pred_obj_tails, batch_y['obj_tails'], batch_x['mask'])
111 | total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)
112 |
113 | optimizer.zero_grad()
114 | total_loss.backward()
115 | optimizer.step()
116 |
117 | global_step += 1
118 | loss_sum += total_loss.item()
119 |
120 | if global_step % con.period == 0:
121 | cur_loss = loss_sum / con.period
122 | elapsed = time.time() - start_time
123 | logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
124 | format(epoch, global_step, elapsed * 1000 / con.period, cur_loss))
125 | loss_sum = 0
126 | start_time = time.time()
127 |
128 | if (epoch + 1) % con.test_epoch == 0:
129 | eval_start_time = time.time()
130 | model.eval()
131 | # call the test function
132 | precision, recall, f1_score = metric(dev_data_iter, rel_vocab, con, model)
133 | model.train()
134 | logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'.
135 | format(epoch, time.time() - eval_start_time, f1_score, precision, recall))
136 |
137 | if f1_score > best_f1_score:
138 | best_f1_score = f1_score
139 | best_epoch = epoch
140 | best_precision = precision
141 | best_recall = recall
142 | logging("Saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".
143 | format(best_epoch, best_f1_score, precision, recall))
144 | # save the best model
145 | path = os.path.join(con.save_weights_dir, con.weights_save_name)
146 | torch.save(model.state_dict(), path)
147 |
148 | # manually release the unused cache
149 | # torch.cuda.empty_cache()
150 |
151 | logging("-" * 5 + "Finish training" + "-" * 5)
152 | logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s".
153 | format(best_epoch, best_f1_score, best_precision, best_recall, time.time() - init_time))
154 |
155 |
156 | train()
157 |
158 | # model = CasRel(con, num_rels)
159 | # # model.cuda()
160 | #
161 | # # define the optimizer
162 | # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=con.learning_rate)
163 | #
164 | # Trainer(train_data_iter, model, optimizer, batch_size=con.batch_size, n_epochs=con.max_epoch, print_every=con.period,
165 | # callbacks=[MyCallBack(dev_data_iter, rel_vocab, con)])
166 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import unicodedata
4 | import codecs
5 | from keras_bert import Tokenizer
6 | import numpy as np
7 | import torch
8 |
9 |
10 | class HBTokenizer(Tokenizer):
11 | def _tokenize(self, text):
12 | if not self._cased:
13 | text = unicodedata.normalize('NFD', text)
14 | text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn'])
15 | text = text.lower()
16 | spaced = ''
17 | for ch in text:
18 | if ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch):
19 | continue
20 | else:
21 | spaced += ch
22 | tokens = []
23 | for word in spaced.strip().split():
24 | tokens += self._word_piece_tokenize(word)
25 | tokens.append('[unused1]')
26 | return tokens
27 |
28 |
29 | def get_tokenizer(vocab_path):
30 | token_dict = {}
31 | with codecs.open(vocab_path, 'r', 'utf8') as reader:
32 | for line in reader:
33 | token = line.strip()
34 | token_dict[token] = len(token_dict)
35 | return HBTokenizer(token_dict, cased=True)
36 |
37 |
38 | def to_tup(triple_list):
39 | ret = []
40 | for triple in triple_list:
41 | ret.append(tuple(triple))
42 | return ret
43 |
44 |
45 | def metric(data_iter, rel_vocab, config, model, output=False, h_bar=0.5, t_bar=0.5):
46 | if output:
47 | # check the result dir
48 | if not os.path.exists(config.result_dir):
49 | os.mkdir(config.result_dir)
50 | path = os.path.join(config.result_dir, config.result_save_name)
51 | fw = open(path, 'w')
52 |
53 | orders = ['subject', 'relation', 'object']
54 | correct_num, predict_num, gold_num = 0, 0, 0
55 |
56 | for batch_x, batch_y in data_iter:
57 | with torch.no_grad():
58 | token_ids = batch_x['token_ids']
59 | tokens = batch_x['tokens']
60 | mask = batch_x['mask']
61 | encoded_text = model.get_encoded_text(token_ids, mask)
62 | pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)
63 | sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], \
64 | np.where(pred_sub_tails.cpu()[0] > t_bar)[0]
65 | subjects = []
66 | for sub_head in sub_heads:
67 | sub_tail = sub_tails[sub_tails >= sub_head]
68 | if len(sub_tail) > 0:
69 | sub_tail = sub_tail[0]
70 | subject = tokens[sub_head: sub_tail]
71 | subjects.append((subject, sub_head, sub_tail))
72 |
73 | if subjects:
74 | triple_list = []
75 | # [subject_num, seq_len, bert_dim]
76 | repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
77 | # [subject_num, 1, seq_len]
78 | sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
79 | sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
80 | for subject_idx, subject in enumerate(subjects):
81 | sub_head_mapping[subject_idx][0][subject[1]] = 1
82 | sub_tail_mapping[subject_idx][0][subject[2]] = 1
83 | sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
84 | sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)
85 | pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping,
86 | repeated_encoded_text)
87 | for subject_idx, subject in enumerate(subjects):
88 | sub = subject[0]
89 | sub = ''.join([i.lstrip("##") for i in sub])
90 | sub = ' '.join(sub.split('[unused1]'))
91 | obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(
92 | pred_obj_tails.cpu()[subject_idx] > t_bar)
93 | for obj_head, rel_head in zip(*obj_heads):
94 | for obj_tail, rel_tail in zip(*obj_tails):
95 | if obj_head <= obj_tail and rel_head == rel_tail:
96 | rel = rel_vocab.to_word[int(rel_head)]
97 | obj = tokens[obj_head: obj_tail]
98 | obj = ''.join([i.lstrip("##") for i in obj])
99 | obj = ' '.join(obj.split('[unused1]'))
100 | triple_list.append((sub, rel, obj))
101 | break
102 | triple_set = set()
103 | for s, r, o in triple_list:
104 | triple_set.add((s, r, o))
105 | pred_list = list(triple_set)
106 | else:
107 | pred_list = []
108 |
109 | pred_triples = set(pred_list)
110 | gold_triples = set(to_tup(batch_y['triples'][0]))
111 |
112 | correct_num += len(pred_triples & gold_triples)
113 | predict_num += len(pred_triples)
114 | gold_num += len(gold_triples)
115 |
116 | if output:
117 | result = json.dumps({
118 | # 'text': ' '.join(tokens),
119 | 'triple_list_gold': [
120 | dict(zip(orders, triple)) for triple in gold_triples
121 | ],
122 | 'triple_list_pred': [
123 | dict(zip(orders, triple)) for triple in pred_triples
124 | ],
125 | 'new': [
126 | dict(zip(orders, triple)) for triple in pred_triples - gold_triples
127 | ],
128 | 'lack': [
129 | dict(zip(orders, triple)) for triple in gold_triples - pred_triples
130 | ]
131 | }, ensure_ascii=False)
132 | fw.write(result + '\n')
133 |
134 | print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num))
135 |
136 | precision = correct_num / (predict_num + 1e-10)
137 | recall = correct_num / (gold_num + 1e-10)
138 | f1_score = 2 * precision * recall / (precision + recall + 1e-10)
139 | return precision, recall, f1_score
140 |
--------------------------------------------------------------------------------