├── .DS_Store ├── .idea ├── .gitignore ├── CasRel_fastNLP.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── config.py ├── data ├── .DS_Store ├── NYT │ ├── README.md │ ├── build_data.py │ ├── raw_NYT │ │ ├── README.md │ │ └── generate.py │ ├── test_split_by_num │ │ ├── README.md │ │ └── split_by_num.py │ └── test_split_by_type │ │ ├── README.md │ │ └── generate_triples.py └── WebNLG │ ├── README.md │ ├── build_data.py │ ├── raw_WebNLG │ ├── README.md │ └── generate.py │ ├── test_split_by_num │ ├── README.md │ └── split_by_num.py │ └── test_split_by_type │ ├── README.md │ └── generate_triples.py ├── data_loader.py ├── model.py ├── pretrained_bert_models └── .DS_Store ├── run.py ├── saved_logs ├── .DS_Store └── NYT │ └── .DS_Store ├── saved_weights ├── .DS_Store └── NYT │ └── .DS_Store ├── test.py ├── train.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/CasRel_fastNLP.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CasRel_fastNLP 2 | This repo is [fastNLP](https://fastnlp.readthedocs.io/zh/latest/) reimplementation of the paper: [**"A Novel Cascade Binary Tagging Framework for Relational Triple Extraction"**](https://www.aclweb.org/anthology/2020.acl-main.136.pdf), which was published in ACL2020. The [original code](https://github.com/weizhepei/CasRel) was written in keras. 3 | 4 | ## Requirements 5 | 6 | - Python 3.8 7 | - Pytorch 1.7 8 | - fastNLP 0.6.0 9 | - keras-bert 0.86.0 10 | - numpy 1.19.1 11 | - transformers 4.0.0 12 | 13 | Other dependent packages described in [fastNLP Docs](https://fastnlp.readthedocs.io/zh/latest/user/installation.html). 14 | 15 | ## Datasets 16 | 17 | - [NYT](https://github.com/weizhepei/CasRel/tree/master/data/NYT) 18 | - [WebNLG](https://github.com/weizhepei/CasRel/tree/master/data/WebNLG) 19 | 20 | ## Usage 21 | 22 | 2. **Build dataset in the form of triples** (DONE) 23 | Take the NYT dataset for example: 24 | 25 | a) Switch to the corresponding directory and download the dataset 26 | 27 | ``` 28 | cd CasRel/data/NYT/raw_NYT 29 | ``` 30 | 31 | b) Follow the instructions at the same directory, and just run 32 | 33 | ``` 34 | python generate.py 35 | ``` 36 | 37 | c) Finally, build dataset in the form of triples 38 | 39 | ``` 40 | cd CasRel/data/NYT 41 | python build_data.py 42 | ``` 43 | 44 | This will convert the raw numerical dataset into a proper format for our model and generate `train.json`, `test.json` and `dev.json`. Then split the test dataset by type and num for in-depth analysis on different scenarios of overlapping triples. (Run `python generate.py` under corresponding folders) 45 | 46 | - NYT: 47 | - Train: 56195, dev: 4999, test: 5000 48 | - normal : EPO : SEO = 3266 : 978 : 1297 49 | - WebNLG: 50 | - Train: 5019, dev: 703, test: 500 51 | - normal : EPO : SEO = 182 : 16 : 318 52 | 53 | 2. **Specify the experimental settings** (DONE) 54 | 55 | By default, we use the following settings in train.py: 56 | 57 | ``` 58 | { 59 | "model_name": "CasRel", 60 | "dataset": "NYT", 61 | "bert_model_name": "bert-base-cased", 62 | "lr": 1e-5, 63 | "nulti-gpu": False, 64 | "batch_size": 6, 65 | "max_epoch": 100, 66 | "test_epoch": 5, 67 | "max_len": 100, 68 | "period": 50, 69 | } 70 | ``` 71 | 72 | 3. **Train and select the model** 73 | - [ ] Now trained in cpu, plan to move to gpu later. 74 | - [ ] Now define my own train model, plan to use Trainer class in fastNLP. 75 | 76 | 4. **Evaluate on the test set** 77 | 78 | 79 | 80 | ## Results 81 | 82 | 83 | 84 | ## References 85 | 86 | [1] https://github.com/weizhepei/CasRel 87 | 88 | [2] https://github.com/longlongman/CasRel-pytorch-reimplement 89 | 90 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | def __init__(self, args): 3 | self.args = args 4 | 5 | # dataset 6 | self.dataset = args.dataset 7 | 8 | # train hyper parameters 9 | self.multi_gpu = args.multi_gpu 10 | self.learning_rate = args.lr 11 | self.batch_size = args.batch_size 12 | self.max_epoch = args.max_epoch 13 | self.max_len = args.max_len 14 | 15 | # path and name 16 | self.bert_model_name = args.bert_model_name 17 | self.train_path = 'data/' + self.dataset + '/train_triples.json' 18 | self.dev_path = 'data/' + self.dataset + '/dev_triples.json' 19 | self.test_path = 'data/' + self.dataset + '/test_triples.json' # overall test 20 | self.rel_dict_path = 'data/' + self.dataset + '/rel2id.json' 21 | self.save_weights_dir = 'saved_weights/' + self.dataset + '/' 22 | self.save_logs_dir = 'saved_logs/' + self.dataset + '/' 23 | self.result_dir = 'results/' + self.dataset 24 | self.weights_save_name = args.model_name + '_DATASET_' + self.dataset + "_LR_" + str( 25 | self.learning_rate) + "_BS_" + str(self.batch_size) 26 | self.log_save_name = 'LOG_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str( 27 | self.learning_rate) + "_BS_" + str(self.batch_size) 28 | self.result_save_name = 'RESULT_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str( 29 | self.learning_rate) + "_BS_" + str(self.batch_size) + ".json" 30 | 31 | # log setting 32 | self.period = args.period 33 | self.test_epoch = args.test_epoch 34 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/data/.DS_Store -------------------------------------------------------------------------------- /data/NYT/README.md: -------------------------------------------------------------------------------- 1 | First download and process the data at dir raw_NYT/ 2 | 3 | Then run build_data.py to get triple files. 4 | -------------------------------------------------------------------------------- /data/NYT/build_data.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | 8 | rel_set = set() 9 | 10 | 11 | train_data = [] 12 | dev_data = [] 13 | test_data = [] 14 | 15 | with open('train.json') as f: 16 | for l in tqdm(f): 17 | a = json.loads(l) 18 | if not a['relationMentions']: 19 | continue 20 | line = { 21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 23 | } 24 | if not line['triple_list']: 25 | continue 26 | train_data.append(line) 27 | for rm in a['relationMentions']: 28 | if rm['label'] != 'None': 29 | rel_set.add(rm['label']) 30 | 31 | 32 | with open('dev.json') as f: 33 | for l in tqdm(f): 34 | a = json.loads(l) 35 | if not a['relationMentions']: 36 | continue 37 | line = { 38 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 39 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 40 | } 41 | if not line['triple_list']: 42 | continue 43 | dev_data.append(line) 44 | for rm in a['relationMentions']: 45 | if rm['label'] != 'None': 46 | rel_set.add(rm['label']) 47 | 48 | cnt = 0 49 | with open('test.json') as f: 50 | for l in tqdm(f): 51 | a = json.loads(l) 52 | if not a['relationMentions']: 53 | continue 54 | line = { 55 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 56 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 57 | } 58 | if not line['triple_list']: 59 | continue 60 | cnt += len(line['triple_list']) 61 | test_data.append(line) 62 | 63 | print(f'test triples:{cnt}') 64 | 65 | 66 | id2predicate = {i:j for i,j in enumerate(sorted(rel_set))} 67 | predicate2id = {j:i for i,j in id2predicate.items()} 68 | 69 | 70 | with codecs.open('rel2id.json', 'w', encoding='utf-8') as f: 71 | json.dump([id2predicate, predicate2id], f, indent=4, ensure_ascii=False) 72 | 73 | 74 | with codecs.open('train_triples.json', 'w', encoding='utf-8') as f: 75 | json.dump(train_data, f, indent=4, ensure_ascii=False) 76 | 77 | 78 | with codecs.open('dev_triples.json', 'w', encoding='utf-8') as f: 79 | json.dump(dev_data, f, indent=4, ensure_ascii=False) 80 | 81 | 82 | with codecs.open('test_triples.json', 'w', encoding='utf-8') as f: 83 | json.dump(test_data, f, indent=4, ensure_ascii=False) 84 | -------------------------------------------------------------------------------- /data/NYT/raw_NYT/README.md: -------------------------------------------------------------------------------- 1 | Download CopyR's NYT data at https://drive.google.com/open?id=10f24s9gM7NdyO3z5OqQxJgYud4NnCJg3 2 | 3 | Run generate.py to convert the numerical mentions into string mentions. Move the converted files to NYT/ 4 | 5 | Also, the test data should be split by overlapping types, i.e., Normal, SEO, and EPO. Move the split files to NYT/test_split_by_type/ 6 | -------------------------------------------------------------------------------- /data/NYT/raw_NYT/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | def is_normal_triple(triples, is_relation_first=False): 5 | entities = set() 6 | for i, e in enumerate(triples): 7 | key = 0 if is_relation_first else 2 8 | if i % 3 != key: 9 | entities.add(e) 10 | return len(entities) == 2 * int(len(triples) / 3) 11 | 12 | def is_multi_label(triples, is_relation_first=False): 13 | if is_normal_triple(triples, is_relation_first): 14 | return False 15 | if is_relation_first: 16 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))] 17 | else: 18 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))] 19 | # if is multi label, then, at least one entity pair appeared more than once 20 | return len(entity_pair) != len(set(entity_pair)) 21 | 22 | def is_over_lapping(triples, is_relation_first=False): 23 | if is_normal_triple(triples, is_relation_first): 24 | return False 25 | if is_relation_first: 26 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))] 27 | else: 28 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))] 29 | # remove the same entity_pair, then, if one entity appear more than once, it's overlapping 30 | entity_pair = set(entity_pair) 31 | entities = [] 32 | for pair in entity_pair: 33 | entities.extend(pair) 34 | entities = set(entities) 35 | return len(entities) != 2 * len(entity_pair) 36 | 37 | def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file): 38 | with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, open(epo_file, 'w') as f4, open(seo_file, 'w') as f5: 39 | cnt_normal = 0 40 | cnt_epo = 0 41 | cnt_seo = 0 42 | lines = f1.readlines() 43 | for line in lines: 44 | line = json.loads(line) 45 | print(len(line)) 46 | lengths, sents, spos = line[0], line[1], line[2] 47 | print(len(spos)) 48 | print(len(sents)) 49 | for i in range(len(sents)): 50 | new_line = dict() 51 | #print(sents[i]) 52 | #print(spos[i]) 53 | tokens = [word_dict[i] for i in sents[i]] 54 | sent = ' '.join(tokens) 55 | new_line['sentText'] = sent 56 | triples = np.reshape(spos[i], (-1,3)) 57 | relationMentions = [] 58 | for triple in triples: 59 | rel = dict() 60 | rel['em1Text'] = tokens[triple[0]] 61 | rel['em2Text'] = tokens[triple[1]] 62 | rel['label'] = rel_dict[triple[2]] 63 | relationMentions.append(rel) 64 | new_line['relationMentions'] = relationMentions 65 | f2.write(json.dumps(new_line) + '\n') 66 | if is_normal_triple(spos[i]): 67 | f3.write(json.dumps(new_line) + '\n') 68 | if is_multi_label(spos[i]): 69 | f4.write(json.dumps(new_line) + '\n') 70 | if is_over_lapping(spos[i]): 71 | f5.write(json.dumps(new_line) + '\n') 72 | 73 | if __name__ == '__main__': 74 | file_name = 'valid.json' 75 | output = 'new_valid.json' 76 | output_normal = 'new_valid_normal.json' 77 | output_epo = 'new_valid_epo.json' 78 | output_seo = 'new_valid_seo.json' 79 | with open('relations2id.json', 'r') as f1, open('words2id.json', 'r') as f2: 80 | rel2id = json.load(f1) 81 | words2id = json.load(f2) 82 | rel_dict = {j:i for i,j in rel2id.items()} 83 | word_dict = {j:i for i,j in words2id.items()} 84 | load_data(file_name, word_dict, rel_dict, output, output_normal, output_epo, output_seo) 85 | -------------------------------------------------------------------------------- /data/NYT/test_split_by_num/README.md: -------------------------------------------------------------------------------- 1 | Copy ../test.json here, and run split_by_num.py to split the test data by triple nums. 2 | -------------------------------------------------------------------------------- /data/NYT/test_split_by_num/split_by_num.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | 8 | 9 | test_1 = [] 10 | test_2 = [] 11 | test_3 = [] 12 | test_4 = [] 13 | test_other = [] 14 | 15 | with open('test.json') as f: 16 | for l in tqdm(f): 17 | a = json.loads(l) 18 | if not a['relationMentions']: 19 | continue 20 | line = { 21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 23 | } 24 | if not line['triple_list']: 25 | continue 26 | spo_num = len(line['triple_list']) 27 | if spo_num == 1: 28 | test_1.append(line) 29 | elif spo_num == 2: 30 | test_2.append(line) 31 | elif spo_num == 3: 32 | test_3.append(line) 33 | elif spo_num == 4: 34 | test_4.append(line) 35 | else: 36 | test_other.append(line) 37 | 38 | 39 | with codecs.open('test_triples_1.json', 'w', encoding='utf-8') as f: 40 | json.dump(test_1, f, indent=4, ensure_ascii=False) 41 | 42 | with codecs.open('test_triples_2.json', 'w', encoding='utf-8') as f: 43 | json.dump(test_2, f, indent=4, ensure_ascii=False) 44 | 45 | with codecs.open('test_triples_3.json', 'w', encoding='utf-8') as f: 46 | json.dump(test_3, f, indent=4, ensure_ascii=False) 47 | 48 | with codecs.open('test_triples_4.json', 'w', encoding='utf-8') as f: 49 | json.dump(test_4, f, indent=4, ensure_ascii=False) 50 | 51 | with codecs.open('test_triples_5.json', 'w', encoding='utf-8') as f: 52 | json.dump(test_other, f, indent=4, ensure_ascii=False) 53 | -------------------------------------------------------------------------------- /data/NYT/test_split_by_type/README.md: -------------------------------------------------------------------------------- 1 | Run generate_triples.py to get triple files. 2 | -------------------------------------------------------------------------------- /data/NYT/test_split_by_type/generate_triples.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | 8 | 9 | test_normal = [] 10 | test_epo = [] 11 | test_seo = [] 12 | 13 | with open('test_normal.json') as f: 14 | for l in tqdm(f): 15 | a = json.loads(l) 16 | if not a['relationMentions']: 17 | continue 18 | line = { 19 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 20 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 21 | } 22 | if not line['triple_list']: 23 | continue 24 | spo_num = len(line['triple_list']) 25 | test_normal.append(line) 26 | 27 | with open('test_epo.json') as f: 28 | for l in tqdm(f): 29 | a = json.loads(l) 30 | if not a['relationMentions']: 31 | continue 32 | line = { 33 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 34 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 35 | } 36 | if not line['triple_list']: 37 | continue 38 | spo_num = len(line['triple_list']) 39 | test_epo.append(line) 40 | 41 | with open('test_seo.json') as f: 42 | for l in tqdm(f): 43 | a = json.loads(l) 44 | if not a['relationMentions']: 45 | continue 46 | line = { 47 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 48 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 49 | } 50 | if not line['triple_list']: 51 | continue 52 | spo_num = len(line['triple_list']) 53 | test_seo.append(line) 54 | 55 | with codecs.open('test_triples_normal.json', 'w', encoding='utf-8') as f: 56 | json.dump(test_normal, f, indent=4, ensure_ascii=False) 57 | 58 | with codecs.open('test_triples_epo.json', 'w', encoding='utf-8') as f: 59 | json.dump(test_epo, f, indent=4, ensure_ascii=False) 60 | 61 | with codecs.open('test_triples_seo.json', 'w', encoding='utf-8') as f: 62 | json.dump(test_seo, f, indent=4, ensure_ascii=False) 63 | 64 | -------------------------------------------------------------------------------- /data/WebNLG/README.md: -------------------------------------------------------------------------------- 1 | First, download and process the raw WebNLG data in dir raw_WebNLG/ 2 | 3 | Then run build_data.py to get triple files. 4 | 5 | For WebNLG, please train the model for at least 10 epochs. 6 | 7 | -------------------------------------------------------------------------------- /data/WebNLG/build_data.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | 8 | rel_set = set() 9 | 10 | 11 | train_data = [] 12 | dev_data = [] 13 | test_data = [] 14 | 15 | with open('train.json') as f: 16 | for l in tqdm(f): 17 | a = json.loads(l) 18 | if not a['relationMentions']: 19 | continue 20 | line = { 21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 23 | } 24 | if not line['triple_list']: 25 | continue 26 | train_data.append(line) 27 | for rm in a['relationMentions']: 28 | if rm['label'] != 'None': 29 | rel_set.add(rm['label']) 30 | 31 | 32 | with open('dev.json') as f: 33 | for l in tqdm(f): 34 | a = json.loads(l) 35 | if not a['relationMentions']: 36 | continue 37 | line = { 38 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 39 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 40 | } 41 | if not line['triple_list']: 42 | continue 43 | dev_data.append(line) 44 | for rm in a['relationMentions']: 45 | if rm['label'] != 'None': 46 | rel_set.add(rm['label']) 47 | 48 | 49 | with open('test.json') as f: 50 | for l in tqdm(f): 51 | a = json.loads(l) 52 | if not a['relationMentions']: 53 | continue 54 | line = { 55 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 56 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 57 | } 58 | if not line['triple_list']: 59 | continue 60 | test_data.append(line) 61 | 62 | 63 | 64 | id2rel = {i:j for i,j in enumerate(sorted(rel_set))} 65 | rel2id = {j:i for i,j in id2rel.items()} 66 | 67 | 68 | with codecs.open('rel2id.json', 'w', encoding='utf-8') as f: 69 | json.dump([id2rel, rel2id], f, indent=4, ensure_ascii=False) 70 | 71 | 72 | with codecs.open('train_triples.json', 'w', encoding='utf-8') as f: 73 | json.dump(train_data, f, indent=4, ensure_ascii=False) 74 | 75 | 76 | with codecs.open('dev_triples.json', 'w', encoding='utf-8') as f: 77 | json.dump(dev_data, f, indent=4, ensure_ascii=False) 78 | 79 | 80 | with codecs.open('test_triples.json', 'w', encoding='utf-8') as f: 81 | json.dump(test_data, f, indent=4, ensure_ascii=False) 82 | -------------------------------------------------------------------------------- /data/WebNLG/raw_WebNLG/README.md: -------------------------------------------------------------------------------- 1 | Download the CopyR's WebNLG dataset at https://drive.google.com/open?id=1zISxYa-8ROe2Zv8iRc82jY9QsQrfY1Vj 2 | 3 | Unzip it and all we need is webnlg/pre_processed_data/* 4 | 5 | Run generate.py to transfromer the numerical mentions to string mentions. Move the transformed files to ../ 6 | 7 | dev.json -> new_dev.json (then rename it test.json) 8 | train.json -> new_train.json (then rename it train.json) 9 | valid.json -> new_valid.json (then rename it dev.json) 10 | 11 | Also, the test data will be split by sentence types, i.e., Normal, SEO, and EPO. Then move the split files to ../test_split_by_type/ 12 | 13 | -------------------------------------------------------------------------------- /data/WebNLG/raw_WebNLG/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | def is_normal_triple(triples, is_relation_first=False): 5 | entities = set() 6 | for i, e in enumerate(triples): 7 | key = 0 if is_relation_first else 2 8 | if i % 3 != key: 9 | entities.add(e) 10 | return len(entities) == 2 * int(len(triples) / 3) 11 | 12 | def is_multi_label(triples, is_relation_first=False): 13 | if is_normal_triple(triples, is_relation_first): 14 | return False 15 | if is_relation_first: 16 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))] 17 | else: 18 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))] 19 | # if is multi label, then, at least one entity pair appeared more than once 20 | return len(entity_pair) != len(set(entity_pair)) 21 | 22 | def is_over_lapping(triples, is_relation_first=False): 23 | if is_normal_triple(triples, is_relation_first): 24 | return False 25 | if is_relation_first: 26 | entity_pair = [tuple(triples[3 * i + 1: 3 * i + 3]) for i in range(int(len(triples) / 3))] 27 | else: 28 | entity_pair = [tuple(triples[3 * i: 3 * i + 2]) for i in range(int(len(triples) / 3))] 29 | # remove the same entity_pair, then, if one entity appear more than once, it's overlapping 30 | entity_pair = set(entity_pair) 31 | entities = [] 32 | for pair in entity_pair: 33 | entities.extend(pair) 34 | entities = set(entities) 35 | return len(entities) != 2 * len(entity_pair) 36 | 37 | def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file): 38 | with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, open(epo_file, 'w') as f4, open(seo_file, 'w') as f5: 39 | cnt_normal = 0 40 | cnt_epo = 0 41 | cnt_seo = 0 42 | lines = f1.readlines() 43 | for line in lines: 44 | line = json.loads(line) 45 | print(len(line)) 46 | sents, spos = line[0], line[1] 47 | print(len(spos)) 48 | print(len(sents)) 49 | for i in range(len(sents)): 50 | new_line = dict() 51 | #print(sents[i]) 52 | #print(spos[i]) 53 | tokens = [word_dict[i] for i in sents[i]] 54 | sent = ' '.join(tokens) 55 | new_line['sentText'] = sent 56 | triples = np.reshape(spos[i], (-1,3)) 57 | relationMentions = [] 58 | for triple in triples: 59 | rel = dict() 60 | rel['em1Text'] = tokens[triple[0]] 61 | rel['em2Text'] = tokens[triple[1]] 62 | rel['label'] = rel_dict[triple[2]] 63 | relationMentions.append(rel) 64 | new_line['relationMentions'] = relationMentions 65 | f2.write(json.dumps(new_line) + '\n') 66 | if is_normal_triple(spos[i]): 67 | f3.write(json.dumps(new_line) + '\n') 68 | if is_multi_label(spos[i]): 69 | f4.write(json.dumps(new_line) + '\n') 70 | if is_over_lapping(spos[i]): 71 | f5.write(json.dumps(new_line) + '\n') 72 | 73 | if __name__ == '__main__': 74 | file_name = 'dev.json' 75 | output = 'new_dev.json' 76 | output_normal = 'new_dev_normal.json' 77 | output_epo = 'new_dev_epo.json' 78 | output_seo = 'new_dev_seo.json' 79 | with open('relations2id.json', 'r') as f1, open('words2id.json', 'r') as f2: 80 | rel2id = json.load(f1) 81 | words2id = json.load(f2) 82 | rel_dict = {j:i for i,j in rel2id.items()} 83 | word_dict = {j:i for i,j in words2id.items()} 84 | load_data(file_name, word_dict, rel_dict, output, output_normal, output_epo, output_seo) 85 | -------------------------------------------------------------------------------- /data/WebNLG/test_split_by_num/README.md: -------------------------------------------------------------------------------- 1 | Copy ../test.json to here. 2 | 3 | Run split_by_num.py to split the test dataset and store the split files in the form of triples. 4 | -------------------------------------------------------------------------------- /data/WebNLG/test_split_by_num/split_by_num.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | 8 | 9 | test_1 = [] 10 | test_2 = [] 11 | test_3 = [] 12 | test_4 = [] 13 | test_other = [] 14 | 15 | with open('test.json') as f: 16 | for l in tqdm(f): 17 | a = json.loads(l) 18 | if not a['relationMentions']: 19 | continue 20 | line = { 21 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 22 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 23 | } 24 | if not line['triple_list']: 25 | continue 26 | spo_num = len(line['triple_list']) 27 | if spo_num == 1: 28 | test_1.append(line) 29 | elif spo_num == 2: 30 | test_2.append(line) 31 | elif spo_num == 3: 32 | test_3.append(line) 33 | elif spo_num == 4: 34 | test_4.append(line) 35 | else: 36 | test_other.append(line) 37 | 38 | 39 | with codecs.open('test_triples_1.json', 'w', encoding='utf-8') as f: 40 | json.dump(test_1, f, indent=4, ensure_ascii=False) 41 | 42 | with codecs.open('test_triples_2.json', 'w', encoding='utf-8') as f: 43 | json.dump(test_2, f, indent=4, ensure_ascii=False) 44 | 45 | with codecs.open('test_triples_3.json', 'w', encoding='utf-8') as f: 46 | json.dump(test_3, f, indent=4, ensure_ascii=False) 47 | 48 | with codecs.open('test_triples_4.json', 'w', encoding='utf-8') as f: 49 | json.dump(test_4, f, indent=4, ensure_ascii=False) 50 | 51 | with codecs.open('test_triples_5.json', 'w', encoding='utf-8') as f: 52 | json.dump(test_other, f, indent=4, ensure_ascii=False) 53 | -------------------------------------------------------------------------------- /data/WebNLG/test_split_by_type/README.md: -------------------------------------------------------------------------------- 1 | Move the split files from raw_WebNLG/ here. 2 | 3 | Then run generate_triples.py to get triple files. 4 | -------------------------------------------------------------------------------- /data/WebNLG/test_split_by_type/generate_triples.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | 3 | 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | 8 | 9 | test_normal = [] 10 | test_epo = [] 11 | test_seo = [] 12 | 13 | with open('test_normal.json') as f: 14 | for l in tqdm(f): 15 | a = json.loads(l) 16 | if not a['relationMentions']: 17 | continue 18 | line = { 19 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 20 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 21 | } 22 | if not line['triple_list']: 23 | continue 24 | spo_num = len(line['triple_list']) 25 | test_normal.append(line) 26 | 27 | with open('test_epo.json') as f: 28 | for l in tqdm(f): 29 | a = json.loads(l) 30 | if not a['relationMentions']: 31 | continue 32 | line = { 33 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 34 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 35 | } 36 | if not line['triple_list']: 37 | continue 38 | spo_num = len(line['triple_list']) 39 | test_epo.append(line) 40 | 41 | with open('test_seo.json') as f: 42 | for l in tqdm(f): 43 | a = json.loads(l) 44 | if not a['relationMentions']: 45 | continue 46 | line = { 47 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 48 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 49 | } 50 | if not line['triple_list']: 51 | continue 52 | spo_num = len(line['triple_list']) 53 | test_seo.append(line) 54 | 55 | with codecs.open('test_triples_normal.json', 'w', encoding='utf-8') as f: 56 | json.dump(test_normal, f, indent=4, ensure_ascii=False) 57 | 58 | with codecs.open('test_triples_epo.json', 'w', encoding='utf-8') as f: 59 | json.dump(test_epo, f, indent=4, ensure_ascii=False) 60 | 61 | with codecs.open('test_triples_seo.json', 'w', encoding='utf-8') as f: 62 | json.dump(test_seo, f, indent=4, ensure_ascii=False) 63 | 64 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from random import choice 3 | from fastNLP import RandomSampler, DataSetIter, DataSet, TorchLoaderIter 4 | from fastNLP.io import JsonLoader 5 | from fastNLP import Vocabulary 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | from utils import get_tokenizer 10 | 11 | BERT_MAX_LEN = 512 12 | 13 | tokenizer = get_tokenizer('data/vocab.txt') 14 | 15 | 16 | def load_data(train_path, dev_path, test_path, rel_dict_path): 17 | """ 18 | Load dataset from origin dataset using fastNLP 19 | :param train_path: train data path 20 | :param dev_path: dev data path 21 | :param test_path: test data path 22 | :param rel_dict_path: relation dictionary path 23 | :return: data_bundle(contain train, dev, test data), rel_vocab(vocabulary of relations), num_rels(number of relations) 24 | """ 25 | paths = {'train': train_path, 'dev': dev_path, "test": test_path} 26 | loader = JsonLoader({"text": "text", "triple_list": "triple_list"}) 27 | data_bundle = loader.load(paths) 28 | print(data_bundle) 29 | 30 | id2rel, rel2id = json.load(open(rel_dict_path)) 31 | rel_vocab = Vocabulary(unknown=None, padding=None) 32 | rel_vocab.add_word_lst(list(id2rel.values())) 33 | print(rel_vocab) 34 | num_rels = len(id2rel) 35 | print("number of relations: " + str(num_rels)) 36 | 37 | return data_bundle, rel_vocab, num_rels 38 | 39 | 40 | def find_head_idx(source, target): 41 | """ 42 | Find the index of the head of target in source 43 | :param source: tokenizer list 44 | :param target: target subject or object 45 | :return: the index of the head of target in source, if not found, return -1 46 | """ 47 | target_len = len(target) 48 | for i in range(len(source)): 49 | if source[i: i + target_len] == target: 50 | return i 51 | return -1 52 | 53 | 54 | # def preprocess_data(config, dataset, rel_vocab, num_rels, is_test): 55 | # token_ids_list = [] 56 | # masks_list = [] 57 | # text_len_list = [] 58 | # sub_heads_list = [] 59 | # sub_tails_list = [] 60 | # sub_head_list = [] 61 | # sub_tail_list = [] 62 | # obj_heads_list = [] 63 | # obj_tails_list = [] 64 | # triple_list = [] 65 | # tokens_list = [] 66 | # for item in range(len(dataset)): 67 | # # get tokenizer list 68 | # ins_json_data = dataset[item] 69 | # text = ins_json_data['text'] 70 | # text = ' '.join(text.split()[:config.max_len]) 71 | # tokens = tokenizer.tokenize(text) 72 | # if len(tokens) > BERT_MAX_LEN: 73 | # tokens = tokens[: BERT_MAX_LEN] 74 | # text_len = len(tokens) 75 | # 76 | # if not is_test: 77 | # # build subject to relation_object map, s2ro_map[(sub_head, sub_tail)] = (obj_head, obj_tail, rel_idx) 78 | # s2ro_map = {} 79 | # for triple in ins_json_data['triple_list']: 80 | # triple = (tokenizer.tokenize(triple[0])[1:-1], triple[1], tokenizer.tokenize(triple[2])[1:-1]) 81 | # sub_head_idx = find_head_idx(tokens, triple[0]) 82 | # obj_head_idx = find_head_idx(tokens, triple[2]) 83 | # if sub_head_idx != -1 and obj_head_idx != -1: 84 | # sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1) 85 | # if sub not in s2ro_map: 86 | # s2ro_map[sub] = [] 87 | # s2ro_map[sub].append( 88 | # (obj_head_idx, obj_head_idx + len(triple[2]) - 1, rel_vocab.to_index(triple[1]))) 89 | # 90 | # if s2ro_map: 91 | # token_ids, segment_ids = tokenizer.encode(first=text) 92 | # masks = segment_ids 93 | # if len(token_ids) > text_len: 94 | # token_ids = token_ids[:text_len] 95 | # masks = masks[:text_len] 96 | # token_ids = np.array(token_ids) 97 | # masks = np.array(masks) + 1 98 | # # sub_heads[i]: if index i is the head of any subjects in text 99 | # sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len) 100 | # for s in s2ro_map: 101 | # sub_heads[s[0]] = 1 102 | # sub_tails[s[1]] = 1 103 | # # randomly select one subject in text and set sub_head and sub_tail 104 | # sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys())) 105 | # sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len) 106 | # sub_head[sub_head_idx] = 1 107 | # sub_tail[sub_tail_idx] = 1 108 | # obj_heads, obj_tails = np.zeros((text_len, num_rels)), np.zeros((text_len, num_rels)) 109 | # for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []): 110 | # obj_heads[ro[0]][ro[2]] = 1 111 | # obj_tails[ro[1]][ro[2]] = 1 112 | # token_ids_list.append(token_ids) 113 | # masks_list.append(masks) 114 | # text_len_list.append(text_len) 115 | # sub_heads_list.append(sub_heads) 116 | # sub_tails_list.append(sub_tails) 117 | # sub_head_list.append(sub_head) 118 | # sub_tail_list.append(sub_tail) 119 | # obj_heads_list.append(obj_heads) 120 | # obj_tails_list.append(obj_tails) 121 | # triple_list.append(ins_json_data['triple_list']) 122 | # tokens_list.append(tokens) 123 | # else: 124 | # token_ids, segment_ids = tokenizer.encode(first=text) 125 | # masks = segment_ids 126 | # if len(token_ids) > text_len: 127 | # token_ids = token_ids[:text_len] 128 | # masks = masks[:text_len] 129 | # token_ids = np.array(token_ids) 130 | # masks = np.array(masks) + 1 131 | # # initialize these variant with 0 132 | # sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len) 133 | # sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len) 134 | # obj_heads, obj_tails = np.zeros((text_len, num_rels)), np.zeros((text_len, num_rels)) 135 | # token_ids_list.append(token_ids) 136 | # masks_list.append(masks) 137 | # text_len_list.append(text_len) 138 | # sub_heads_list.append(sub_heads) 139 | # sub_tails_list.append(sub_tails) 140 | # sub_head_list.append(sub_head) 141 | # sub_tail_list.append(sub_tail) 142 | # obj_heads_list.append(obj_heads) 143 | # obj_tails_list.append(obj_tails) 144 | # triple_list.append(ins_json_data['triple_list']) 145 | # tokens_list.append(tokens) 146 | # 147 | # data_dict = {'token_ids': token_ids_list, 148 | # 'marks': masks_list, 149 | # 'text_len': text_len_list, 150 | # 'sub_heads': sub_heads_list, 151 | # 'sub_tails': sub_tails_list, 152 | # 'sub_head': sub_head_list, 153 | # 'sub_tail': sub_tail_list, 154 | # 'obj_heads': obj_heads_list, 155 | # 'obj_tails': obj_tails_list, 156 | # 'triple_list': triple_list, 157 | # 'tokens': tokens_list} 158 | # process_dataset = DataSet(data_dict) 159 | # process_dataset.set_input('token_ids', 'marks', 'text_len', 'sub_head', 'sub_tail', 'tokens') 160 | # process_dataset.set_target('sub_heads', 'sub_tails', 'obj_heads', 'obj_tails', 'triple_list') 161 | # return process_dataset 162 | 163 | 164 | class MyDataset(Dataset): 165 | def __init__(self, config, dataset, rel_vocab, num_rels, is_test): 166 | self.config = config 167 | self.dataset = dataset 168 | self.rel_vocab = rel_vocab 169 | self.num_rels = num_rels 170 | self.is_test = is_test 171 | self.tokenizer = tokenizer 172 | 173 | def __getitem__(self, item): 174 | """ 175 | The way of reading data, so that we can get data through MyDataset[i] 176 | :param item: index number 177 | :return: the item st text attribute in dataset 178 | """ 179 | # get tokenizer list 180 | ins_json_data = self.dataset[item] 181 | text = ins_json_data['text'] 182 | text = ' '.join(text.split()[:self.config.max_len]) 183 | tokens = self.tokenizer.tokenize(text) 184 | if len(tokens) > BERT_MAX_LEN: 185 | tokens = tokens[: BERT_MAX_LEN] 186 | text_len = len(tokens) 187 | 188 | if not self.is_test: 189 | # build subject to relation_object map, s2ro_map[(sub_head, sub_tail)] = (obj_head, obj_tail, rel_idx) 190 | s2ro_map = {} 191 | for triple in ins_json_data['triple_list']: 192 | triple = (self.tokenizer.tokenize(triple[0])[1:-1], triple[1], self.tokenizer.tokenize(triple[2])[1:-1]) 193 | sub_head_idx = find_head_idx(tokens, triple[0]) 194 | obj_head_idx = find_head_idx(tokens, triple[2]) 195 | if sub_head_idx != -1 and obj_head_idx != -1: 196 | sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1) 197 | if sub not in s2ro_map: 198 | s2ro_map[sub] = [] 199 | s2ro_map[sub].append( 200 | (obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel_vocab.to_index(triple[1]))) 201 | 202 | if s2ro_map: 203 | token_ids, segment_ids = self.tokenizer.encode(first=text) 204 | masks = segment_ids 205 | if len(token_ids) > text_len: 206 | token_ids = token_ids[:text_len] 207 | masks = masks[:text_len] 208 | token_ids = np.array(token_ids) 209 | masks = np.array(masks) + 1 210 | # sub_heads[i]: if index i is the head of any subjects in text 211 | sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len) 212 | for s in s2ro_map: 213 | sub_heads[s[0]] = 1 214 | sub_tails[s[1]] = 1 215 | # randomly select one subject in text and set sub_head and sub_tail 216 | sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys())) 217 | sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len) 218 | sub_head[sub_head_idx] = 1 219 | sub_tail[sub_tail_idx] = 1 220 | obj_heads, obj_tails = np.zeros((text_len, self.num_rels)), np.zeros((text_len, self.num_rels)) 221 | for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []): 222 | obj_heads[ro[0]][ro[2]] = 1 223 | obj_tails[ro[1]][ro[2]] = 1 224 | return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, \ 225 | ins_json_data['triple_list'], tokens 226 | else: 227 | return None 228 | else: 229 | token_ids, segment_ids = self.tokenizer.encode(first=text) 230 | masks = segment_ids 231 | if len(token_ids) > text_len: 232 | token_ids = token_ids[:text_len] 233 | masks = masks[:text_len] 234 | token_ids = np.array(token_ids) 235 | masks = np.array(masks) + 1 236 | # initialize these variant with 0 237 | sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len) 238 | sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len) 239 | obj_heads, obj_tails = np.zeros((text_len, self.num_rels)), np.zeros((text_len, self.num_rels)) 240 | return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, \ 241 | ins_json_data['triple_list'], tokens 242 | 243 | def __len__(self): 244 | return len(self.dataset) 245 | 246 | 247 | def my_collate_fn(batch): 248 | """ 249 | Merge data in one batch 250 | :param batch: the batch size 251 | :return: a dictionary 252 | """ 253 | batch = list(filter(lambda x: x is not None, batch)) 254 | batch.sort(key=lambda x: x[2], reverse=True) 255 | token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, triples, tokens = zip( 256 | *batch) 257 | cur_batch = len(batch) 258 | max_text_len = max(text_len) 259 | batch_token_ids = torch.LongTensor(cur_batch, max_text_len).zero_() 260 | batch_masks = torch.LongTensor(cur_batch, max_text_len).zero_() 261 | batch_sub_heads = torch.Tensor(cur_batch, max_text_len).zero_() 262 | batch_sub_tails = torch.Tensor(cur_batch, max_text_len).zero_() 263 | batch_sub_head = torch.Tensor(cur_batch, max_text_len).zero_() 264 | batch_sub_tail = torch.Tensor(cur_batch, max_text_len).zero_() 265 | batch_obj_heads = torch.Tensor(cur_batch, max_text_len, 24).zero_() 266 | batch_obj_tails = torch.Tensor(cur_batch, max_text_len, 24).zero_() 267 | 268 | for i in range(cur_batch): 269 | batch_token_ids[i, :text_len[i]].copy_(torch.from_numpy(token_ids[i])) 270 | batch_masks[i, :text_len[i]].copy_(torch.from_numpy(masks[i])) 271 | batch_sub_heads[i, :text_len[i]].copy_(torch.from_numpy(sub_heads[i])) 272 | batch_sub_tails[i, :text_len[i]].copy_(torch.from_numpy(sub_tails[i])) 273 | batch_sub_head[i, :text_len[i]].copy_(torch.from_numpy(sub_head[i])) 274 | batch_sub_tail[i, :text_len[i]].copy_(torch.from_numpy(sub_tail[i])) 275 | batch_obj_heads[i, :text_len[i], :].copy_(torch.from_numpy(obj_heads[i])) 276 | batch_obj_tails[i, :text_len[i], :].copy_(torch.from_numpy(obj_tails[i])) 277 | 278 | return {'token_ids': batch_token_ids, 279 | 'mask': batch_masks, 280 | 'sub_head': batch_sub_head, 281 | 'sub_tail': batch_sub_tail, 282 | 'tokens': tokens}, \ 283 | {'sub_heads': batch_sub_heads, 284 | 'sub_tails': batch_sub_tails, 285 | 'obj_heads': batch_obj_heads, 286 | 'obj_tails': batch_obj_tails, 287 | 'triples': triples, 288 | } 289 | 290 | 291 | def get_data_iter(config, dataset, rel_vocab, num_rels, is_test=False, num_workers=0, collate_fn=my_collate_fn): 292 | """ 293 | Build a data Iterator that combines a dataset and a sampler, and provides single- or multi-process iterators 294 | over the dataset. 295 | :param config: configuration 296 | :param dataset: certain dataset in data bundle processed by fastNLP 297 | :param rel_vocab: vocabulary of relations 298 | :param num_rels: the number of relations 299 | :param is_test: if not test, use RandomSampler; if test, use SequentialSampler 300 | :param num_workers: how many subprocesses to use for data loading 301 | :param collate_fn: merges a list of samples to form a mini-batch 302 | :return: a dataloader 303 | """ 304 | dataset = MyDataset(config, dataset, rel_vocab, num_rels, is_test) 305 | # dataset = preprocess_data(config, dataset, rel_vocab, num_rels, is_test) 306 | # print(dataset) 307 | if not is_test: 308 | sampler = RandomSampler() 309 | data_iter = TorchLoaderIter(dataset=dataset, 310 | collate_fn=collate_fn, 311 | batch_size=config.batch_size, 312 | num_workers=num_workers, 313 | pin_memory=True) 314 | # data_iter = DataSetIter(dataset=dataset, 315 | # batch_size=config.batch_size, 316 | # num_workers=num_workers, 317 | # sampler=sampler, 318 | # pin_memory=True) 319 | else: 320 | data_iter = TorchLoaderIter(dataset=dataset, 321 | collate_fn=collate_fn, 322 | batch_size=1, 323 | num_workers=num_workers, 324 | pin_memory=True) 325 | # data_iter = DataSetIter(dataset=dataset, 326 | # batch_size=1, 327 | # num_workers=num_workers, 328 | # pin_memory=True) 329 | return data_iter 330 | 331 | # class DataPreFetcher(object): 332 | # def __init__(self, loader): 333 | # self.loader = iter(loader) 334 | # # self.stream = torch.cuda.Stream() 335 | # self.preload() 336 | # 337 | # def preload(self): 338 | # try: 339 | # self.next_data = next(self.loader) 340 | # except StopIteration: 341 | # self.next_data = None 342 | # return 343 | # # with torch.cuda.stream(self.stream): 344 | # # for k, v in self.next_data.items(): 345 | # # if isinstance(v, torch.Tensor): 346 | # # self.next_data[k] = self.next_data[k].cuda(non_blocking=True) 347 | # 348 | # def next(self): 349 | # # torch.cuda.current_stream().wait_stream(self.stream) 350 | # data = self.next_data 351 | # self.preload() 352 | # return data 353 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from fastNLP import Callback 7 | from transformers import * 8 | from utils import metric 9 | 10 | 11 | class CasRel(nn.Module): 12 | def __init__(self, config, num_rels): 13 | super().__init__() 14 | self.config = config 15 | self.bert_encoder = BertModel.from_pretrained(self.config.bert_model_name) 16 | self.bert_dim = 768 # the size of hidden state 17 | self.sub_heads_linear = nn.Linear(self.bert_dim, 1) 18 | self.sub_tails_linear = nn.Linear(self.bert_dim, 1) 19 | self.obj_heads_linear = nn.Linear(self.bert_dim, num_rels) 20 | self.obj_tails_linear = nn.Linear(self.bert_dim, num_rels) 21 | 22 | def get_encoded_text(self, token_ids, mask): 23 | # [batch_size, seq_len, bert_dim(768)] 24 | encoded_text = self.bert_encoder(token_ids, attention_mask=mask)[0] 25 | return encoded_text 26 | 27 | def get_subs(self, encoded_text): 28 | """ 29 | Subject Taggers 30 | :param encoded_text: input sentenced pretrained with BERT 31 | :return: predicted subject head, predicted object head 32 | """ 33 | # [batch_size, seq_len, 1] 34 | pred_sub_heads = self.sub_heads_linear(encoded_text) 35 | pred_sub_heads = torch.sigmoid(pred_sub_heads) 36 | # [batch_size, seq_len, 1] 37 | pred_sub_tails = self.sub_tails_linear(encoded_text) 38 | pred_sub_tails = torch.sigmoid(pred_sub_tails) 39 | return pred_sub_heads, pred_sub_tails 40 | 41 | def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text): 42 | """ 43 | Relation-specific Object Taggers 44 | :param sub_head_mapping: 45 | :param sub_tail_mapping: 46 | :param encoded_text: input sentenced pretrained with BERT 47 | :return: predicted object head, predicted object tail 48 | """ 49 | # [batch_size, 1, bert_dim] 50 | sub_head = torch.matmul(sub_head_mapping, encoded_text) 51 | # [batch_size, 1, bert_dim] 52 | sub_tail = torch.matmul(sub_tail_mapping, encoded_text) 53 | # [batch_size, 1 bert_dim] 54 | sub = (sub_head + sub_tail) / 2 55 | # [batch_size, seq_len, bert_dim] 56 | encoded_text = encoded_text + sub 57 | # [batch_size, seq_len, rel_num] 58 | pred_obj_heads = self.obj_heads_linear(encoded_text) 59 | pred_obj_heads = torch.sigmoid(pred_obj_heads) 60 | pred_obj_tails = self.obj_tails_linear(encoded_text) 61 | pred_obj_tails = torch.sigmoid(pred_obj_tails) 62 | return pred_obj_heads, pred_obj_tails 63 | 64 | def forward(self, data): 65 | # [batch_size, seq_len] 66 | token_ids = data['token_ids'] 67 | # [batch_size, seq_len] 68 | mask = data['mask'] 69 | # [batch_size, seq_len, bert_dim(768)] 70 | encoded_text = self.get_encoded_text(token_ids, mask) 71 | # [batch_size, seq_len, 1] 72 | pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text) 73 | # [batch_size, 1, seq_len] 74 | sub_head_mapping = data['sub_head'].unsqueeze(1) 75 | # [batch_size, 1, seq_len] 76 | sub_tail_mapping = data['sub_tail'].unsqueeze(1) 77 | # [batch_size, seq_len, rel_num] 78 | pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, 79 | encoded_text) 80 | return pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails 81 | 82 | 83 | class MyCallBack(Callback): 84 | def __init__(self, data_iter, rel_vocab, config): 85 | super().__init__() 86 | self.loss_sum = 0 87 | self.global_step = 0 88 | 89 | self.data_iter = data_iter 90 | self.rel_vocab = rel_vocab 91 | self.config = config 92 | 93 | def logging(self, s, print_=True, log_=True): 94 | if print_: 95 | print(s) 96 | if log_: 97 | with open(os.path.join(self.config.save_logs_dir, self.config.log_save_name), 'a+') as f_log: 98 | f_log.write(s + '\n') 99 | 100 | # define the loss function 101 | def loss(self, pred, gold, mask): 102 | pred = pred.squeeze(-1) 103 | los = F.binary_cross_entropy(pred, gold, reduction='none') 104 | if los.shape != mask.shape: 105 | mask = mask.unsqueeze(-1) 106 | los = torch.sum(los * mask) / torch.sum(mask) 107 | return los 108 | 109 | def on_train_begin(self): 110 | self.best_f1_score = 0 111 | self.best_precision = 0 112 | self.best_recall = 0 113 | 114 | self.best_epoch = 0 115 | self.init_time = time.time() 116 | self.start_time = time.time() 117 | print("-" * 5 + "Initializing the model" + "-" * 5) 118 | 119 | def on_epoch_begin(self): 120 | self.eval_start_time = time.time() 121 | 122 | def on_batch_begin(self, batch_x, batch_y, indices): 123 | pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = self.model(batch_x) 124 | sub_heads_loss = self.loss(batch_y['sub_heads'], pred_sub_heads, batch_x['mask']) 125 | sub_tails_loss = self.loss(batch_y['sub_tails'], pred_sub_tails, batch_x['mask']) 126 | obj_heads_loss = self.loss(batch_y['obj_heads'], pred_obj_heads, batch_x['mask']) 127 | obj_tails_loss = self.loss(batch_y['obj_tails'], pred_obj_tails, batch_x['mask']) 128 | total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss) 129 | 130 | self.optimizer.zero_grad() 131 | total_loss.backward() 132 | self.optimizer.step() 133 | 134 | self.loss_sum += total_loss.item() 135 | 136 | def on_epoch_end(self): 137 | precision, recall, f1_score = metric(self.data_iter, self.rel_vocab, self.config, self.model) 138 | self.logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'. 139 | format(self.epoch, time.time() - self.eval_start_time, f1_score, precision, recall)) 140 | if f1_score > self.best_f1_score: 141 | self.best_f1_score = f1_score 142 | self.best_epoch = self.epoch 143 | self.best_precision = precision 144 | self.best_recall = recall 145 | self.logging("Saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}". 146 | format(self.best_epoch, self.best_f1_score, precision, recall)) 147 | # save the best model 148 | path = os.path.join(self.config.save_weights_dir, self.config.weights_save_name) 149 | torch.save(self.model.state_dict(), path) 150 | 151 | def on_train_end(self): 152 | self.logging("-" * 5 + "Finish training" + "-" * 5) 153 | self.logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s". 154 | format(self.best_epoch, self.best_f1_score, self.best_precision, self.best_recall, 155 | time.time() - self.init_time)) 156 | -------------------------------------------------------------------------------- /pretrained_bert_models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/pretrained_bert_models/.DS_Store -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from fastNLP import Trainer, LossFunc 2 | from data_loader import load_data 3 | from model import CasRel 4 | from utils import get_tokenizer 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='Model Controller') 8 | parser.add_argument('--train', default=False, type=bool, help='to train the HBT model, python run.py --train=True') 9 | parser.add_argument('--dataset', default='WebNLG', type=str, 10 | help='specify the dataset from ["NYT","WebNLG"]') 11 | args = parser.parse_args() 12 | 13 | if __name__ == '__main__': 14 | # pre-trained bert model name 15 | bert_model_name = 'en-base-cased' 16 | 17 | vocab_path = 'data/vocab.txt' 18 | # load dataset 19 | # dataset = args.dataset 20 | dataset = "NYT" 21 | train_path = 'data/' + dataset + '/train_triples.json' 22 | dev_path = 'data/' + dataset + '/dev_triples.json' 23 | test_path = 'data/' + dataset + '/test_triples.json' # overall test 24 | # test_path = 'data/' + dataset + '/test_split_by_num/test_triples_5.json' # ['1','2','3','4','5'] 25 | # test_path = 'data/' + dataset + '/test_split_by_type/test_triples_seo.json' # ['normal', 'seo', 'epo'] 26 | rel_dict_path = 'data/' + dataset + '/rel2id.json' 27 | save_weights_path = 'saved_weights/' + dataset + '/best_model.weights' 28 | save_logs_path = '/saved_logs/' + dataset + '/log' 29 | 30 | # parameters 31 | LR = 1e-5 32 | tokenizer = get_tokenizer(vocab_path) 33 | data_bundle, rel_vocab, num_rels = load_data(train_path, dev_path, test_path, rel_dict_path) 34 | model = CasRel(bert_model_name, num_rels) 35 | 36 | if args.train: 37 | BATCH_SIZE = 6 38 | EPOCH = 100 39 | MAX_LEN = 100 40 | STEPS = len(data_bundle.get_dataset('train')) // BATCH_SIZE 41 | metric = SpanFPreRecMetric() 42 | optimizer = Adam(lr=LR) 43 | trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, batch_size=BATCH_SIZE, 44 | metrics=metric) 45 | -------------------------------------------------------------------------------- /saved_logs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_logs/.DS_Store -------------------------------------------------------------------------------- /saved_logs/NYT/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_logs/NYT/.DS_Store -------------------------------------------------------------------------------- /saved_weights/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_weights/.DS_Store -------------------------------------------------------------------------------- /saved_weights/NYT/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusun-nlp/CasRel_fastNLP/eb194941c136323ba8ac92e65757041dc3914dad/saved_weights/NYT/.DS_Store -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from fastNLP import Trainer 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import config 12 | from data_loader import load_data, get_data_iter 13 | from model import CasRel, MyCallBack 14 | from utils import metric 15 | 16 | seed = 1234 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | np.random.seed(seed) 20 | random.seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | parser = argparse.ArgumentParser(description='Model Controller') 25 | parser.add_argument('--model_name', type=str, default='CasRel', help='name of the model') 26 | parser.add_argument('--dataset', type=str, default='NYT', help='specify the dataset from ["NYT","WebNLG"]') 27 | parser.add_argument('--bert_model_name', type=str, default='bert-base-cased', help='the name of pretrained bert model') 28 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate') 29 | parser.add_argument('--multi_gpu', type=bool, default=False, help='if use multiple gpu') 30 | parser.add_argument('--batch_size', type=int, default=6) 31 | parser.add_argument('--max_epoch', type=int, default=200) 32 | parser.add_argument('--test_epoch', type=int, default=1) 33 | parser.add_argument('--max_len', type=int, default=150) 34 | parser.add_argument('--period', type=int, default=50) 35 | args = parser.parse_args() 36 | 37 | con = config.Config(args) 38 | 39 | # get the data and dataloader 40 | print("-" * 5 + "Starting processing data" + "-" * 5) 41 | data_bundle, rel_vocab, num_rels = load_data(con.train_path, con.dev_path, con.test_path, con.rel_dict_path) 42 | print("Test process data:") 43 | test_data_iter = get_data_iter(con, data_bundle.get_dataset('test'), rel_vocab, num_rels, is_test=True) 44 | print("-" * 5 + "Data processing done" + "-" * 5) 45 | 46 | # check the checkpoint dir 47 | if not os.path.exists(con.save_weights_dir): 48 | os.mkdir(con.save_weights_dir) 49 | # check the log dir 50 | if not os.path.exists(con.save_logs_dir): 51 | os.mkdir(con.save_logs_dir) 52 | 53 | 54 | def test(): 55 | model = CasRel(con, num_rels) 56 | path = os.path.join(con.save_weights_dir, con.weights_save_name) 57 | model.load_state_dict(torch.load(path)) 58 | model.cuda() 59 | model.eval() 60 | precision, recall, f1_score = metric(test_data_iter, rel_vocab, con, model, True) 61 | print("f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".format(f1_score, precision, recall)) 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from fastNLP import Trainer 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import config 12 | from data_loader import load_data, get_data_iter 13 | from model import CasRel, MyCallBack 14 | from utils import metric 15 | 16 | seed = 1234 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | np.random.seed(seed) 20 | random.seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | parser = argparse.ArgumentParser(description='Model Controller') 25 | parser.add_argument('--model_name', type=str, default='CasRel', help='name of the model') 26 | parser.add_argument('--dataset', type=str, default='NYT', help='specify the dataset from ["NYT","WebNLG"]') 27 | parser.add_argument('--bert_model_name', type=str, default='bert-base-cased', help='the name of pretrained bert model') 28 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate') 29 | parser.add_argument('--multi_gpu', type=bool, default=False, help='if use multiple gpu') 30 | parser.add_argument('--batch_size', type=int, default=6) 31 | parser.add_argument('--max_epoch', type=int, default=300) 32 | parser.add_argument('--test_epoch', type=int, default=5) 33 | parser.add_argument('--max_len', type=int, default=150) 34 | parser.add_argument('--period', type=int, default=50) 35 | args = parser.parse_args() 36 | 37 | con = config.Config(args) 38 | 39 | # get the data and dataloader 40 | print("-" * 5 + "Starting processing data" + "-" * 5) 41 | data_bundle, rel_vocab, num_rels = load_data(con.train_path, con.dev_path, con.test_path, con.rel_dict_path) 42 | print("Train process data:") 43 | train_data_iter = get_data_iter(con, data_bundle.get_dataset('train'), rel_vocab, num_rels) 44 | print("Dev process data:") 45 | dev_data_iter = get_data_iter(con, data_bundle.get_dataset('dev'), rel_vocab, num_rels, is_test=True) 46 | print("-" * 5 + "Data processing done" + "-" * 5) 47 | 48 | # check the checkpoint dir 49 | if not os.path.exists(con.save_weights_dir): 50 | os.mkdir(con.save_weights_dir) 51 | # check the log dir 52 | if not os.path.exists(con.save_logs_dir): 53 | os.mkdir(con.save_logs_dir) 54 | 55 | 56 | # define the loss function 57 | def loss(pred, gold, mask): 58 | pred = pred.squeeze(-1) 59 | los = F.binary_cross_entropy(pred, gold, reduction='none') 60 | if los.shape != mask.shape: 61 | mask = mask.unsqueeze(-1) 62 | los = torch.sum(los * mask) / torch.sum(mask) 63 | return los 64 | 65 | 66 | def logging(s, print_=True, log_=True): 67 | if print_: 68 | print(s) 69 | if log_: 70 | with open(os.path.join(con.save_logs_dir, con.log_save_name), 'a+') as f_log: 71 | f_log.write(s + '\n') 72 | 73 | 74 | def train(): 75 | # initialize the model 76 | print("-" * 5 + "Initializing the model" + "-" * 5) 77 | model = CasRel(con, num_rels) 78 | # model.cuda() 79 | 80 | # whether use multi GPU 81 | if con.multi_gpu: 82 | model = nn.DataParallel(model) 83 | model.train() 84 | 85 | # define the optimizer 86 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=con.learning_rate) 87 | 88 | # other 89 | global_step = 0 90 | loss_sum = 0 91 | 92 | best_f1_score = 0.0 93 | best_precision = 0.0 94 | best_recall = 0.0 95 | 96 | best_epoch = 0 97 | init_time = time.time() 98 | start_time = time.time() 99 | 100 | # the training loop 101 | print("-" * 5 + "Start training" + "-" * 5) 102 | for epoch in range(con.max_epoch): 103 | for batch_x, batch_y in train_data_iter: 104 | if global_step == 20: 105 | break 106 | pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = model(batch_x) 107 | sub_heads_loss = loss(pred_sub_heads, batch_y['sub_heads'], batch_x['mask']) 108 | sub_tails_loss = loss(pred_sub_tails, batch_y['sub_tails'], batch_x['mask']) 109 | obj_heads_loss = loss(pred_obj_heads, batch_y['obj_heads'], batch_x['mask']) 110 | obj_tails_loss = loss(pred_obj_tails, batch_y['obj_tails'], batch_x['mask']) 111 | total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss) 112 | 113 | optimizer.zero_grad() 114 | total_loss.backward() 115 | optimizer.step() 116 | 117 | global_step += 1 118 | loss_sum += total_loss.item() 119 | 120 | if global_step % con.period == 0: 121 | cur_loss = loss_sum / con.period 122 | elapsed = time.time() - start_time 123 | logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}". 124 | format(epoch, global_step, elapsed * 1000 / con.period, cur_loss)) 125 | loss_sum = 0 126 | start_time = time.time() 127 | 128 | if (epoch + 1) % con.test_epoch == 0: 129 | eval_start_time = time.time() 130 | model.eval() 131 | # call the test function 132 | precision, recall, f1_score = metric(dev_data_iter, rel_vocab, con, model) 133 | model.train() 134 | logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'. 135 | format(epoch, time.time() - eval_start_time, f1_score, precision, recall)) 136 | 137 | if f1_score > best_f1_score: 138 | best_f1_score = f1_score 139 | best_epoch = epoch 140 | best_precision = precision 141 | best_recall = recall 142 | logging("Saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}". 143 | format(best_epoch, best_f1_score, precision, recall)) 144 | # save the best model 145 | path = os.path.join(con.save_weights_dir, con.weights_save_name) 146 | torch.save(model.state_dict(), path) 147 | 148 | # manually release the unused cache 149 | # torch.cuda.empty_cache() 150 | 151 | logging("-" * 5 + "Finish training" + "-" * 5) 152 | logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s". 153 | format(best_epoch, best_f1_score, best_precision, best_recall, time.time() - init_time)) 154 | 155 | 156 | train() 157 | 158 | # model = CasRel(con, num_rels) 159 | # # model.cuda() 160 | # 161 | # # define the optimizer 162 | # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=con.learning_rate) 163 | # 164 | # Trainer(train_data_iter, model, optimizer, batch_size=con.batch_size, n_epochs=con.max_epoch, print_every=con.period, 165 | # callbacks=[MyCallBack(dev_data_iter, rel_vocab, con)]) 166 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import unicodedata 4 | import codecs 5 | from keras_bert import Tokenizer 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class HBTokenizer(Tokenizer): 11 | def _tokenize(self, text): 12 | if not self._cased: 13 | text = unicodedata.normalize('NFD', text) 14 | text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn']) 15 | text = text.lower() 16 | spaced = '' 17 | for ch in text: 18 | if ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 19 | continue 20 | else: 21 | spaced += ch 22 | tokens = [] 23 | for word in spaced.strip().split(): 24 | tokens += self._word_piece_tokenize(word) 25 | tokens.append('[unused1]') 26 | return tokens 27 | 28 | 29 | def get_tokenizer(vocab_path): 30 | token_dict = {} 31 | with codecs.open(vocab_path, 'r', 'utf8') as reader: 32 | for line in reader: 33 | token = line.strip() 34 | token_dict[token] = len(token_dict) 35 | return HBTokenizer(token_dict, cased=True) 36 | 37 | 38 | def to_tup(triple_list): 39 | ret = [] 40 | for triple in triple_list: 41 | ret.append(tuple(triple)) 42 | return ret 43 | 44 | 45 | def metric(data_iter, rel_vocab, config, model, output=False, h_bar=0.5, t_bar=0.5): 46 | if output: 47 | # check the result dir 48 | if not os.path.exists(config.result_dir): 49 | os.mkdir(config.result_dir) 50 | path = os.path.join(config.result_dir, config.result_save_name) 51 | fw = open(path, 'w') 52 | 53 | orders = ['subject', 'relation', 'object'] 54 | correct_num, predict_num, gold_num = 0, 0, 0 55 | 56 | for batch_x, batch_y in data_iter: 57 | with torch.no_grad(): 58 | token_ids = batch_x['token_ids'] 59 | tokens = batch_x['tokens'] 60 | mask = batch_x['mask'] 61 | encoded_text = model.get_encoded_text(token_ids, mask) 62 | pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text) 63 | sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], \ 64 | np.where(pred_sub_tails.cpu()[0] > t_bar)[0] 65 | subjects = [] 66 | for sub_head in sub_heads: 67 | sub_tail = sub_tails[sub_tails >= sub_head] 68 | if len(sub_tail) > 0: 69 | sub_tail = sub_tail[0] 70 | subject = tokens[sub_head: sub_tail] 71 | subjects.append((subject, sub_head, sub_tail)) 72 | 73 | if subjects: 74 | triple_list = [] 75 | # [subject_num, seq_len, bert_dim] 76 | repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1) 77 | # [subject_num, 1, seq_len] 78 | sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_() 79 | sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_() 80 | for subject_idx, subject in enumerate(subjects): 81 | sub_head_mapping[subject_idx][0][subject[1]] = 1 82 | sub_tail_mapping[subject_idx][0][subject[2]] = 1 83 | sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text) 84 | sub_head_mapping = sub_head_mapping.to(repeated_encoded_text) 85 | pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, 86 | repeated_encoded_text) 87 | for subject_idx, subject in enumerate(subjects): 88 | sub = subject[0] 89 | sub = ''.join([i.lstrip("##") for i in sub]) 90 | sub = ' '.join(sub.split('[unused1]')) 91 | obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where( 92 | pred_obj_tails.cpu()[subject_idx] > t_bar) 93 | for obj_head, rel_head in zip(*obj_heads): 94 | for obj_tail, rel_tail in zip(*obj_tails): 95 | if obj_head <= obj_tail and rel_head == rel_tail: 96 | rel = rel_vocab.to_word[int(rel_head)] 97 | obj = tokens[obj_head: obj_tail] 98 | obj = ''.join([i.lstrip("##") for i in obj]) 99 | obj = ' '.join(obj.split('[unused1]')) 100 | triple_list.append((sub, rel, obj)) 101 | break 102 | triple_set = set() 103 | for s, r, o in triple_list: 104 | triple_set.add((s, r, o)) 105 | pred_list = list(triple_set) 106 | else: 107 | pred_list = [] 108 | 109 | pred_triples = set(pred_list) 110 | gold_triples = set(to_tup(batch_y['triples'][0])) 111 | 112 | correct_num += len(pred_triples & gold_triples) 113 | predict_num += len(pred_triples) 114 | gold_num += len(gold_triples) 115 | 116 | if output: 117 | result = json.dumps({ 118 | # 'text': ' '.join(tokens), 119 | 'triple_list_gold': [ 120 | dict(zip(orders, triple)) for triple in gold_triples 121 | ], 122 | 'triple_list_pred': [ 123 | dict(zip(orders, triple)) for triple in pred_triples 124 | ], 125 | 'new': [ 126 | dict(zip(orders, triple)) for triple in pred_triples - gold_triples 127 | ], 128 | 'lack': [ 129 | dict(zip(orders, triple)) for triple in gold_triples - pred_triples 130 | ] 131 | }, ensure_ascii=False) 132 | fw.write(result + '\n') 133 | 134 | print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num)) 135 | 136 | precision = correct_num / (predict_num + 1e-10) 137 | recall = correct_num / (gold_num + 1e-10) 138 | f1_score = 2 * precision * recall / (precision + recall + 1e-10) 139 | return precision, recall, f1_score 140 | --------------------------------------------------------------------------------