├── LICENSE ├── PLM ├── bert-base-uncased │ └── .placeholder └── bert-large-uncased │ └── .placeholder ├── README.md ├── code ├── checkpoint │ └── .placeholder ├── config.py ├── data.py ├── eval_GAIN_BERT.sh ├── eval_GAIN_GloVe.sh ├── fig_result │ └── .placeholder ├── logs │ └── .placeholder ├── models │ ├── GAIN.py │ └── GAIN_nomention.py ├── run_GAIN_BERT.sh ├── run_GAIN_GloVe.sh ├── test.py ├── train.py └── utils.py ├── data ├── README.md └── prepro_data │ └── .placeholder ├── pictures ├── model.png └── results.png └── test_result_jsons └── test_result_62.76_F1.json /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shuang Zeng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PLM/bert-base-uncased/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/PLM/bert-base-uncased/.placeholder -------------------------------------------------------------------------------- /PLM/bert-large-uncased/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/PLM/bert-large-uncased/.placeholder -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Double Graph Based Reasoning for Document-level Relation Extraction 2 | Source code for EMNLP 2020 paper: [Double Graph Based Reasoning for Document-level Relation Extraction](https://arxiv.org/abs/2009.13752) 3 | 4 | > Document-level relation extraction aims to extract relations among entities within a document. Different from sentence-level relation extraction, it requires reasoning over multiple sentences across a document. In this paper, we propose Graph Aggregation-and-Inference Network (GAIN) featuring double graphs. GAIN first constructs a heterogeneous mention-level graph (hMG) to model complex interaction among different mentions across the document. It also constructs an entity-level graph (EG), based on which we propose a novel path reasoning mechanism to infer relations between entities. Experiments on the public dataset, DocRED, show GAIN achieves a significant performance improvement (2.85 on F1) over the previous state-of-the-art. 5 | 6 | + Architecture 7 | ![model overview](pictures/model.png) 8 | 9 | + Overall Results 10 | ![results](pictures/results.png) 11 | ## 0. Package Description 12 | ``` 13 | GAIN/ 14 | ├─ code/ 15 | ├── checkpoint/: save model checkpoints 16 | ├── fig_result/: plot AUC curves 17 | ├── logs/: save training / evaluation logs 18 | ├── models/: 19 | ├── GAIN.py: GAIN model for GloVe or BERT version 20 | ├── GAIN_nomention.py: GAIN model for -hMG ablation 21 | ├── config.py: process command arguments 22 | ├── data.py: define Datasets / Dataloader for GAIN-GloVe or GAIN-BERT 23 | ├── test.py: evaluation code 24 | ├── train.py: training code 25 | ├── utils.py: some tools for training / evaluation 26 | ├── *.sh: training / evaluation shell scripts 27 | ├─ data/: raw data and preprocessed data about DocRED dataset 28 | ├── prepro_data/ 29 | ├── README.md 30 | ├─ PLM/: save pre-trained language models such as BERT_base / BERT_lagrge 31 | ├── bert-base-uncased/ 32 | ├── bert-large-uncased/ 33 | ├─ test_result_jsons/: save test result jsons 34 | ├─ LICENSE 35 | ├─ README.md 36 | ``` 37 | 38 | ## 1. Environments 39 | 40 | - python (3.7.4) 41 | - cuda (10.2) 42 | - Ubuntu-18.0.4 (4.15.0-65-generic) 43 | 44 | ## 2. Dependencies 45 | 46 | - numpy (1.19.2) 47 | - matplotlib (3.3.2) 48 | - torch (1.6.0) 49 | - transformers (3.1.0) 50 | - dgl-cu102 (0.4.3) 51 | - scikit-learn (0.23.2) 52 | 53 | PS: dgl >= 0.5 is not compatible with our code, we will fix this compatibility problem in the future. 54 | 55 | ## 3. Preparation 56 | 57 | ### 3.1. Dataset 58 | - Download data from [Google Drive link](https://drive.google.com/drive/folders/1c5-0YwnoJx8NS6CV2f-NoTHR__BdkNqw) shared by DocRED authors 59 | 60 | - Put `train_annotated.json`, `dev.json`, `test.json`, `word2id.json`, `ner2id.json`, `rel2id.json`, `vec.npy` into the directory `data/` 61 | 62 | - If you want to use other datasets, please first process them to fit the same format as DocRED. 63 | 64 | ### 3.2. (Optional) Pre-trained Language Models 65 | Following the hint in this [link](http://viewsetting.xyz/2019/10/17/pytorch_transformers/?nsukey=v0sWRSl5BbNLDI3eWyUvd1HlPVJiEOiV%2Fk8adAy5VryF9JNLUt1TidZkzaDANBUG6yb6ZGywa9Qa7qiP3KssXrGXeNC1S21IyT6HZq6%2BZ71K1ADF1jKBTGkgRHaarcXIA5%2B1cUq%2BdM%2FhoJVzgDoM7lcmJg9%2Be6NarwsZzpwAbAwjHTLv5b2uQzsSrYwJEdPl7q9O70SmzCJ1VF511vwxKA%3D%3D), download possible required files (`pytorch_model.bin`, `config.json`, `vocab.txt`, etc.) into the directory `PLM/bert-????-uncased` such as `PLM/bert-base-uncased`. 66 | 67 | ## 4. Training 68 | 69 | ```bash 70 | >> cd code 71 | >> ./runXXX.sh gpu_id # like ./run_GAIN_BERT.sh 2 72 | >> tail -f -n 2000 logs/train_xxx.log 73 | ``` 74 | 75 | ## 5. Evaluation 76 | 77 | ```bash 78 | >> cd code 79 | >> ./evalXXX.sh gpu_id threshold(optional) # like ./eval_GAIN_BERT.sh 0 0.5521 80 | >> tail -f -n 2000 logs/test_xxx.log 81 | ``` 82 | 83 | PS: we recommend to use threshold = -1 (which is the default, you can omit this arguments at this time) for dev set, 84 | the log will print the optimal threshold in dev set, and you can use this optimal value as threshold to evaluate test set. 85 | 86 | ## 6. Submission to LeadBoard (CodaLab) 87 | - You will get json output file for test set at step 5. 88 | 89 | - And then you can rename it as `result.json` and compress it as `result.zip`. 90 | 91 | - At last, you can submit the `result.zip` to [CodaLab](https://competitions.codalab.org/competitions/20717#participate-submit_results). 92 | 93 | ## 7. License 94 | 95 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 96 | 97 | ## 8. Citation 98 | 99 | If you use this work or code, please kindly cite the following paper: 100 | 101 | ```bib 102 | @inproceedings{zeng-etal-2020-gain, 103 | title = "Double Graph Based Reasoning for Document-level Relation Extraction", 104 | author = "Zeng, Shuang and 105 | Xu, Runxin and 106 | Chang, Baobao and 107 | Li, Lei", 108 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 109 | year = "2020", 110 | publisher = "Association for Computational Linguistics", 111 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.127", 112 | pages = "1630--1640", 113 | } 114 | ``` 115 | 116 | ## 9. Contacts 117 | 118 | If you have any questions, please feel free to contact [Shuang Zeng](mailto:zengs@pku.edu.cn), we will reply it as soon as possible. 119 | 120 | -------------------------------------------------------------------------------- /code/checkpoint/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/code/checkpoint/.placeholder -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | 7 | data_dir = '../data/' 8 | prepro_dir = os.path.join(data_dir, 'prepro_data/') 9 | if not os.path.exists(prepro_dir): 10 | os.mkdir(prepro_dir) 11 | 12 | rel2id = json.load(open(os.path.join(data_dir, 'rel2id.json'), "r")) 13 | id2rel = {v: k for k, v in rel2id.items()} 14 | word2id = json.load(open(os.path.join(data_dir, 'word2id.json'), "r")) 15 | ner2id = json.load(open(os.path.join(data_dir, 'ner2id.json'), "r")) 16 | 17 | word2vec = np.load(os.path.join(data_dir, 'vec.npy')) 18 | 19 | 20 | def get_opt(): 21 | parser = argparse.ArgumentParser() 22 | 23 | # datasets path 24 | parser.add_argument('--train_set', type=str, default=os.path.join(data_dir, 'train_annotated.json')) 25 | parser.add_argument('--dev_set', type=str, default=os.path.join(data_dir, 'dev.json')) 26 | parser.add_argument('--test_set', type=str, default=os.path.join(data_dir, 'test.json')) 27 | 28 | # save path of preprocessed datasets 29 | parser.add_argument('--train_set_save', type=str, default=os.path.join(prepro_dir, 'train.pkl')) 30 | parser.add_argument('--dev_set_save', type=str, default=os.path.join(prepro_dir, 'dev.pkl')) 31 | parser.add_argument('--test_set_save', type=str, default=os.path.join(prepro_dir, 'test.pkl')) 32 | 33 | # checkpoints 34 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint') 35 | parser.add_argument('--fig_result_dir', type=str, default='fig_result') 36 | parser.add_argument('--model_name', type=str, default='train_model') 37 | parser.add_argument('--pretrain_model', type=str, default='') 38 | 39 | # task/Dataset-related 40 | parser.add_argument('--vocabulary_size', type=int, default=200000) 41 | parser.add_argument('--relation_nums', type=int, default=97) 42 | parser.add_argument('--entity_type_num', type=int, default=7) 43 | parser.add_argument('--max_entity_num', type=int, default=80) 44 | 45 | # padding 46 | parser.add_argument('--word_pad', type=int, default=0) 47 | parser.add_argument('--entity_type_pad', type=int, default=0) 48 | parser.add_argument('--entity_id_pad', type=int, default=0) 49 | 50 | # word embedding 51 | parser.add_argument('--word_emb_size', type=int, default=10) 52 | parser.add_argument('--pre_train_word', action='store_true') 53 | parser.add_argument('--data_word_vec', type=str) 54 | parser.add_argument('--finetune_word', action='store_true') 55 | 56 | # entity type embedding 57 | parser.add_argument('--use_entity_type', action='store_true') 58 | parser.add_argument('--entity_type_size', type=int, default=20) 59 | 60 | # entity id embedding, i.e., coreference embedding in DocRED original paper 61 | parser.add_argument('--use_entity_id', action='store_true') 62 | parser.add_argument('--entity_id_size', type=int, default=20) 63 | 64 | # BiLSTM 65 | parser.add_argument('--nlayers', type=int, default=1) 66 | parser.add_argument('--lstm_hidden_size', type=int, default=32) 67 | parser.add_argument('--lstm_dropout', type=float, default=0.1) 68 | 69 | # training settings 70 | parser.add_argument('--lr', type=float, default=0.001) 71 | parser.add_argument('--batch_size', type=int, default=1) 72 | parser.add_argument('--test_batch_size', type=int, default=1) 73 | parser.add_argument('--epoch', type=int, default=10) 74 | parser.add_argument('--test_epoch', type=int, default=1) 75 | parser.add_argument('--weight_decay', type=float, default=0.0001) 76 | parser.add_argument('--negativa_alpha', type=float, default=0.0) # negative example nums v.s positive example num 77 | parser.add_argument('--log_step', type=int, default=50) 78 | parser.add_argument('--save_model_freq', type=int, default=1) 79 | 80 | # gcn 81 | parser.add_argument('--mention_drop', action='store_true') 82 | parser.add_argument('--gcn_layers', type=int, default=2) 83 | parser.add_argument('--gcn_dim', type=int, default=808) 84 | parser.add_argument('--dropout', type=float, default=0.6) 85 | parser.add_argument('--activation', type=str, default="relu") 86 | 87 | # BERT 88 | parser.add_argument('--bert_hid_size', type=int, default=768) 89 | parser.add_argument('--bert_path', type=str, default="") 90 | parser.add_argument('--bert_fix', action='store_true') 91 | parser.add_argument('--coslr', action='store_true') 92 | parser.add_argument('--clip', type=float, default=-1) 93 | 94 | parser.add_argument('--k_fold', type=str, default="none") 95 | 96 | # use BiLSTM / BERT encoder, default: BiLSTM encoder 97 | parser.add_argument('--use_model', type=str, default="bilstm", choices=['bilstm', 'bert'], 98 | help='you should choose between bert and bilstm') 99 | 100 | # binary classification threshold, automatically find optimal threshold when -1 101 | parser.add_argument('--input_theta', type=float, default=-1) 102 | 103 | return parser.parse_args() 104 | -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import pickle 5 | import random 6 | from collections import defaultdict 7 | 8 | import dgl 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import IterableDataset, DataLoader 12 | from transformers import * 13 | 14 | from models.GAIN import Bert 15 | from utils import get_cuda 16 | 17 | IGNORE_INDEX = -100 18 | 19 | 20 | class DGLREDataset(IterableDataset): 21 | 22 | def __init__(self, src_file, save_file, word2id, ner2id, rel2id, 23 | dataset_type='train', instance_in_train=None, opt=None): 24 | 25 | super(DGLREDataset, self).__init__() 26 | 27 | # record training set mention triples 28 | self.instance_in_train = set([]) if instance_in_train is None else instance_in_train 29 | self.data = None 30 | self.document_max_length = 512 31 | self.INTRA_EDGE = 0 32 | self.INTER_EDGE = 1 33 | self.LOOP_EDGE = 2 34 | self.count = 0 35 | 36 | print('Reading data from {}.'.format(src_file)) 37 | if os.path.exists(save_file): 38 | with open(file=save_file, mode='rb') as fr: 39 | info = pickle.load(fr) 40 | self.data = info['data'] 41 | self.instance_in_train = info['intrain_set'] 42 | print('load preprocessed data from {}.'.format(save_file)) 43 | 44 | else: 45 | with open(file=src_file, mode='r', encoding='utf-8') as fr: 46 | ori_data = json.load(fr) 47 | print('loading..') 48 | self.data = [] 49 | 50 | for i, doc in enumerate(ori_data): 51 | 52 | title, entity_list, labels, sentences = \ 53 | doc['title'], doc['vertexSet'], doc.get('labels', []), doc['sents'] 54 | 55 | Ls = [0] 56 | L = 0 57 | for x in sentences: 58 | L += len(x) 59 | Ls.append(L) 60 | for j in range(len(entity_list)): 61 | for k in range(len(entity_list[j])): 62 | sent_id = int(entity_list[j][k]['sent_id']) 63 | entity_list[j][k]['sent_id'] = sent_id 64 | 65 | dl = Ls[sent_id] 66 | pos0, pos1 = entity_list[j][k]['pos'] 67 | entity_list[j][k]['global_pos'] = (pos0 + dl, pos1 + dl) 68 | 69 | # generate positive examples 70 | train_triple = [] 71 | new_labels = [] 72 | for label in labels: 73 | head, tail, relation, evidence = label['h'], label['t'], label['r'], label['evidence'] 74 | assert (relation in rel2id), 'no such relation {} in rel2id'.format(relation) 75 | label['r'] = rel2id[relation] 76 | 77 | train_triple.append((head, tail)) 78 | 79 | label['in_train'] = False 80 | 81 | # record training set mention triples and mark it for dev and test set 82 | for n1 in entity_list[head]: 83 | for n2 in entity_list[tail]: 84 | mention_triple = (n1['name'], n2['name'], relation) 85 | if dataset_type == 'train': 86 | self.instance_in_train.add(mention_triple) 87 | else: 88 | if mention_triple in self.instance_in_train: 89 | label['in_train'] = True 90 | break 91 | 92 | new_labels.append(label) 93 | 94 | # generate negative examples 95 | na_triple = [] 96 | for j in range(len(entity_list)): 97 | for k in range(len(entity_list)): 98 | if j != k and (j, k) not in train_triple: 99 | na_triple.append((j, k)) 100 | 101 | # generate document ids 102 | words = [] 103 | for sentence in sentences: 104 | for word in sentence: 105 | words.append(word) 106 | if len(words) > self.document_max_length: 107 | words = words[:self.document_max_length] 108 | 109 | word_id = np.zeros((self.document_max_length,), dtype=np.int32) 110 | pos_id = np.zeros((self.document_max_length,), dtype=np.int32) 111 | ner_id = np.zeros((self.document_max_length,), dtype=np.int32) 112 | mention_id = np.zeros((self.document_max_length,), dtype=np.int32) 113 | 114 | for iii, w in enumerate(words): 115 | word = word2id.get(w.lower(), word2id['UNK']) 116 | word_id[iii] = word 117 | 118 | entity2mention = defaultdict(list) 119 | mention_idx = 1 120 | already_exist = set() # dealing with NER overlapping problem 121 | for idx, vertex in enumerate(entity_list, 1): 122 | for v in vertex: 123 | sent_id, (pos0, pos1), ner_type = v['sent_id'], v['global_pos'], v['type'] 124 | if (pos0, pos1) in already_exist: 125 | continue 126 | pos_id[pos0:pos1] = idx 127 | ner_id[pos0:pos1] = ner2id[ner_type] 128 | mention_id[pos0:pos1] = mention_idx 129 | entity2mention[idx].append(mention_idx) 130 | mention_idx += 1 131 | already_exist.add((pos0, pos1)) 132 | 133 | # construct graph 134 | graph = self.create_graph(Ls, mention_id, pos_id, entity2mention) 135 | 136 | # construct entity graph & path 137 | entity_graph, path = self.create_entity_graph(Ls, pos_id, entity2mention) 138 | 139 | assert pos_id.max() == len(entity_list) 140 | assert mention_id.max() == graph.number_of_nodes() - 1 141 | 142 | overlap = doc.get('overlap_entity_pair', []) 143 | new_overlap = [tuple(item) for item in overlap] 144 | 145 | self.data.append({ 146 | 'title': title, 147 | 'entities': entity_list, 148 | 'labels': new_labels, 149 | 'na_triple': na_triple, 150 | 'word_id': word_id, 151 | 'pos_id': pos_id, 152 | 'ner_id': ner_id, 153 | 'mention_id': mention_id, 154 | 'entity2mention': entity2mention, 155 | 'graph': graph, 156 | 'entity_graph': entity_graph, 157 | 'path': path, 158 | 'overlap': new_overlap 159 | }) 160 | 161 | # save data 162 | with open(file=save_file, mode='wb') as fw: 163 | pickle.dump({'data': self.data, 'intrain_set': self.instance_in_train}, fw) 164 | print('finish reading {} and save preprocessed data to {}.'.format(src_file, save_file)) 165 | 166 | if opt.k_fold != "none": 167 | k_fold = opt.k_fold.split(',') 168 | k, total = float(k_fold[0]), float(k_fold[1]) 169 | a = (k - 1) / total * len(self.data) 170 | b = k / total * len(self.data) 171 | self.data = self.data[:a] + self.data[b:] 172 | 173 | def __len__(self): 174 | return len(self.data) 175 | 176 | def __getitem__(self, idx): 177 | return self.data[idx] 178 | 179 | def __iter__(self): 180 | return iter(self.data) 181 | 182 | def create_graph(self, Ls, mention_id, entity_id, entity2mention): 183 | 184 | d = defaultdict(list) 185 | 186 | # add intra-entity edges 187 | for _, mentions in entity2mention.items(): 188 | for i in range(len(mentions)): 189 | for j in range(i + 1, len(mentions)): 190 | d[('node', 'intra', 'node')].append((mentions[i], mentions[j])) 191 | d[('node', 'intra', 'node')].append((mentions[j], mentions[i])) 192 | 193 | if d[('node', 'intra', 'node')] == []: 194 | d[('node', 'intra', 'node')].append((entity2mention[1][0], 0)) 195 | 196 | for i in range(1, len(Ls)): 197 | tmp = dict() 198 | for j in range(Ls[i - 1], Ls[i]): 199 | if mention_id[j] != 0: 200 | tmp[mention_id[j]] = entity_id[j] 201 | mention_entity_info = [(k, v) for k, v in tmp.items()] 202 | 203 | # add self-loop & to-globle-node edges 204 | for m in range(len(mention_entity_info)): 205 | # self-loop 206 | # d[('node', 'loop', 'node')].append((mention_entity_info[m][0], mention_entity_info[m][0])) 207 | 208 | # to global node 209 | d[('node', 'global', 'node')].append((mention_entity_info[m][0], 0)) 210 | d[('node', 'global', 'node')].append((0, mention_entity_info[m][0])) 211 | 212 | # add inter edges 213 | for m in range(len(mention_entity_info)): 214 | for n in range(m + 1, len(mention_entity_info)): 215 | if mention_entity_info[m][1] != mention_entity_info[n][1]: 216 | # inter edge 217 | d[('node', 'inter', 'node')].append((mention_entity_info[m][0], mention_entity_info[n][0])) 218 | d[('node', 'inter', 'node')].append((mention_entity_info[n][0], mention_entity_info[m][0])) 219 | 220 | # add self-loop for global node 221 | # d[('node', 'loop', 'node')].append((0, 0)) 222 | if d[('node', 'inter', 'node')] == []: 223 | d[('node', 'inter', 'node')].append((entity2mention[1][0], 0)) 224 | 225 | graph = dgl.heterograph(d) 226 | 227 | return graph 228 | 229 | def create_entity_graph(self, Ls, entity_id, entity2mention): 230 | 231 | graph = dgl.DGLGraph() 232 | graph.add_nodes(entity_id.max()) 233 | 234 | d = defaultdict(set) 235 | 236 | for i in range(1, len(Ls)): 237 | tmp = set() 238 | for j in range(Ls[i - 1], Ls[i]): 239 | if entity_id[j] != 0: 240 | tmp.add(entity_id[j]) 241 | tmp = list(tmp) 242 | for ii in range(len(tmp)): 243 | for jj in range(ii + 1, len(tmp)): 244 | d[tmp[ii] - 1].add(tmp[jj] - 1) 245 | d[tmp[jj] - 1].add(tmp[ii] - 1) 246 | a = [] 247 | b = [] 248 | for k, v in d.items(): 249 | for vv in v: 250 | a.append(k) 251 | b.append(vv) 252 | graph.add_edges(a, b) 253 | 254 | path = dict() 255 | for i in range(0, graph.number_of_nodes()): 256 | for j in range(i + 1, graph.number_of_nodes()): 257 | a = set(graph.successors(i).numpy()) 258 | b = set(graph.successors(j).numpy()) 259 | c = [val + 1 for val in list(a & b)] 260 | path[(i + 1, j + 1)] = c 261 | 262 | return graph, path 263 | 264 | 265 | class BERTDGLREDataset(IterableDataset): 266 | 267 | def __init__(self, src_file, save_file, word2id, ner2id, rel2id, 268 | dataset_type='train', instance_in_train=None, opt=None): 269 | 270 | super(BERTDGLREDataset, self).__init__() 271 | 272 | # record training set mention triples 273 | self.instance_in_train = set([]) if instance_in_train is None else instance_in_train 274 | self.data = None 275 | self.document_max_length = 512 276 | self.INFRA_EDGE = 0 277 | self.INTER_EDGE = 1 278 | self.LOOP_EDGE = 2 279 | self.count = 0 280 | 281 | print('Reading data from {}.'.format(src_file)) 282 | if os.path.exists(save_file): 283 | with open(file=save_file, mode='rb') as fr: 284 | info = pickle.load(fr) 285 | self.data = info['data'] 286 | self.instance_in_train = info['intrain_set'] 287 | print('load preprocessed data from {}.'.format(save_file)) 288 | 289 | else: 290 | bert = Bert(BertModel, 'bert-base-uncased', opt.bert_path) 291 | 292 | with open(file=src_file, mode='r', encoding='utf-8') as fr: 293 | ori_data = json.load(fr) 294 | print('loading..') 295 | self.data = [] 296 | 297 | for i, doc in enumerate(ori_data): 298 | 299 | title, entity_list, labels, sentences = \ 300 | doc['title'], doc['vertexSet'], doc.get('labels', []), doc['sents'] 301 | 302 | Ls = [0] 303 | L = 0 304 | for x in sentences: 305 | L += len(x) 306 | Ls.append(L) 307 | for j in range(len(entity_list)): 308 | for k in range(len(entity_list[j])): 309 | sent_id = int(entity_list[j][k]['sent_id']) 310 | entity_list[j][k]['sent_id'] = sent_id 311 | 312 | dl = Ls[sent_id] 313 | pos0, pos1 = entity_list[j][k]['pos'] 314 | entity_list[j][k]['global_pos'] = (pos0 + dl, pos1 + dl) 315 | 316 | # generate positive examples 317 | train_triple = [] 318 | new_labels = [] 319 | for label in labels: 320 | head, tail, relation, evidence = label['h'], label['t'], label['r'], label['evidence'] 321 | assert (relation in rel2id), 'no such relation {} in rel2id'.format(relation) 322 | label['r'] = rel2id[relation] 323 | 324 | train_triple.append((head, tail)) 325 | 326 | label['in_train'] = False 327 | 328 | # record training set mention triples and mark it for dev and test set 329 | for n1 in entity_list[head]: 330 | for n2 in entity_list[tail]: 331 | mention_triple = (n1['name'], n2['name'], relation) 332 | if dataset_type == 'train': 333 | self.instance_in_train.add(mention_triple) 334 | else: 335 | if mention_triple in self.instance_in_train: 336 | label['in_train'] = True 337 | break 338 | 339 | new_labels.append(label) 340 | 341 | # generate negative examples 342 | na_triple = [] 343 | for j in range(len(entity_list)): 344 | for k in range(len(entity_list)): 345 | if j != k and (j, k) not in train_triple: 346 | na_triple.append((j, k)) 347 | 348 | # generate document ids 349 | words = [] 350 | for sentence in sentences: 351 | for word in sentence: 352 | words.append(word) 353 | 354 | bert_token, bert_starts, bert_subwords = bert.subword_tokenize_to_ids(words) 355 | 356 | word_id = np.zeros((self.document_max_length,), dtype=np.int32) 357 | pos_id = np.zeros((self.document_max_length,), dtype=np.int32) 358 | ner_id = np.zeros((self.document_max_length,), dtype=np.int32) 359 | mention_id = np.zeros((self.document_max_length,), dtype=np.int32) 360 | word_id[:] = bert_token[0] 361 | 362 | entity2mention = defaultdict(list) 363 | mention_idx = 1 364 | already_exist = set() 365 | for idx, vertex in enumerate(entity_list, 1): 366 | for v in vertex: 367 | 368 | sent_id, (pos0, pos1), ner_type = v['sent_id'], v['global_pos'], v['type'] 369 | 370 | pos0 = bert_starts[pos0] 371 | pos1 = bert_starts[pos1] if pos1 < len(bert_starts) else 1024 372 | 373 | if (pos0, pos1) in already_exist: 374 | continue 375 | 376 | if pos0 >= len(pos_id): 377 | continue 378 | 379 | pos_id[pos0:pos1] = idx 380 | ner_id[pos0:pos1] = ner2id[ner_type] 381 | mention_id[pos0:pos1] = mention_idx 382 | entity2mention[idx].append(mention_idx) 383 | mention_idx += 1 384 | already_exist.add((pos0, pos1)) 385 | replace_i = 0 386 | idx = len(entity_list) 387 | if entity2mention[idx] == []: 388 | entity2mention[idx].append(mention_idx) 389 | while mention_id[replace_i] != 0: 390 | replace_i += 1 391 | mention_id[replace_i] = mention_idx 392 | pos_id[replace_i] = idx 393 | ner_id[replace_i] = ner2id[vertex[0]['type']] 394 | mention_idx += 1 395 | 396 | new_Ls = [0] 397 | for ii in range(1, len(Ls)): 398 | new_Ls.append(bert_starts[Ls[ii]] if Ls[ii] < len(bert_starts) else len(bert_subwords)) 399 | Ls = new_Ls 400 | 401 | # construct graph 402 | graph = self.create_graph(Ls, mention_id, pos_id, entity2mention) 403 | 404 | # construct entity graph & path 405 | entity_graph, path = self.create_entity_graph(Ls, pos_id, entity2mention) 406 | 407 | assert pos_id.max() == len(entity_list) 408 | assert mention_id.max() == graph.number_of_nodes() - 1 409 | 410 | overlap = doc.get('overlap_entity_pair', []) 411 | new_overlap = [tuple(item) for item in overlap] 412 | 413 | self.data.append({ 414 | 'title': title, 415 | 'entities': entity_list, 416 | 'labels': new_labels, 417 | 'na_triple': na_triple, 418 | 'word_id': word_id, 419 | 'pos_id': pos_id, 420 | 'ner_id': ner_id, 421 | 'mention_id': mention_id, 422 | 'entity2mention': entity2mention, 423 | 'graph': graph, 424 | 'entity_graph': entity_graph, 425 | 'path': path, 426 | 'overlap': new_overlap 427 | }) 428 | 429 | # save data 430 | with open(file=save_file, mode='wb') as fw: 431 | pickle.dump({'data': self.data, 'intrain_set': self.instance_in_train}, fw) 432 | print('finish reading {} and save preprocessed data to {}.'.format(src_file, save_file)) 433 | 434 | def __len__(self): 435 | return len(self.data) 436 | 437 | def __getitem__(self, idx): 438 | return self.data[idx] 439 | 440 | def __iter__(self): 441 | return iter(self.data) 442 | 443 | def create_graph(self, Ls, mention_id, entity_id, entity2mention): 444 | 445 | d = defaultdict(list) 446 | 447 | # add intra edges 448 | for _, mentions in entity2mention.items(): 449 | for i in range(len(mentions)): 450 | for j in range(i + 1, len(mentions)): 451 | d[('node', 'intra', 'node')].append((mentions[i], mentions[j])) 452 | d[('node', 'intra', 'node')].append((mentions[j], mentions[i])) 453 | 454 | if d[('node', 'intra', 'node')] == []: 455 | d[('node', 'intra', 'node')].append((entity2mention[1][0], 0)) 456 | 457 | for i in range(1, len(Ls)): 458 | tmp = dict() 459 | for j in range(Ls[i - 1], Ls[i]): 460 | if mention_id[j] != 0: 461 | tmp[mention_id[j]] = entity_id[j] 462 | mention_entity_info = [(k, v) for k, v in tmp.items()] 463 | 464 | # add self-loop & to-globle-node edges 465 | for m in range(len(mention_entity_info)): 466 | # self-loop 467 | # d[('node', 'loop', 'node')].append((mention_entity_info[m][0], mention_entity_info[m][0])) 468 | 469 | # to global node 470 | d[('node', 'global', 'node')].append((mention_entity_info[m][0], 0)) 471 | d[('node', 'global', 'node')].append((0, mention_entity_info[m][0])) 472 | 473 | # add inter edges 474 | for m in range(len(mention_entity_info)): 475 | for n in range(m + 1, len(mention_entity_info)): 476 | if mention_entity_info[m][1] != mention_entity_info[n][1]: 477 | # inter edge 478 | d[('node', 'inter', 'node')].append((mention_entity_info[m][0], mention_entity_info[n][0])) 479 | d[('node', 'inter', 'node')].append((mention_entity_info[n][0], mention_entity_info[m][0])) 480 | 481 | # add self-loop for global node 482 | # d[('node', 'loop', 'node')].append((0, 0)) 483 | if d[('node', 'inter', 'node')] == []: 484 | d[('node', 'inter', 'node')].append((entity2mention[1][0], 0)) 485 | 486 | graph = dgl.heterograph(d) 487 | 488 | return graph 489 | 490 | def create_entity_graph(self, Ls, entity_id, entity2mention): 491 | 492 | graph = dgl.DGLGraph() 493 | graph.add_nodes(entity_id.max()) 494 | 495 | d = defaultdict(set) 496 | 497 | for i in range(1, len(Ls)): 498 | tmp = set() 499 | for j in range(Ls[i - 1], Ls[i]): 500 | if entity_id[j] != 0: 501 | tmp.add(entity_id[j]) 502 | tmp = list(tmp) 503 | for ii in range(len(tmp)): 504 | for jj in range(ii + 1, len(tmp)): 505 | d[tmp[ii] - 1].add(tmp[jj] - 1) 506 | d[tmp[jj] - 1].add(tmp[ii] - 1) 507 | a = [] 508 | b = [] 509 | for k, v in d.items(): 510 | for vv in v: 511 | a.append(k) 512 | b.append(vv) 513 | graph.add_edges(a, b) 514 | 515 | path = dict() 516 | for i in range(0, graph.number_of_nodes()): 517 | for j in range(i + 1, graph.number_of_nodes()): 518 | a = set(graph.successors(i).numpy()) 519 | b = set(graph.successors(j).numpy()) 520 | c = [val + 1 for val in list(a & b)] 521 | path[(i + 1, j + 1)] = c 522 | 523 | return graph, path 524 | 525 | 526 | class DGLREDataloader(DataLoader): 527 | 528 | def __init__(self, dataset, batch_size, shuffle=False, h_t_limit_per_batch=300, h_t_limit=1722, relation_num=97, 529 | max_length=512, negativa_alpha=0.0, dataset_type='train'): 530 | super(DGLREDataloader, self).__init__(dataset, batch_size=batch_size) 531 | self.shuffle = shuffle 532 | self.length = len(self.dataset) 533 | self.max_length = max_length 534 | self.negativa_alpha = negativa_alpha 535 | self.dataset_type = dataset_type 536 | 537 | self.h_t_limit_per_batch = h_t_limit_per_batch 538 | self.h_t_limit = h_t_limit 539 | self.relation_num = relation_num 540 | self.dis2idx = np.zeros((512), dtype='int64') 541 | self.dis2idx[1] = 1 542 | self.dis2idx[2:] = 2 543 | self.dis2idx[4:] = 3 544 | self.dis2idx[8:] = 4 545 | self.dis2idx[16:] = 5 546 | self.dis2idx[32:] = 6 547 | self.dis2idx[64:] = 7 548 | self.dis2idx[128:] = 8 549 | self.dis2idx[256:] = 9 550 | self.dis_size = 20 551 | 552 | self.order = list(range(self.length)) 553 | 554 | def __iter__(self): 555 | # shuffle 556 | if self.shuffle: 557 | random.shuffle(self.order) 558 | self.data = [self.dataset[idx] for idx in self.order] 559 | else: 560 | self.data = self.dataset 561 | batch_num = math.ceil(self.length / self.batch_size) 562 | self.batches = [self.data[idx * self.batch_size: min(self.length, (idx + 1) * self.batch_size)] 563 | for idx in range(0, batch_num)] 564 | self.batches_order = [self.order[idx * self.batch_size: min(self.length, (idx + 1) * self.batch_size)] 565 | for idx in range(0, batch_num)] 566 | 567 | # begin 568 | context_word_ids = torch.LongTensor(self.batch_size, self.max_length).cpu() 569 | context_pos_ids = torch.LongTensor(self.batch_size, self.max_length).cpu() 570 | context_ner_ids = torch.LongTensor(self.batch_size, self.max_length).cpu() 571 | context_mention_ids = torch.LongTensor(self.batch_size, self.max_length).cpu() 572 | context_word_mask = torch.LongTensor(self.batch_size, self.max_length).cpu() 573 | context_word_length = torch.LongTensor(self.batch_size).cpu() 574 | ht_pairs = torch.LongTensor(self.batch_size, self.h_t_limit, 2).cpu() 575 | relation_multi_label = torch.Tensor(self.batch_size, self.h_t_limit, self.relation_num).cpu() 576 | relation_mask = torch.Tensor(self.batch_size, self.h_t_limit).cpu() 577 | relation_label = torch.LongTensor(self.batch_size, self.h_t_limit).cpu() 578 | ht_pair_distance = torch.LongTensor(self.batch_size, self.h_t_limit).cpu() 579 | 580 | for idx, minibatch in enumerate(self.batches): 581 | cur_bsz = len(minibatch) 582 | 583 | for mapping in [context_word_ids, context_pos_ids, context_ner_ids, context_mention_ids, 584 | context_word_mask, context_word_length, 585 | ht_pairs, ht_pair_distance, relation_multi_label, relation_mask, relation_label]: 586 | if mapping is not None: 587 | mapping.zero_() 588 | 589 | relation_label.fill_(IGNORE_INDEX) 590 | 591 | max_h_t_cnt = 0 592 | 593 | label_list = [] 594 | L_vertex = [] 595 | titles = [] 596 | indexes = [] 597 | graph_list = [] 598 | entity_graph_list = [] 599 | entity2mention_table = [] 600 | path_table = [] 601 | overlaps = [] 602 | 603 | for i, example in enumerate(minibatch): 604 | title, entities, labels, na_triple, word_id, pos_id, ner_id, mention_id, entity2mention, graph, entity_graph, path = \ 605 | example['title'], example['entities'], example['labels'], example['na_triple'], \ 606 | example['word_id'], example['pos_id'], example['ner_id'], example['mention_id'], example[ 607 | 'entity2mention'], example['graph'], example['entity_graph'], example['path'] 608 | graph_list.append(graph) 609 | entity_graph_list.append(entity_graph) 610 | path_table.append(path) 611 | overlaps.append(example['overlap']) 612 | 613 | entity2mention_t = get_cuda(torch.zeros((pos_id.max() + 1, mention_id.max() + 1))) 614 | for e, ms in entity2mention.items(): 615 | for m in ms: 616 | entity2mention_t[e, m] = 1 617 | entity2mention_table.append(entity2mention_t) 618 | 619 | L = len(entities) 620 | word_num = word_id.shape[0] 621 | 622 | context_word_ids[i, :word_num].copy_(torch.from_numpy(word_id)) 623 | context_pos_ids[i, :word_num].copy_(torch.from_numpy(pos_id)) 624 | context_ner_ids[i, :word_num].copy_(torch.from_numpy(ner_id)) 625 | context_mention_ids[i, :word_num].copy_(torch.from_numpy(mention_id)) 626 | 627 | idx2label = defaultdict(list) 628 | label_set = {} 629 | for label in labels: 630 | head, tail, relation, intrain, evidence = \ 631 | label['h'], label['t'], label['r'], label['in_train'], label['evidence'] 632 | idx2label[(head, tail)].append(relation) 633 | label_set[(head, tail, relation)] = intrain 634 | 635 | label_list.append(label_set) 636 | 637 | if self.dataset_type == 'train': 638 | train_tripe = list(idx2label.keys()) 639 | for j, (h_idx, t_idx) in enumerate(train_tripe): 640 | hlist, tlist = entities[h_idx], entities[t_idx] 641 | ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1]) 642 | label = idx2label[(h_idx, t_idx)] 643 | 644 | delta_dis = hlist[0]['global_pos'][0] - tlist[0]['global_pos'][0] 645 | if delta_dis < 0: 646 | ht_pair_distance[i, j] = -int(self.dis2idx[-delta_dis]) + self.dis_size // 2 647 | else: 648 | ht_pair_distance[i, j] = int(self.dis2idx[delta_dis]) + self.dis_size // 2 649 | 650 | for r in label: 651 | relation_multi_label[i, j, r] = 1 652 | 653 | relation_mask[i, j] = 1 654 | rt = np.random.randint(len(label)) 655 | relation_label[i, j] = label[rt] 656 | 657 | lower_bound = len(na_triple) 658 | if self.negativa_alpha > 0.0: 659 | random.shuffle(na_triple) 660 | lower_bound = int(max(20, len(train_tripe) * self.negativa_alpha)) 661 | 662 | for j, (h_idx, t_idx) in enumerate(na_triple[:lower_bound], len(train_tripe)): 663 | hlist, tlist = entities[h_idx], entities[t_idx] 664 | ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1]) 665 | 666 | delta_dis = hlist[0]['global_pos'][0] - tlist[0]['global_pos'][0] 667 | if delta_dis < 0: 668 | ht_pair_distance[i, j] = -int(self.dis2idx[-delta_dis]) + self.dis_size // 2 669 | else: 670 | ht_pair_distance[i, j] = int(self.dis2idx[delta_dis]) + self.dis_size // 2 671 | 672 | relation_multi_label[i, j, 0] = 1 673 | relation_label[i, j] = 0 674 | relation_mask[i, j] = 1 675 | 676 | max_h_t_cnt = max(max_h_t_cnt, len(train_tripe) + lower_bound) 677 | else: 678 | j = 0 679 | for h_idx in range(L): 680 | for t_idx in range(L): 681 | if h_idx != t_idx: 682 | hlist, tlist = entities[h_idx], entities[t_idx] 683 | ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1]) 684 | 685 | relation_mask[i, j] = 1 686 | 687 | delta_dis = hlist[0]['global_pos'][0] - tlist[0]['global_pos'][0] 688 | if delta_dis < 0: 689 | ht_pair_distance[i, j] = -int(self.dis2idx[-delta_dis]) + self.dis_size // 2 690 | else: 691 | ht_pair_distance[i, j] = int(self.dis2idx[delta_dis]) + self.dis_size // 2 692 | 693 | j += 1 694 | 695 | max_h_t_cnt = max(max_h_t_cnt, j) 696 | L_vertex.append(L) 697 | titles.append(title) 698 | indexes.append(self.batches_order[idx][i]) 699 | 700 | context_word_mask = context_word_ids > 0 701 | context_word_length = context_word_mask.sum(1) 702 | batch_max_length = context_word_length.max() 703 | 704 | yield {'context_idxs': get_cuda(context_word_ids[:cur_bsz, :batch_max_length].contiguous()), 705 | 'context_pos': get_cuda(context_pos_ids[:cur_bsz, :batch_max_length].contiguous()), 706 | 'context_ner': get_cuda(context_ner_ids[:cur_bsz, :batch_max_length].contiguous()), 707 | 'context_mention': get_cuda(context_mention_ids[:cur_bsz, :batch_max_length].contiguous()), 708 | 'context_word_mask': get_cuda(context_word_mask[:cur_bsz, :batch_max_length].contiguous()), 709 | 'context_word_length': get_cuda(context_word_length[:cur_bsz].contiguous()), 710 | 'h_t_pairs': get_cuda(ht_pairs[:cur_bsz, :max_h_t_cnt, :2]), 711 | 'relation_label': get_cuda(relation_label[:cur_bsz, :max_h_t_cnt]).contiguous(), 712 | 'relation_multi_label': get_cuda(relation_multi_label[:cur_bsz, :max_h_t_cnt]), 713 | 'relation_mask': get_cuda(relation_mask[:cur_bsz, :max_h_t_cnt]), 714 | 'ht_pair_distance': get_cuda(ht_pair_distance[:cur_bsz, :max_h_t_cnt]), 715 | 'labels': label_list, 716 | 'L_vertex': L_vertex, 717 | 'titles': titles, 718 | 'indexes': indexes, 719 | 'graphs': graph_list, 720 | 'entity2mention_table': entity2mention_table, 721 | 'entity_graphs': entity_graph_list, 722 | 'path_table': path_table, 723 | 'overlaps': overlaps 724 | } 725 | -------------------------------------------------------------------------------- /code/eval_GAIN_BERT.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export CUDA_VISIBLE_DEVICES=$1 3 | 4 | # binary classification threshold, automatically find optimal threshold when -1, default:-1 5 | input_theta=${2--1} 6 | batch_size=5 7 | test_batch_size=16 8 | dataset=test 9 | 10 | # -------------------GAIN_BERT_base Evaluation Shell Script-------------------- 11 | 12 | if true; then 13 | model_name=GAIN_BERT_base 14 | 15 | nohup python3 -u test.py \ 16 | --train_set ../data/train_annotated.json \ 17 | --train_set_save ../data/prepro_data/train_BERT.pkl \ 18 | --dev_set ../data/dev.json \ 19 | --dev_set_save ../data/prepro_data/dev_BERT.pkl \ 20 | --test_set ../data/${dataset}.json \ 21 | --test_set_save ../data/prepro_data/${dataset}_BERT.pkl \ 22 | --model_name ${model_name} \ 23 | --use_model bert \ 24 | --pretrain_model checkpoint/GAIN_BERT_base_best.pt \ 25 | --batch_size ${batch_size} \ 26 | --test_batch_size ${test_batch_size} \ 27 | --gcn_dim 808 \ 28 | --gcn_layers 2 \ 29 | --bert_hid_size 768 \ 30 | --bert_path ../PLM/bert-base-uncased \ 31 | --use_entity_type \ 32 | --use_entity_id \ 33 | --dropout 0.6 \ 34 | --activation relu \ 35 | --input_theta ${input_theta} \ 36 | >logs/test_${model_name}.log 2>&1 & 37 | fi 38 | 39 | # -------------------GAIN_BERT_large Evaluation Shell Script-------------------- 40 | 41 | if false; then 42 | model_name=GAIN_BERT_large 43 | 44 | nohup python3 -u test.py \ 45 | --train_set ../data/train_annotated.json \ 46 | --train_set_save ../data/prepro_data/train_BERT.pkl \ 47 | --dev_set ../data/dev.json \ 48 | --dev_set_save ../data/prepro_data/dev_BERT.pkl \ 49 | --test_set ../data/${dataset}.json \ 50 | --test_set_save ../data/prepro_data/${dataset}_BERT.pkl \ 51 | --model_name ${model_name} \ 52 | --use_model bert \ 53 | --pretrain_model checkpoint/GAIN_BERT_large_best.pt \ 54 | --batch_size ${batch_size} \ 55 | --test_batch_size ${test_batch_size} \ 56 | --gcn_dim 1064 \ 57 | --gcn_layers 2 \ 58 | --bert_hid_size 1024 \ 59 | --bert_path ../PLM/bert-large-uncased \ 60 | --use_entity_type \ 61 | --use_entity_id \ 62 | --dropout 0.6 \ 63 | --activation relu \ 64 | --input_theta ${input_theta} \ 65 | >logs/test_${model_name}.log 2>&1 & 66 | fi 67 | -------------------------------------------------------------------------------- /code/eval_GAIN_GloVe.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export CUDA_VISIBLE_DEVICES=$1 3 | 4 | # -------------------GAIN_GloVe Evaluation Shell Script-------------------- 5 | 6 | model_name=GAIN_GloVe 7 | batch_size=32 8 | test_batch_size=16 9 | # binary classification threshold, automatically find optimal threshold when -1, default:-1 10 | input_theta=${2--1} 11 | dataset=test 12 | 13 | nohup python3 -u test.py \ 14 | --train_set ../data/train_annotated.json \ 15 | --train_set_save ../data/prepro_data/train_GloVe.pkl \ 16 | --dev_set ../data/dev.json \ 17 | --dev_set_save ../data/prepro_data/dev_GloVe.pkl \ 18 | --test_set ../data/${dataset}.json \ 19 | --test_set_save ../data/prepro_data/${dataset}_GloVe.pkl \ 20 | --use_model bilstm \ 21 | --model_name ${model_name} \ 22 | --pretrain_model checkpoint/GAIN_GloVe_best.pt \ 23 | --batch_size ${batch_size} \ 24 | --test_batch_size ${test_batch_size} \ 25 | --gcn_dim 512 \ 26 | --gcn_layers 2 \ 27 | --lstm_hidden_size 256 \ 28 | --use_entity_type \ 29 | --use_entity_id \ 30 | --word_emb_size 100 \ 31 | --finetune_word \ 32 | --pre_train_word \ 33 | --activation relu \ 34 | --input_theta ${input_theta} \ 35 | >>logs/test_${model_name}.log 2>&1 & 36 | -------------------------------------------------------------------------------- /code/fig_result/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/code/fig_result/.placeholder -------------------------------------------------------------------------------- /code/logs/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/code/logs/.placeholder -------------------------------------------------------------------------------- /code/models/GAIN.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import dgl.nn.pytorch as dglnn 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from transformers import * 7 | 8 | from utils import get_cuda 9 | 10 | 11 | class GAIN_GloVe(nn.Module): 12 | def __init__(self, config): 13 | super(GAIN_GloVe, self).__init__() 14 | self.config = config 15 | 16 | word_emb_size = config.word_emb_size 17 | vocabulary_size = config.vocabulary_size 18 | encoder_input_size = word_emb_size 19 | self.activation = nn.Tanh() if config.activation == 'tanh' else nn.ReLU() 20 | 21 | self.word_emb = nn.Embedding(vocabulary_size, word_emb_size, padding_idx=config.word_pad) 22 | if config.pre_train_word: 23 | self.word_emb = nn.Embedding(config.data_word_vec.shape[0], word_emb_size, padding_idx=config.word_pad) 24 | self.word_emb.weight.data.copy_(torch.from_numpy(config.data_word_vec[:, :word_emb_size])) 25 | 26 | self.word_emb.weight.requires_grad = config.finetune_word 27 | if config.use_entity_type: 28 | encoder_input_size += config.entity_type_size 29 | self.entity_type_emb = nn.Embedding(config.entity_type_num, config.entity_type_size, 30 | padding_idx=config.entity_type_pad) 31 | 32 | if config.use_entity_id: 33 | encoder_input_size += config.entity_id_size 34 | self.entity_id_emb = nn.Embedding(config.max_entity_num + 1, config.entity_id_size, 35 | padding_idx=config.entity_id_pad) 36 | 37 | self.encoder = BiLSTM(encoder_input_size, config) 38 | 39 | self.gcn_dim = config.gcn_dim 40 | assert self.gcn_dim == 2 * config.lstm_hidden_size, 'gcn dim should be the lstm hidden dim * 2' 41 | rel_name_lists = ['intra', 'inter', 'global'] 42 | self.GCN_layers = nn.ModuleList([RelGraphConvLayer(self.gcn_dim, self.gcn_dim, rel_name_lists, 43 | num_bases=len(rel_name_lists), activation=self.activation, 44 | self_loop=True, dropout=self.config.dropout) 45 | for i in range(config.gcn_layers)]) 46 | 47 | self.bank_size = self.config.gcn_dim * (self.config.gcn_layers + 1) 48 | self.dropout = nn.Dropout(self.config.dropout) 49 | 50 | self.predict = nn.Sequential( 51 | nn.Linear(self.bank_size * 5 + self.gcn_dim * 4, self.bank_size * 2), # 52 | self.activation, 53 | self.dropout, 54 | nn.Linear(self.bank_size * 2, config.relation_nums), 55 | ) 56 | 57 | self.edge_layer = RelEdgeLayer(node_feat=self.gcn_dim, edge_feat=self.gcn_dim, 58 | activation=self.activation, dropout=config.dropout) 59 | 60 | self.path_info_mapping = nn.Linear(self.gcn_dim * 4, self.gcn_dim * 4) 61 | self.attention = Attention(self.bank_size * 2, self.gcn_dim * 4) 62 | 63 | def forward(self, **params): 64 | ''' 65 | words: [batch_size, max_length] 66 | src_lengths: [batchs_size] 67 | mask: [batch_size, max_length] 68 | entity_type: [batch_size, max_length] 69 | entity_id: [batch_size, max_length] 70 | mention_id: [batch_size, max_length] 71 | distance: [batch_size, max_length] 72 | entity2mention_table: list of [local_entity_num, local_mention_num] 73 | graphs: list of DGLHeteroGraph 74 | h_t_pairs: [batch_size, h_t_limit, 2] 75 | ''' 76 | src = self.word_emb(params['words']) 77 | mask = params['mask'] 78 | bsz, slen, _ = src.size() 79 | 80 | if self.config.use_entity_type: 81 | src = torch.cat([src, self.entity_type_emb(params['entity_type'])], dim=-1) 82 | 83 | if self.config.use_entity_id: 84 | src = torch.cat([src, self.entity_id_emb(params['entity_id'])], dim=-1) 85 | 86 | # src: [batch_size, slen, encoder_input_size] 87 | # src_lengths: [batchs_size] 88 | 89 | encoder_outputs, (output_h_t, _) = self.encoder(src, params['src_lengths']) 90 | encoder_outputs[mask == 0] = 0 91 | # encoder_outputs: [batch_size, slen, 2*encoder_hid_size] 92 | # output_h_t: [batch_size, 2*encoder_hid_size] 93 | 94 | graphs = params['graphs'] 95 | 96 | mention_id = params['mention_id'] 97 | features = None 98 | 99 | for i in range(len(graphs)): 100 | encoder_output = encoder_outputs[i] # [slen, 2*encoder_hid_size] 101 | mention_num = torch.max(mention_id[i]) 102 | mention_index = get_cuda( 103 | (torch.arange(mention_num) + 1).unsqueeze(1).expand(-1, slen)) # [mention_num, slen] 104 | mentions = mention_id[i].unsqueeze(0).expand(mention_num, -1) # [mention_num, slen] 105 | select_metrix = (mention_index == mentions).float() # [mention_num, slen] 106 | # average word -> mention 107 | word_total_numbers = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, slen) # [mention_num, slen] 108 | select_metrix = torch.where(word_total_numbers > 0, select_metrix / word_total_numbers, select_metrix) 109 | x = torch.mm(select_metrix, encoder_output) # [mention_num, 2*encoder_hid_size] 110 | 111 | x = torch.cat((output_h_t[i].unsqueeze(0), x), dim=0) 112 | 113 | if features is None: 114 | features = x 115 | else: 116 | features = torch.cat((features, x), dim=0) 117 | 118 | graph_big = dgl.batch_hetero(graphs) 119 | output_features = [features] 120 | 121 | for GCN_layer in self.GCN_layers: 122 | features = GCN_layer(graph_big, {"node": features})["node"] # [total_mention_nums, gcn_dim] 123 | output_features.append(features) 124 | 125 | output_feature = torch.cat(output_features, dim=-1) 126 | 127 | graphs = dgl.unbatch_hetero(graph_big) 128 | 129 | # mention -> entity 130 | entity2mention_table = params['entity2mention_table'] # list of [entity_num, mention_num] 131 | entity_num = torch.max(params['entity_id']) 132 | entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size)) 133 | global_info = get_cuda(torch.Tensor(bsz, self.bank_size)) 134 | 135 | cur_idx = 0 136 | entity_graph_feature = None 137 | for i in range(len(graphs)): 138 | # average mention -> entity 139 | select_metrix = entity2mention_table[i].float() # [local_entity_num, mention_num] 140 | select_metrix[0][0] = 1 141 | mention_nums = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, select_metrix.size(1)) 142 | select_metrix = torch.where(mention_nums > 0, select_metrix / mention_nums, select_metrix) 143 | node_num = graphs[i].number_of_nodes('node') 144 | entity_representation = torch.mm(select_metrix, output_feature[cur_idx:cur_idx + node_num]) 145 | entity_bank[i, :select_metrix.size(0) - 1] = entity_representation[1:] 146 | global_info[i] = output_feature[cur_idx] 147 | cur_idx += node_num 148 | 149 | if entity_graph_feature is None: 150 | entity_graph_feature = entity_representation[1:, -self.config.gcn_dim:] 151 | else: 152 | entity_graph_feature = torch.cat( 153 | (entity_graph_feature, entity_representation[1:, -self.config.gcn_dim:]), dim=0) 154 | 155 | h_t_pairs = params['h_t_pairs'] 156 | h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1 # [batch_size, h_t_limit, 2] 157 | h_t_limit = h_t_pairs.size(1) 158 | 159 | # [batch_size, h_t_limit, bank_size] 160 | h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand(-1, -1, self.bank_size) 161 | t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand(-1, -1, self.bank_size) 162 | 163 | # [batch_size, h_t_limit, bank_size] 164 | h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index) 165 | t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index) 166 | 167 | global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1) 168 | 169 | entity_graphs = params['entity_graphs'] 170 | entity_graph_big = dgl.batch(entity_graphs) 171 | self.edge_layer(entity_graph_big, entity_graph_feature) 172 | entity_graphs = dgl.unbatch(entity_graph_big) 173 | path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4))) 174 | relation_mask = params['relation_mask'] 175 | path_table = params['path_table'] 176 | for i in range(len(entity_graphs)): 177 | path_t = path_table[i] 178 | for j in range(h_t_limit): 179 | if relation_mask is not None and relation_mask[i, j].item() == 0: 180 | break 181 | 182 | h = h_t_pairs[i, j, 0].item() 183 | t = h_t_pairs[i, j, 1].item() 184 | # for evaluate 185 | if relation_mask is None and h == 0 and t == 0: 186 | continue 187 | 188 | if (h + 1, t + 1) in path_t: 189 | v = [val - 1 for val in path_t[(h + 1, t + 1)]] 190 | elif (t + 1, h + 1) in path_t: 191 | v = [val - 1 for val in path_t[(t + 1, h + 1)]] 192 | else: 193 | print(h, t, v) 194 | print(entity_graphs[i].all_edges()) 195 | print(h_t_pairs) 196 | print(relation_mask) 197 | assert 1 == 2 198 | 199 | middle_node_num = len(v) 200 | 201 | if middle_node_num == 0: 202 | continue 203 | 204 | # forward 205 | edge_ids = get_cuda(entity_graphs[i].edge_ids([h for _ in range(middle_node_num)], v)) 206 | forward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 207 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [t for _ in range(middle_node_num)])) 208 | forward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 209 | 210 | # backward 211 | edge_ids = get_cuda(entity_graphs[i].edge_ids([t for _ in range(middle_node_num)], v)) 212 | backward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 213 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [h for _ in range(middle_node_num)])) 214 | backward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 215 | 216 | tmp_path_info = torch.cat((forward_first, forward_second, backward_first, backward_second), dim=-1) 217 | _, attn_value = self.attention(torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1), tmp_path_info) 218 | path_info[i, j] = attn_value 219 | 220 | entity_graphs[i].edata.pop('h') 221 | 222 | path_info = self.dropout( 223 | self.activation( 224 | self.path_info_mapping(path_info) 225 | ) 226 | ) 227 | 228 | predictions = self.predict(torch.cat( 229 | (h_entity, t_entity, torch.abs(h_entity - t_entity), torch.mul(h_entity, t_entity), global_info, path_info), 230 | dim=-1)) 231 | return predictions 232 | 233 | 234 | class GAIN_BERT(nn.Module): 235 | def __init__(self, config): 236 | super(GAIN_BERT, self).__init__() 237 | self.config = config 238 | if config.activation == 'tanh': 239 | self.activation = nn.Tanh() 240 | elif config.activation == 'relu': 241 | self.activation = nn.ReLU() 242 | else: 243 | assert 1 == 2, "you should provide activation function." 244 | 245 | if config.use_entity_type: 246 | self.entity_type_emb = nn.Embedding(config.entity_type_num, config.entity_type_size, 247 | padding_idx=config.entity_type_pad) 248 | if config.use_entity_id: 249 | self.entity_id_emb = nn.Embedding(config.max_entity_num + 1, config.entity_id_size, 250 | padding_idx=config.entity_id_pad) 251 | 252 | self.bert = BertModel.from_pretrained(config.bert_path) 253 | if config.bert_fix: 254 | for p in self.bert.parameters(): 255 | p.requires_grad = False 256 | 257 | self.gcn_dim = config.gcn_dim 258 | assert self.gcn_dim == config.bert_hid_size + config.entity_id_size + config.entity_type_size 259 | 260 | rel_name_lists = ['intra', 'inter', 'global'] 261 | self.GCN_layers = nn.ModuleList([RelGraphConvLayer(self.gcn_dim, self.gcn_dim, rel_name_lists, 262 | num_bases=len(rel_name_lists), activation=self.activation, 263 | self_loop=True, dropout=self.config.dropout) 264 | for i in range(config.gcn_layers)]) 265 | 266 | self.bank_size = self.gcn_dim * (self.config.gcn_layers + 1) 267 | self.dropout = nn.Dropout(self.config.dropout) 268 | self.predict = nn.Sequential( 269 | nn.Linear(self.bank_size * 5 + self.gcn_dim * 4, self.bank_size * 2), 270 | self.activation, 271 | self.dropout, 272 | nn.Linear(self.bank_size * 2, config.relation_nums), 273 | ) 274 | 275 | self.edge_layer = RelEdgeLayer(node_feat=self.gcn_dim, edge_feat=self.gcn_dim, 276 | activation=self.activation, dropout=config.dropout) 277 | 278 | self.path_info_mapping = nn.Linear(self.gcn_dim * 4, self.gcn_dim * 4) 279 | 280 | self.attention = Attention(self.bank_size * 2, self.gcn_dim * 4) 281 | 282 | def forward(self, **params): 283 | ''' 284 | words: [batch_size, max_length] 285 | src_lengths: [batchs_size] 286 | mask: [batch_size, max_length] 287 | entity_type: [batch_size, max_length] 288 | entity_id: [batch_size, max_length] 289 | mention_id: [batch_size, max_length] 290 | distance: [batch_size, max_length] 291 | entity2mention_table: list of [local_entity_num, local_mention_num] 292 | graphs: list of DGLHeteroGraph 293 | h_t_pairs: [batch_size, h_t_limit, 2] 294 | ht_pair_distance: [batch_size, h_t_limit] 295 | ''' 296 | words = params['words'] 297 | mask = params['mask'] 298 | bsz, slen = words.size() 299 | 300 | encoder_outputs, sentence_cls = self.bert(input_ids=words, attention_mask=mask) 301 | # encoder_outputs[mask == 0] = 0 302 | 303 | if self.config.use_entity_type: 304 | encoder_outputs = torch.cat([encoder_outputs, self.entity_type_emb(params['entity_type'])], dim=-1) 305 | 306 | if self.config.use_entity_id: 307 | encoder_outputs = torch.cat([encoder_outputs, self.entity_id_emb(params['entity_id'])], dim=-1) 308 | 309 | sentence_cls = torch.cat( 310 | (sentence_cls, get_cuda(torch.zeros((bsz, self.config.entity_type_size + self.config.entity_id_size)))), 311 | dim=-1) 312 | # encoder_outputs: [batch_size, slen, bert_hid+type_size+id_size] 313 | # sentence_cls: [batch_size, bert_hid+type_size+id_size] 314 | 315 | graphs = params['graphs'] 316 | 317 | mention_id = params['mention_id'] 318 | features = None 319 | 320 | for i in range(len(graphs)): 321 | encoder_output = encoder_outputs[i] # [slen, bert_hid] 322 | mention_num = torch.max(mention_id[i]) 323 | mention_index = get_cuda( 324 | (torch.arange(mention_num) + 1).unsqueeze(1).expand(-1, slen)) # [mention_num, slen] 325 | mentions = mention_id[i].unsqueeze(0).expand(mention_num, -1) # [mention_num, slen] 326 | select_metrix = (mention_index == mentions).float() # [mention_num, slen] 327 | # average word -> mention 328 | word_total_numbers = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, slen) # [mention_num, slen] 329 | select_metrix = torch.where(word_total_numbers > 0, select_metrix / word_total_numbers, select_metrix) 330 | 331 | x = torch.mm(select_metrix, encoder_output) # [mention_num, bert_hid] 332 | x = torch.cat((sentence_cls[i].unsqueeze(0), x), dim=0) 333 | 334 | if features is None: 335 | features = x 336 | else: 337 | features = torch.cat((features, x), dim=0) 338 | 339 | graph_big = dgl.batch_hetero(graphs) 340 | output_features = [features] 341 | 342 | for GCN_layer in self.GCN_layers: 343 | features = GCN_layer(graph_big, {"node": features})["node"] # [total_mention_nums, gcn_dim] 344 | output_features.append(features) 345 | 346 | output_feature = torch.cat(output_features, dim=-1) 347 | 348 | graphs = dgl.unbatch_hetero(graph_big) 349 | 350 | # mention -> entity 351 | entity2mention_table = params['entity2mention_table'] # list of [entity_num, mention_num] 352 | entity_num = torch.max(params['entity_id']) 353 | entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size)) 354 | global_info = get_cuda(torch.Tensor(bsz, self.bank_size)) 355 | 356 | cur_idx = 0 357 | entity_graph_feature = None 358 | for i in range(len(graphs)): 359 | # average mention -> entity 360 | select_metrix = entity2mention_table[i].float() # [local_entity_num, mention_num] 361 | select_metrix[0][0] = 1 362 | mention_nums = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, select_metrix.size(1)) 363 | select_metrix = torch.where(mention_nums > 0, select_metrix / mention_nums, select_metrix) 364 | node_num = graphs[i].number_of_nodes('node') 365 | entity_representation = torch.mm(select_metrix, output_feature[cur_idx:cur_idx + node_num]) 366 | entity_bank[i, :select_metrix.size(0) - 1] = entity_representation[1:] 367 | global_info[i] = output_feature[cur_idx] 368 | cur_idx += node_num 369 | 370 | if entity_graph_feature is None: 371 | entity_graph_feature = entity_representation[1:, -self.gcn_dim:] 372 | else: 373 | entity_graph_feature = torch.cat((entity_graph_feature, entity_representation[1:, -self.gcn_dim:]), 374 | dim=0) 375 | 376 | h_t_pairs = params['h_t_pairs'] 377 | h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1 # [batch_size, h_t_limit, 2] 378 | h_t_limit = h_t_pairs.size(1) 379 | 380 | # [batch_size, h_t_limit, bank_size] 381 | h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand(-1, -1, self.bank_size) 382 | t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand(-1, -1, self.bank_size) 383 | 384 | # [batch_size, h_t_limit, bank_size] 385 | h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index) 386 | t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index) 387 | 388 | global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1) 389 | 390 | entity_graphs = params['entity_graphs'] 391 | entity_graph_big = dgl.batch(entity_graphs) 392 | self.edge_layer(entity_graph_big, entity_graph_feature) 393 | 394 | entity_graphs = dgl.unbatch(entity_graph_big) 395 | path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4))) 396 | relation_mask = params['relation_mask'] 397 | path_table = params['path_table'] 398 | for i in range(len(entity_graphs)): 399 | path_t = path_table[i] 400 | for j in range(h_t_limit): 401 | if relation_mask is not None and relation_mask[i, j].item() == 0: 402 | break 403 | 404 | h = h_t_pairs[i, j, 0].item() 405 | t = h_t_pairs[i, j, 1].item() 406 | # for evaluate 407 | if relation_mask is None and h == 0 and t == 0: 408 | continue 409 | 410 | if (h + 1, t + 1) in path_t: 411 | v = [val - 1 for val in path_t[(h + 1, t + 1)]] 412 | elif (t + 1, h + 1) in path_t: 413 | v = [val - 1 for val in path_t[(t + 1, h + 1)]] 414 | else: 415 | print(h, t, v) 416 | print(entity_graphs[i].number_of_nodes()) 417 | print(entity_graphs[i].all_edges()) 418 | print(path_table) 419 | print(h_t_pairs) 420 | print(relation_mask) 421 | assert 1 == 2 422 | 423 | middle_node_num = len(v) 424 | 425 | if middle_node_num == 0: 426 | continue 427 | 428 | # forward 429 | edge_ids = get_cuda(entity_graphs[i].edge_ids([h for _ in range(middle_node_num)], v)) 430 | forward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 431 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [t for _ in range(middle_node_num)])) 432 | forward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 433 | 434 | # backward 435 | edge_ids = get_cuda(entity_graphs[i].edge_ids([t for _ in range(middle_node_num)], v)) 436 | backward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 437 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [h for _ in range(middle_node_num)])) 438 | backward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 439 | 440 | tmp_path_info = torch.cat((forward_first, forward_second, backward_first, backward_second), dim=-1) 441 | _, attn_value = self.attention(torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1), tmp_path_info) 442 | path_info[i, j] = attn_value 443 | 444 | entity_graphs[i].edata.pop('h') 445 | 446 | path_info = self.dropout( 447 | self.activation( 448 | self.path_info_mapping(path_info) 449 | ) 450 | ) 451 | 452 | predictions = self.predict(torch.cat( 453 | (h_entity, t_entity, torch.abs(h_entity - t_entity), torch.mul(h_entity, t_entity), global_info, path_info), 454 | dim=-1)) 455 | # predictions = self.predict(torch.cat((h_entity, t_entity, torch.abs(h_entity-t_entity), torch.mul(h_entity, t_entity), global_info), dim=-1)) 456 | return predictions 457 | 458 | 459 | class Attention(nn.Module): 460 | def __init__(self, src_size, trg_size): 461 | super().__init__() 462 | self.W = nn.Bilinear(src_size, trg_size, 1) 463 | self.softmax = nn.Softmax(dim=-1) 464 | 465 | def forward(self, src, trg, attention_mask=None): 466 | ''' 467 | src: [src_size] 468 | trg: [middle_node, trg_size] 469 | ''' 470 | 471 | score = self.W(src.unsqueeze(0).expand(trg.size(0), -1), trg) 472 | score = self.softmax(score) 473 | value = torch.mm(score.permute(1, 0), trg) 474 | 475 | return score.squeeze(0), value.squeeze(0) 476 | 477 | 478 | class BiLSTM(nn.Module): 479 | def __init__(self, input_size, config): 480 | super().__init__() 481 | self.config = config 482 | self.lstm = nn.LSTM(input_size=input_size, hidden_size=config.lstm_hidden_size, 483 | num_layers=config.nlayers, batch_first=True, 484 | bidirectional=True) 485 | self.in_dropout = nn.Dropout(config.dropout) 486 | self.out_dropout = nn.Dropout(config.dropout) 487 | 488 | def forward(self, src, src_lengths): 489 | ''' 490 | src: [batch_size, slen, input_size] 491 | src_lengths: [batch_size] 492 | ''' 493 | 494 | self.lstm.flatten_parameters() 495 | bsz, slen, input_size = src.size() 496 | 497 | src = self.in_dropout(src) 498 | 499 | new_src_lengths, sort_index = torch.sort(src_lengths, dim=-1, descending=True) 500 | new_src = torch.index_select(src, dim=0, index=sort_index) 501 | 502 | packed_src = nn.utils.rnn.pack_padded_sequence(new_src, new_src_lengths, batch_first=True, enforce_sorted=True) 503 | packed_outputs, (src_h_t, src_c_t) = self.lstm(packed_src) 504 | 505 | outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True, 506 | padding_value=self.config.word_pad) 507 | 508 | unsort_index = torch.argsort(sort_index) 509 | outputs = torch.index_select(outputs, dim=0, index=unsort_index) 510 | 511 | src_h_t = src_h_t.view(self.config.nlayers, 2, bsz, self.config.lstm_hidden_size) 512 | src_c_t = src_c_t.view(self.config.nlayers, 2, bsz, self.config.lstm_hidden_size) 513 | output_h_t = torch.cat((src_h_t[-1, 0], src_h_t[-1, 1]), dim=-1) 514 | output_c_t = torch.cat((src_c_t[-1, 0], src_c_t[-1, 1]), dim=-1) 515 | output_h_t = torch.index_select(output_h_t, dim=0, index=unsort_index) 516 | output_c_t = torch.index_select(output_c_t, dim=0, index=unsort_index) 517 | 518 | outputs = self.out_dropout(outputs) 519 | output_h_t = self.out_dropout(output_h_t) 520 | output_c_t = self.out_dropout(output_c_t) 521 | 522 | return outputs, (output_h_t, output_c_t) 523 | 524 | 525 | class RelGraphConvLayer(nn.Module): 526 | r"""Relational graph convolution layer. 527 | Parameters 528 | ---------- 529 | in_feat : int 530 | Input feature size. 531 | out_feat : int 532 | Output feature size. 533 | rel_names : list[str] 534 | Relation names. 535 | num_bases : int, optional 536 | Number of bases. If is none, use number of relations. Default: None. 537 | weight : bool, optional 538 | True if a linear layer is applied after message passing. Default: True 539 | bias : bool, optional 540 | True if bias is added. Default: True 541 | activation : callable, optional 542 | Activation function. Default: None 543 | self_loop : bool, optional 544 | True to include self loop message. Default: False 545 | dropout : float, optional 546 | Dropout rate. Default: 0.0 547 | """ 548 | 549 | def __init__(self, 550 | in_feat, 551 | out_feat, 552 | rel_names, 553 | num_bases, 554 | *, 555 | weight=True, 556 | bias=True, 557 | activation=None, 558 | self_loop=False, 559 | dropout=0.0): 560 | super(RelGraphConvLayer, self).__init__() 561 | self.in_feat = in_feat 562 | self.out_feat = out_feat 563 | self.rel_names = rel_names 564 | self.num_bases = num_bases 565 | self.bias = bias 566 | self.activation = activation 567 | self.self_loop = self_loop 568 | 569 | self.conv = dglnn.HeteroGraphConv({ 570 | rel: dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) 571 | for rel in rel_names 572 | }) 573 | 574 | self.use_weight = weight 575 | self.use_basis = num_bases < len(self.rel_names) and weight 576 | if self.use_weight: 577 | if self.use_basis: 578 | self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) 579 | else: 580 | self.weight = nn.Parameter(torch.Tensor(len(self.rel_names), in_feat, out_feat)) 581 | nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) 582 | 583 | # bias 584 | if bias: 585 | self.h_bias = nn.Parameter(torch.Tensor(out_feat)) 586 | nn.init.zeros_(self.h_bias) 587 | 588 | # weight for self loop 589 | if self.self_loop: 590 | self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat)) 591 | nn.init.xavier_uniform_(self.loop_weight, 592 | gain=nn.init.calculate_gain('relu')) 593 | 594 | self.dropout = nn.Dropout(dropout) 595 | 596 | def forward(self, g, inputs): 597 | """Forward computation 598 | Parameters 599 | ---------- 600 | g : DGLHeteroGraph 601 | Input graph. 602 | inputs : dict[str, torch.Tensor] 603 | Node feature for each node type. 604 | Returns 605 | ------- 606 | dict[str, torch.Tensor] 607 | New node features for each node type. 608 | """ 609 | g = g.local_var() 610 | if self.use_weight: 611 | weight = self.basis() if self.use_basis else self.weight 612 | wdict = {self.rel_names[i]: {'weight': w.squeeze(0)} 613 | for i, w in enumerate(torch.split(weight, 1, dim=0))} 614 | else: 615 | wdict = {} 616 | hs = self.conv(g, inputs, mod_kwargs=wdict) 617 | 618 | def _apply(ntype, h): 619 | if self.self_loop: 620 | h = h + torch.matmul(inputs[ntype], self.loop_weight) 621 | if self.bias: 622 | h = h + self.h_bias 623 | if self.activation: 624 | h = self.activation(h) 625 | return self.dropout(h) 626 | 627 | return {ntype: _apply(ntype, h) for ntype, h in hs.items()} 628 | 629 | 630 | class RelEdgeLayer(nn.Module): 631 | def __init__(self, 632 | node_feat, 633 | edge_feat, 634 | activation, 635 | dropout=0.0): 636 | super(RelEdgeLayer, self).__init__() 637 | self.node_feat = node_feat 638 | self.edge_feat = edge_feat 639 | self.activation = activation 640 | self.dropout = nn.Dropout(dropout) 641 | self.mapping = nn.Linear(node_feat * 2, edge_feat) 642 | 643 | def forward(self, g, inputs): 644 | # g = g.local_var() 645 | 646 | g.ndata['h'] = inputs # [total_mention_num, node_feat] 647 | g.apply_edges(lambda edges: { 648 | 'h': self.dropout(self.activation(self.mapping(torch.cat((edges.src['h'], edges.dst['h']), dim=-1))))}) 649 | g.ndata.pop('h') 650 | 651 | 652 | class Bert(): 653 | MASK = '[MASK]' 654 | CLS = "[CLS]" 655 | SEP = "[SEP]" 656 | 657 | def __init__(self, model_class, model_name, model_path=None): 658 | super().__init__() 659 | self.model_name = model_name 660 | print(model_path) 661 | self.tokenizer = BertTokenizer.from_pretrained(model_path) 662 | self.max_len = 512 663 | 664 | def tokenize(self, text, masked_idxs=None): 665 | tokenized_text = self.tokenizer.tokenize(text) 666 | if masked_idxs is not None: 667 | for idx in masked_idxs: 668 | tokenized_text[idx] = self.MASK 669 | # prepend [CLS] and append [SEP] 670 | # see https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py#L195 # NOQA 671 | tokenized = [self.CLS] + tokenized_text + [self.SEP] 672 | return tokenized 673 | 674 | def tokenize_to_ids(self, text, masked_idxs=None, pad=True): 675 | tokens = self.tokenize(text, masked_idxs) 676 | return tokens, self.convert_tokens_to_ids(tokens, pad=pad) 677 | 678 | def convert_tokens_to_ids(self, tokens, pad=True): 679 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 680 | ids = torch.tensor([token_ids]) 681 | # assert ids.size(1) < self.max_len 682 | ids = ids[:, :self.max_len] # https://github.com/DreamInvoker/GAIN/issues/4 683 | if pad: 684 | padded_ids = torch.zeros(1, self.max_len).to(ids) 685 | padded_ids[0, :ids.size(1)] = ids 686 | mask = torch.zeros(1, self.max_len).to(ids) 687 | mask[0, :ids.size(1)] = 1 688 | return padded_ids, mask 689 | else: 690 | return ids 691 | 692 | def flatten(self, list_of_lists): 693 | for list in list_of_lists: 694 | for item in list: 695 | yield item 696 | 697 | def subword_tokenize(self, tokens): 698 | """Segment each token into subwords while keeping track of 699 | token boundaries. 700 | Parameters 701 | ---------- 702 | tokens: A sequence of strings, representing input tokens. 703 | Returns 704 | ------- 705 | A tuple consisting of: 706 | - A list of subwords, flanked by the special symbols required 707 | by Bert (CLS and SEP). 708 | - An array of indices into the list of subwords, indicating 709 | that the corresponding subword is the start of a new 710 | token. For example, [1, 3, 4, 7] means that the subwords 711 | 1, 3, 4, 7 are token starts, while all other subwords 712 | (0, 2, 5, 6, 8...) are in or at the end of tokens. 713 | This list allows selecting Bert hidden states that 714 | represent tokens, which is necessary in sequence 715 | labeling. 716 | """ 717 | subwords = list(map(self.tokenizer.tokenize, tokens)) 718 | subword_lengths = list(map(len, subwords)) 719 | subwords = [self.CLS] + list(self.flatten(subwords))[:509] + [self.SEP] 720 | token_start_idxs = 1 + np.cumsum([0] + subword_lengths[:-1]) 721 | token_start_idxs[token_start_idxs > 509] = 512 722 | return subwords, token_start_idxs 723 | 724 | def subword_tokenize_to_ids(self, tokens): 725 | """Segment each token into subwords while keeping track of 726 | token boundaries and convert subwords into IDs. 727 | Parameters 728 | ---------- 729 | tokens: A sequence of strings, representing input tokens. 730 | Returns 731 | ------- 732 | A tuple consisting of: 733 | - A list of subword IDs, including IDs of the special 734 | symbols (CLS and SEP) required by Bert. 735 | - A mask indicating padding tokens. 736 | - An array of indices into the list of subwords. See 737 | doc of subword_tokenize. 738 | """ 739 | subwords, token_start_idxs = self.subword_tokenize(tokens) 740 | subword_ids, mask = self.convert_tokens_to_ids(subwords) 741 | return subword_ids.numpy(), token_start_idxs, subwords 742 | 743 | def segment_ids(self, segment1_len, segment2_len): 744 | ids = [0] * segment1_len + [1] * segment2_len 745 | return torch.tensor([ids]) 746 | -------------------------------------------------------------------------------- /code/models/GAIN_nomention.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import dgl.nn.pytorch as dglnn 3 | import torch 4 | import torch.nn as nn 5 | from transformers import * 6 | 7 | from utils import get_cuda 8 | 9 | 10 | # for no mention module ablation study 11 | 12 | class GAIN_GloVe(nn.Module): 13 | def __init__(self, config): 14 | super(GAIN_GloVe, self).__init__() 15 | self.config = config 16 | 17 | word_emb_size = config.word_emb_size 18 | vocabulary_size = config.vocabulary_size 19 | encoder_input_size = word_emb_size 20 | self.activation = nn.Tanh() if config.activation == 'tanh' else nn.ReLU() 21 | 22 | self.word_emb = nn.Embedding(vocabulary_size, word_emb_size, padding_idx=config.word_pad) 23 | if config.pre_train_word: 24 | self.word_emb = nn.Embedding(config.data_word_vec.shape[0], word_emb_size, padding_idx=config.word_pad) 25 | self.word_emb.weight.data.copy_(torch.from_numpy(config.data_word_vec[:, :word_emb_size])) 26 | 27 | self.word_emb.weight.requires_grad = config.finetune_word 28 | if config.use_entity_type: 29 | encoder_input_size += config.entity_type_size 30 | self.entity_type_emb = nn.Embedding(config.entity_type_num, config.entity_type_size, 31 | padding_idx=config.entity_type_pad) 32 | 33 | if config.use_entity_id: 34 | encoder_input_size += config.entity_id_size 35 | self.entity_id_emb = nn.Embedding(config.max_entity_num + 1, config.entity_id_size, 36 | padding_idx=config.entity_id_pad) 37 | 38 | self.encoder = BiLSTM(encoder_input_size, config) 39 | 40 | self.gcn_dim = config.gcn_dim 41 | assert self.gcn_dim == 2 * config.lstm_hidden_size, 'gcn dim should be the lstm hidden dim * 2' 42 | rel_name_lists = ['intra', 'inter', 'global'] 43 | self.GCN_layers = nn.ModuleList([dglnn.GraphConv(self.gcn_dim, self.gcn_dim, norm='right', weight=True, 44 | bias=True, activation=self.activation) 45 | for i in range(config.gcn_layers)]) 46 | 47 | self.bank_size = self.config.gcn_dim * (self.config.gcn_layers + 1) 48 | self.dropout = nn.Dropout(self.config.dropout) 49 | 50 | self.predict = nn.Sequential( 51 | nn.Linear(self.bank_size * 4 + self.gcn_dim * 5, self.bank_size * 2), 52 | self.activation, 53 | self.dropout, 54 | nn.Linear(self.bank_size * 2, config.relation_nums), 55 | ) 56 | 57 | self.edge_layer = RelEdgeLayer(node_feat=self.gcn_dim, edge_feat=self.gcn_dim, 58 | activation=self.activation, dropout=config.dropout) 59 | self.path_info_mapping = nn.Linear(self.gcn_dim * 4, self.gcn_dim * 4) 60 | self.attention = Attention(self.bank_size * 2, self.gcn_dim * 4) 61 | 62 | def forward(self, **params): 63 | ''' 64 | words: [batch_size, max_length] 65 | src_lengths: [batchs_size] 66 | mask: [batch_size, max_length] 67 | entity_type: [batch_size, max_length] 68 | entity_id: [batch_size, max_length] 69 | mention_id: [batch_size, max_length] 70 | distance: [batch_size, max_length] 71 | entity2mention_table: list of [local_entity_num, local_mention_num] 72 | graphs: list of DGLHeteroGraph 73 | h_t_pairs: [batch_size, h_t_limit, 2] 74 | ''' 75 | src = self.word_emb(params['words']) 76 | mask = params['mask'] 77 | bsz, slen, _ = src.size() 78 | 79 | if self.config.use_entity_type: 80 | src = torch.cat([src, self.entity_type_emb(params['entity_type'])], dim=-1) 81 | 82 | if self.config.use_entity_id: 83 | src = torch.cat([src, self.entity_id_emb(params['entity_id'])], dim=-1) 84 | 85 | # src: [batch_size, slen, encoder_input_size] 86 | # src_lengths: [batchs_size] 87 | 88 | encoder_outputs, (output_h_t, _) = self.encoder(src, params['src_lengths']) 89 | encoder_outputs[mask == 0] = 0 90 | # encoder_outputs: [batch_size, slen, 2*encoder_hid_size] 91 | # output_h_t: [batch_size, 2*encoder_hid_size] 92 | 93 | graphs = params['graphs'] 94 | 95 | mention_id = params['mention_id'] 96 | features = None 97 | 98 | for i in range(len(graphs)): 99 | encoder_output = encoder_outputs[i] # [slen, 2*encoder_hid_size] 100 | mention_num = torch.max(mention_id[i]) 101 | mention_index = get_cuda( 102 | (torch.arange(mention_num) + 1).unsqueeze(1).expand(-1, slen)) # [mention_num, slen] 103 | mentions = mention_id[i].unsqueeze(0).expand(mention_num, -1) # [mention_num, slen] 104 | select_metrix = (mention_index == mentions).float() # [mention_num, slen] 105 | # average word -> mention 106 | word_total_numbers = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, slen) # [mention_num, slen] 107 | select_metrix = torch.where(word_total_numbers > 0, select_metrix / word_total_numbers, select_metrix) 108 | x = torch.mm(select_metrix, encoder_output) # [mention_num, 2*encoder_hid_size] 109 | 110 | x = torch.cat((output_h_t[i].unsqueeze(0), x), dim=0) 111 | # x = torch.cat((torch.max(encoder_output, dim=0)[0].unsqueeze(0), x), dim=0) 112 | 113 | if features is None: 114 | features = x 115 | else: 116 | features = torch.cat((features, x), dim=0) 117 | 118 | # mention -> entity 119 | entity2mention_table = params['entity2mention_table'] # list of [entity_num, mention_num] 120 | entity_num = torch.max(params['entity_id']) 121 | global_info = get_cuda(torch.Tensor(bsz, self.gcn_dim)) 122 | 123 | cur_idx = 0 124 | entity_graph_feature = None 125 | for i in range(len(graphs)): 126 | # average mention -> entity 127 | select_metrix = entity2mention_table[i].float() # [local_entity_num, mention_num] 128 | select_metrix[0][0] = 1 129 | mention_nums = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, select_metrix.size(1)) 130 | select_metrix = torch.where(mention_nums > 0, select_metrix / mention_nums, select_metrix) 131 | node_num = graphs[i].number_of_nodes('node') 132 | entity_representation = torch.mm(select_metrix, features[cur_idx:cur_idx + node_num]) 133 | global_info[i] = features[cur_idx] 134 | cur_idx += node_num 135 | 136 | if entity_graph_feature is None: 137 | entity_graph_feature = entity_representation[1:] 138 | else: 139 | entity_graph_feature = torch.cat((entity_graph_feature, entity_representation[1:]), dim=0) 140 | 141 | entity_graphs = params['entity_graphs'] 142 | entity_graph_big = dgl.batch(entity_graphs) 143 | output_features = [entity_graph_feature] 144 | 145 | for GCN_layer in self.GCN_layers: 146 | entity_graph_feature = GCN_layer(entity_graph_big, entity_graph_feature) 147 | output_features.append(entity_graph_feature) 148 | output_features = torch.cat(output_features, dim=-1) 149 | self.edge_layer(entity_graph_big, entity_graph_feature) 150 | entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size)) 151 | entity_graphs = dgl.unbatch(entity_graph_big) 152 | 153 | cur_idx = 0 154 | for i in range(len(entity_graphs)): 155 | node_num = entity_graphs[i].number_of_nodes() 156 | entity_bank[i, :node_num] = output_features[cur_idx:cur_idx + node_num] 157 | cur_idx += node_num 158 | 159 | h_t_pairs = params['h_t_pairs'] 160 | h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1 # [batch_size, h_t_limit, 2] 161 | h_t_limit = h_t_pairs.size(1) 162 | 163 | # [batch_size, h_t_limit, bank_size] 164 | h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand(-1, -1, self.bank_size) 165 | t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand(-1, -1, self.bank_size) 166 | 167 | # [batch_size, h_t_limit, bank_size] 168 | h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index) 169 | t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index) 170 | 171 | global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1) 172 | path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4))) 173 | relation_mask = params['relation_mask'] 174 | path_table = params['path_table'] 175 | for i in range(len(entity_graphs)): 176 | path_t = path_table[i] 177 | for j in range(h_t_limit): 178 | if relation_mask is not None and relation_mask[i, j].item() == 0: 179 | break 180 | 181 | h = h_t_pairs[i, j, 0].item() 182 | t = h_t_pairs[i, j, 1].item() 183 | # for evaluate 184 | if relation_mask is None and h == 0 and t == 0: 185 | continue 186 | 187 | if (h + 1, t + 1) in path_t: 188 | v = [val - 1 for val in path_t[(h + 1, t + 1)]] 189 | elif (t + 1, h + 1) in path_t: 190 | v = [val - 1 for val in path_t[(t + 1, h + 1)]] 191 | else: 192 | print(h, t, v) 193 | print(entity_graphs[i].all_edges()) 194 | print(h_t_pairs) 195 | print(relation_mask) 196 | assert 1 == 2 197 | 198 | middle_node_num = len(v) 199 | 200 | if middle_node_num == 0: 201 | continue 202 | 203 | # forward 204 | edge_ids = get_cuda(entity_graphs[i].edge_ids([h for _ in range(middle_node_num)], v)) 205 | forward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 206 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [t for _ in range(middle_node_num)])) 207 | forward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 208 | 209 | # backward 210 | edge_ids = get_cuda(entity_graphs[i].edge_ids([t for _ in range(middle_node_num)], v)) 211 | backward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 212 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [h for _ in range(middle_node_num)])) 213 | backward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 214 | 215 | tmp_path_info = torch.cat((forward_first, forward_second, backward_first, backward_second), dim=-1) 216 | _, attn_value = self.attention(torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1), tmp_path_info) 217 | path_info[i, j] = attn_value 218 | 219 | entity_graphs[i].edata.pop('h') 220 | 221 | path_info = self.dropout( 222 | self.activation( 223 | self.path_info_mapping(path_info) 224 | ) 225 | ) 226 | 227 | predictions = self.predict(torch.cat( 228 | (h_entity, t_entity, torch.abs(h_entity - t_entity), torch.mul(h_entity, t_entity), global_info, path_info), 229 | dim=-1)) 230 | return predictions 231 | 232 | 233 | class GAIN_BERT(nn.Module): 234 | def __init__(self, config): 235 | super(GAIN_BERT, self).__init__() 236 | self.config = config 237 | self.activation = nn.Tanh() if config.activation == 'tanh' else nn.ReLU() 238 | 239 | if config.use_entity_type: 240 | self.entity_type_emb = nn.Embedding(config.entity_type_num, config.entity_type_size, 241 | padding_idx=config.entity_type_pad) 242 | 243 | if config.use_entity_id: 244 | self.entity_id_emb = nn.Embedding(config.max_entity_num + 1, config.entity_id_size, 245 | padding_idx=config.entity_id_pad) 246 | 247 | self.bert = BertModel.from_pretrained(config.bert_path) 248 | if config.bert_fix: 249 | for p in self.bert.parameters(): 250 | p.requires_grad = False 251 | 252 | self.gcn_dim = config.gcn_dim 253 | assert self.gcn_dim == config.bert_hid_size + config.entity_id_size + config.entity_type_size 254 | 255 | rel_name_lists = ['intra', 'inter', 'global'] 256 | self.GCN_layers = nn.ModuleList([dglnn.GraphConv(self.gcn_dim, self.gcn_dim, norm='right', weight=True, 257 | bias=True, activation=self.activation) 258 | for i in range(config.gcn_layers)]) 259 | 260 | self.bank_size = self.gcn_dim * (self.config.gcn_layers + 1) 261 | 262 | self.dropout = nn.Dropout(self.config.dropout) 263 | 264 | self.predict = nn.Sequential( 265 | nn.Linear(self.bank_size * 4 + self.gcn_dim * 5, self.bank_size * 2), 266 | self.activation, 267 | self.dropout, 268 | nn.Linear(self.bank_size * 2, config.relation_nums), 269 | ) 270 | 271 | self.edge_layer = RelEdgeLayer(node_feat=self.gcn_dim, edge_feat=self.gcn_dim, 272 | activation=self.activation, dropout=config.dropout) 273 | 274 | self.path_info_mapping = nn.Linear(self.gcn_dim * 4, self.gcn_dim * 4) 275 | 276 | self.attention = Attention(self.bank_size * 2, self.gcn_dim * 4) 277 | # self.attention = Attention2(self.bank_size*2, self.gcn_dim*4, self.activation, config) 278 | 279 | def forward(self, **params): 280 | ''' 281 | words: [batch_size, max_length] 282 | src_lengths: [batchs_size] 283 | mask: [batch_size, max_length] 284 | entity_type: [batch_size, max_length] 285 | entity_id: [batch_size, max_length] 286 | mention_id: [batch_size, max_length] 287 | distance: [batch_size, max_length] 288 | entity2mention_table: list of [local_entity_num, local_mention_num] 289 | graphs: list of DGLHeteroGraph 290 | h_t_pairs: [batch_size, h_t_limit, 2] 291 | ht_pair_distance: [batch_size, h_t_limit] 292 | ''' 293 | words = params['words'] 294 | mask = params['mask'] 295 | bsz, slen = words.size() 296 | 297 | encoder_outputs, sentence_cls = self.bert(input_ids=words, attention_mask=mask) 298 | # encoder_outputs[mask == 0] = 0 299 | 300 | if self.config.use_entity_type: 301 | encoder_outputs = torch.cat([encoder_outputs, self.entity_type_emb(params['entity_type'])], dim=-1) 302 | 303 | if self.config.use_entity_id: 304 | encoder_outputs = torch.cat([encoder_outputs, self.entity_id_emb(params['entity_id'])], dim=-1) 305 | 306 | sentence_cls = torch.cat( 307 | (sentence_cls, get_cuda(torch.zeros((bsz, self.config.entity_type_size + self.config.entity_id_size)))), 308 | dim=-1) 309 | # encoder_outputs: [batch_size, slen, bert_hid+type_size+id_size] 310 | # sentence_cls: [batch_size, bert_hid+type_size+id_size] 311 | 312 | graphs = params['graphs'] 313 | 314 | mention_id = params['mention_id'] 315 | features = None 316 | 317 | for i in range(len(graphs)): 318 | encoder_output = encoder_outputs[i] # [slen, bert_hid] 319 | mention_num = torch.max(mention_id[i]) 320 | mention_index = get_cuda( 321 | (torch.arange(mention_num) + 1).unsqueeze(1).expand(-1, slen)) # [mention_num, slen] 322 | mentions = mention_id[i].unsqueeze(0).expand(mention_num, -1) # [mention_num, slen] 323 | select_metrix = (mention_index == mentions).float() # [mention_num, slen] 324 | # average word -> mention 325 | word_total_numbers = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, slen) # [mention_num, slen] 326 | select_metrix = torch.where(word_total_numbers > 0, select_metrix / word_total_numbers, select_metrix) 327 | 328 | x = torch.mm(select_metrix, encoder_output) # [mention_num, bert_hid] 329 | x = torch.cat((sentence_cls[i].unsqueeze(0), x), dim=0) 330 | 331 | if features is None: 332 | features = x 333 | else: 334 | features = torch.cat((features, x), dim=0) 335 | 336 | # mention -> entity 337 | entity2mention_table = params['entity2mention_table'] # list of [entity_num, mention_num] 338 | entity_num = torch.max(params['entity_id']) 339 | global_info = get_cuda(torch.Tensor(bsz, self.gcn_dim)) 340 | 341 | cur_idx = 0 342 | entity_graph_feature = None 343 | for i in range(len(graphs)): 344 | # average mention -> entity 345 | select_metrix = entity2mention_table[i].float() # [local_entity_num, mention_num] 346 | select_metrix[0][0] = 1 347 | mention_nums = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand(-1, select_metrix.size(1)) 348 | select_metrix = torch.where(mention_nums > 0, select_metrix / mention_nums, select_metrix) 349 | node_num = graphs[i].number_of_nodes('node') 350 | entity_representation = torch.mm(select_metrix, features[cur_idx:cur_idx + node_num]) 351 | global_info[i] = features[cur_idx] 352 | cur_idx += node_num 353 | 354 | if entity_graph_feature is None: 355 | entity_graph_feature = entity_representation[1:] 356 | else: 357 | entity_graph_feature = torch.cat((entity_graph_feature, entity_representation[1:]), dim=0) 358 | 359 | entity_graphs = params['entity_graphs'] 360 | entity_graph_big = dgl.batch(entity_graphs) 361 | output_features = [entity_graph_feature] 362 | for GCN_layer in self.GCN_layers: 363 | entity_graph_feature = GCN_layer(entity_graph_big, entity_graph_feature) 364 | output_features.append(entity_graph_feature) 365 | output_features = torch.cat(output_features, dim=-1) 366 | self.edge_layer(entity_graph_big, entity_graph_feature) 367 | entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size)) 368 | entity_graphs = dgl.unbatch(entity_graph_big) 369 | 370 | cur_idx = 0 371 | for i in range(len(entity_graphs)): 372 | node_num = entity_graphs[i].number_of_nodes() 373 | entity_bank[i, :node_num] = output_features[cur_idx:cur_idx + node_num] 374 | cur_idx += node_num 375 | 376 | h_t_pairs = params['h_t_pairs'] 377 | h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1 # [batch_size, h_t_limit, 2] 378 | h_t_limit = h_t_pairs.size(1) 379 | 380 | # [batch_size, h_t_limit, bank_size] 381 | h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand(-1, -1, self.bank_size) 382 | t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand(-1, -1, self.bank_size) 383 | 384 | # [batch_size, h_t_limit, bank_size] 385 | h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index) 386 | t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index) 387 | 388 | global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1) 389 | path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4))) 390 | relation_mask = params['relation_mask'] 391 | path_table = params['path_table'] 392 | for i in range(len(entity_graphs)): 393 | path_t = path_table[i] 394 | for j in range(h_t_limit): 395 | if relation_mask is not None and relation_mask[i, j].item() == 0: 396 | break 397 | 398 | h = h_t_pairs[i, j, 0].item() 399 | t = h_t_pairs[i, j, 1].item() 400 | # for evaluate 401 | if relation_mask is None and h == 0 and t == 0: 402 | continue 403 | 404 | if (h + 1, t + 1) in path_t: 405 | v = [val - 1 for val in path_t[(h + 1, t + 1)]] 406 | elif (t + 1, h + 1) in path_t: 407 | v = [val - 1 for val in path_t[(t + 1, h + 1)]] 408 | else: 409 | print(h, t, v) 410 | print(entity_graphs[i].number_of_nodes()) 411 | print(entity_graphs[i].all_edges()) 412 | print(path_table) 413 | print(h_t_pairs) 414 | print(relation_mask) 415 | assert 1 == 2 416 | 417 | middle_node_num = len(v) 418 | 419 | if middle_node_num == 0: 420 | continue 421 | 422 | # forward 423 | edge_ids = get_cuda(entity_graphs[i].edge_ids([h for _ in range(middle_node_num)], v)) 424 | forward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 425 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [t for _ in range(middle_node_num)])) 426 | forward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 427 | 428 | # backward 429 | edge_ids = get_cuda(entity_graphs[i].edge_ids([t for _ in range(middle_node_num)], v)) 430 | backward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 431 | edge_ids = get_cuda(entity_graphs[i].edge_ids(v, [h for _ in range(middle_node_num)])) 432 | backward_second = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) 433 | 434 | tmp_path_info = torch.cat((forward_first, forward_second, backward_first, backward_second), dim=-1) 435 | _, attn_value = self.attention(torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1), tmp_path_info) 436 | path_info[i, j] = attn_value 437 | 438 | entity_graphs[i].edata.pop('h') 439 | 440 | path_info = self.dropout( 441 | self.activation( 442 | self.path_info_mapping(path_info) 443 | ) 444 | ) 445 | 446 | predictions = self.predict(torch.cat( 447 | (h_entity, t_entity, torch.abs(h_entity - t_entity), torch.mul(h_entity, t_entity), global_info, path_info), 448 | dim=-1)) 449 | 450 | return predictions 451 | 452 | 453 | class Attention(nn.Module): 454 | def __init__(self, src_size, trg_size): 455 | super().__init__() 456 | self.W = nn.Bilinear(src_size, trg_size, 1) 457 | self.softmax = nn.Softmax(dim=-1) 458 | 459 | def forward(self, src, trg, attention_mask=None): 460 | ''' 461 | src: [src_size] 462 | trg: [middle_node, trg_size] 463 | ''' 464 | 465 | score = self.W(src.unsqueeze(0).expand(trg.size(0), -1), trg) 466 | score = self.softmax(score) 467 | value = torch.mm(score.permute(1, 0), trg) 468 | 469 | return score.squeeze(0), value.squeeze(0) 470 | 471 | 472 | class BiLSTM(nn.Module): 473 | def __init__(self, input_size, config): 474 | super().__init__() 475 | self.config = config 476 | self.lstm = nn.LSTM(input_size=input_size, hidden_size=config.lstm_hidden_size, 477 | num_layers=config.nlayers, batch_first=True, 478 | bidirectional=True) 479 | self.in_dropout = nn.Dropout(config.dropout) 480 | self.out_dropout = nn.Dropout(config.dropout) 481 | 482 | def forward(self, src, src_lengths): 483 | ''' 484 | src: [batch_size, slen, input_size] 485 | src_lengths: [batch_size] 486 | ''' 487 | 488 | self.lstm.flatten_parameters() 489 | bsz, slen, input_size = src.size() 490 | 491 | src = self.in_dropout(src) 492 | 493 | new_src_lengths, sort_index = torch.sort(src_lengths, dim=-1, descending=True) 494 | new_src = torch.index_select(src, dim=0, index=sort_index) 495 | 496 | packed_src = nn.utils.rnn.pack_padded_sequence(new_src, new_src_lengths, batch_first=True, enforce_sorted=True) 497 | packed_outputs, (src_h_t, src_c_t) = self.lstm(packed_src) 498 | 499 | outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True, 500 | padding_value=self.config.word_pad) 501 | 502 | unsort_index = torch.argsort(sort_index) 503 | outputs = torch.index_select(outputs, dim=0, index=unsort_index) 504 | 505 | src_h_t = src_h_t.view(self.config.nlayers, 2, bsz, self.config.lstm_hidden_size) 506 | src_c_t = src_c_t.view(self.config.nlayers, 2, bsz, self.config.lstm_hidden_size) 507 | output_h_t = torch.cat((src_h_t[-1, 0], src_h_t[-1, 1]), dim=-1) 508 | output_c_t = torch.cat((src_c_t[-1, 0], src_c_t[-1, 1]), dim=-1) 509 | output_h_t = torch.index_select(output_h_t, dim=0, index=unsort_index) 510 | output_c_t = torch.index_select(output_c_t, dim=0, index=unsort_index) 511 | 512 | outputs = self.out_dropout(outputs) 513 | output_h_t = self.out_dropout(output_h_t) 514 | output_c_t = self.out_dropout(output_c_t) 515 | 516 | return outputs, (output_h_t, output_c_t) 517 | 518 | 519 | class RelGraphConvLayer(nn.Module): 520 | r"""Relational graph convolution layer. 521 | Parameters 522 | ---------- 523 | in_feat : int 524 | Input feature size. 525 | out_feat : int 526 | Output feature size. 527 | rel_names : list[str] 528 | Relation names. 529 | num_bases : int, optional 530 | Number of bases. If is none, use number of relations. Default: None. 531 | weight : bool, optional 532 | True if a linear layer is applied after message passing. Default: True 533 | bias : bool, optional 534 | True if bias is added. Default: True 535 | activation : callable, optional 536 | Activation function. Default: None 537 | self_loop : bool, optional 538 | True to include self loop message. Default: False 539 | dropout : float, optional 540 | Dropout rate. Default: 0.0 541 | """ 542 | 543 | def __init__(self, 544 | in_feat, 545 | out_feat, 546 | rel_names, 547 | num_bases, 548 | *, 549 | weight=True, 550 | bias=True, 551 | activation=None, 552 | self_loop=False, 553 | dropout=0.0): 554 | super(RelGraphConvLayer, self).__init__() 555 | self.in_feat = in_feat 556 | self.out_feat = out_feat 557 | self.rel_names = rel_names 558 | self.num_bases = num_bases 559 | self.bias = bias 560 | self.activation = activation 561 | self.self_loop = self_loop 562 | 563 | self.conv = dglnn.HeteroGraphConv({ 564 | rel: dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) 565 | for rel in rel_names 566 | }) 567 | 568 | self.use_weight = weight 569 | self.use_basis = num_bases < len(self.rel_names) and weight 570 | if self.use_weight: 571 | if self.use_basis: 572 | self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) 573 | else: 574 | self.weight = nn.Parameter(torch.Tensor(len(self.rel_names), in_feat, out_feat)) 575 | nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) 576 | 577 | # bias 578 | if bias: 579 | self.h_bias = nn.Parameter(torch.Tensor(out_feat)) 580 | nn.init.zeros_(self.h_bias) 581 | 582 | # weight for self loop 583 | if self.self_loop: 584 | self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat)) 585 | nn.init.xavier_uniform_(self.loop_weight, 586 | gain=nn.init.calculate_gain('relu')) 587 | 588 | self.dropout = nn.Dropout(dropout) 589 | 590 | def forward(self, g, inputs): 591 | """Forward computation 592 | Parameters 593 | ---------- 594 | g : DGLHeteroGraph 595 | Input graph. 596 | inputs : dict[str, torch.Tensor] 597 | Node feature for each node type. 598 | Returns 599 | ------- 600 | dict[str, torch.Tensor] 601 | New node features for each node type. 602 | """ 603 | g = g.local_var() 604 | if self.use_weight: 605 | weight = self.basis() if self.use_basis else self.weight 606 | wdict = {self.rel_names[i]: {'weight': w.squeeze(0)} 607 | for i, w in enumerate(torch.split(weight, 1, dim=0))} 608 | else: 609 | wdict = {} 610 | hs = self.conv(g, inputs, mod_kwargs=wdict) 611 | 612 | def _apply(ntype, h): 613 | if self.self_loop: 614 | h = h + torch.matmul(inputs[ntype], self.loop_weight) 615 | if self.bias: 616 | h = h + self.h_bias 617 | if self.activation: 618 | h = self.activation(h) 619 | return self.dropout(h) 620 | 621 | return {ntype: _apply(ntype, h) for ntype, h in hs.items()} 622 | 623 | 624 | class RelEdgeLayer(nn.Module): 625 | def __init__(self, 626 | node_feat, 627 | edge_feat, 628 | activation, 629 | dropout=0.0): 630 | super(RelEdgeLayer, self).__init__() 631 | self.node_feat = node_feat 632 | self.edge_feat = edge_feat 633 | self.activation = activation 634 | self.dropout = nn.Dropout(dropout) 635 | self.mapping = nn.Linear(node_feat * 2, edge_feat) 636 | 637 | def forward(self, g, inputs): 638 | # g = g.local_var() 639 | 640 | g.ndata['h'] = inputs # [total_mention_num, node_feat] 641 | g.apply_edges(lambda edges: { 642 | 'h': self.dropout(self.activation(self.mapping(torch.cat((edges.src['h'], edges.dst['h']), dim=-1))))}) 643 | g.ndata.pop('h') 644 | -------------------------------------------------------------------------------- /code/run_GAIN_BERT.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export CUDA_VISIBLE_DEVICES=$1 3 | 4 | # -------------------GAIN_BERT_base Training Shell Script-------------------- 5 | 6 | if true; then 7 | model_name=GAIN_BERT_base 8 | lr=0.001 9 | batch_size=5 10 | test_batch_size=16 11 | epoch=300 12 | test_epoch=5 13 | log_step=20 14 | save_model_freq=3 15 | negativa_alpha=4 16 | 17 | nohup python3 -u train.py \ 18 | --train_set ../data/train_annotated.json \ 19 | --train_set_save ../data/prepro_data/train_BERT.pkl \ 20 | --dev_set ../data/dev.json \ 21 | --dev_set_save ../data/prepro_data/dev_BERT.pkl \ 22 | --test_set ../data/test.json \ 23 | --test_set_save ../data/prepro_data/test_BERT.pkl \ 24 | --use_model bert \ 25 | --model_name ${model_name} \ 26 | --lr ${lr} \ 27 | --batch_size ${batch_size} \ 28 | --test_batch_size ${test_batch_size} \ 29 | --epoch ${epoch} \ 30 | --test_epoch ${test_epoch} \ 31 | --log_step ${log_step} \ 32 | --save_model_freq ${save_model_freq} \ 33 | --negativa_alpha ${negativa_alpha} \ 34 | --gcn_dim 808 \ 35 | --gcn_layers 2 \ 36 | --bert_hid_size 768 \ 37 | --bert_path ../PLM/bert-base-uncased \ 38 | --use_entity_type \ 39 | --use_entity_id \ 40 | --dropout 0.6 \ 41 | --activation relu \ 42 | --coslr \ 43 | >logs/train_${model_name}.log 2>&1 & 44 | fi 45 | 46 | # -------------------GAIN_BERT_large Training Shell Script-------------------- 47 | 48 | if false; then 49 | model_name=GAIN_BERT_large 50 | lr=0.001 51 | batch_size=5 52 | test_batch_size=16 53 | epoch=300 54 | test_epoch=5 55 | log_step=20 56 | save_model_freq=3 57 | negativa_alpha=4 58 | 59 | nohup python3 -u train.py \ 60 | --train_set ../data/train_annotated.json \ 61 | --train_set_save ../data/prepro_data/train_BERT.pkl \ 62 | --dev_set ../data/dev.json \ 63 | --dev_set_save ../data/prepro_data/dev_BERT.pkl \ 64 | --test_set ../data/test.json \ 65 | --test_set_save ../data/prepro_data/test_BERT.pkl \ 66 | --use_model bert \ 67 | --model_name ${model_name} \ 68 | --lr ${lr} \ 69 | --batch_size ${batch_size} \ 70 | --test_batch_size ${test_batch_size} \ 71 | --epoch ${epoch} \ 72 | --test_epoch ${test_epoch} \ 73 | --log_step ${log_step} \ 74 | --save_model_freq ${save_model_freq} \ 75 | --negativa_alpha ${negativa_alpha} \ 76 | --gcn_dim 1064 \ 77 | --gcn_layers 2 \ 78 | --bert_hid_size 1024 \ 79 | --bert_path ../PLM/bert-large-uncased \ 80 | --use_entity_type \ 81 | --use_entity_id \ 82 | --dropout 0.6 \ 83 | --activation relu \ 84 | --coslr \ 85 | >logs/train_${model_name}.log 2>&1 & 86 | fi 87 | 88 | # -------------------additional options-------------------- 89 | 90 | # option below is used to resume training, it should be add into the shell scripts above 91 | # --pretrain_model checkpoint/GAIN_BERT_base_10.pt \ 92 | -------------------------------------------------------------------------------- /code/run_GAIN_GloVe.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export CUDA_VISIBLE_DEVICES=$1 3 | 4 | # -------------------GAIN_GloVe Training Shell Script-------------------- 5 | 6 | model_name=GAIN_GloVe 7 | lr=0.001 8 | batch_size=32 9 | test_batch_size=16 10 | epoch=300 11 | test_epoch=5 12 | log_step=20 13 | save_model_freq=3 14 | negativa_alpha=4 15 | weight_decay=0.0001 16 | 17 | nohup python3 -u train.py \ 18 | --train_set ../data/train_annotated.json \ 19 | --train_set_save ../data/prepro_data/train_GloVe.pkl \ 20 | --dev_set ../data/dev.json \ 21 | --dev_set_save ../data/prepro_data/dev_GloVe.pkl \ 22 | --test_set ../data/test.json \ 23 | --test_set_save ../data/prepro_data/test_GloVe.pkl \ 24 | --use_model bilstm \ 25 | --model_name ${model_name} \ 26 | --lr ${lr} \ 27 | --batch_size ${batch_size} \ 28 | --test_batch_size ${test_batch_size} \ 29 | --epoch ${epoch} \ 30 | --test_epoch ${test_epoch} \ 31 | --log_step ${log_step} \ 32 | --save_model_freq ${save_model_freq} \ 33 | --negativa_alpha ${negativa_alpha} \ 34 | --gcn_dim 512 \ 35 | --gcn_layers 2 \ 36 | --lstm_hidden_size 256 \ 37 | --use_entity_type \ 38 | --use_entity_id \ 39 | --word_emb_size 100 \ 40 | --finetune_word \ 41 | --pre_train_word \ 42 | --dropout 0.6 \ 43 | --activation relu \ 44 | --weight_decay ${weight_decay} \ 45 | >logs/train_${model_name}.log 2>&1 & 46 | 47 | # -------------------additional options-------------------- 48 | 49 | # option below is used to resume training, it should be add into the shell scripts above 50 | # --pretrain_model checkpoint/GAIN_GloVe_10.pt \ 51 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | import sklearn.metrics 2 | import torch 3 | 4 | from config import * 5 | from data import DGLREDataset, DGLREDataloader, BERTDGLREDataset 6 | from models.GAIN import GAIN_GloVe, GAIN_BERT 7 | from utils import get_cuda, logging, print_params 8 | 9 | 10 | # for ablation 11 | # from models.GCNRE_nomention import GAIN_GloVe, GAIN_BERT 12 | 13 | 14 | def test(model, dataloader, modelname, id2rel, input_theta=-1, output=False, is_test=False, test_prefix='dev', 15 | relation_num=97, ours=False): 16 | # ours: inter-sentence F1 in LSR 17 | 18 | total_recall_ignore = 0 19 | 20 | test_result = [] 21 | total_recall = 0 22 | total_steps = len(dataloader) 23 | for cur_i, d in enumerate(dataloader): 24 | print('step: {}/{}'.format(cur_i, total_steps)) 25 | 26 | with torch.no_grad(): 27 | labels = d['labels'] 28 | L_vertex = d['L_vertex'] 29 | titles = d['titles'] 30 | indexes = d['indexes'] 31 | overlaps = d['overlaps'] 32 | 33 | predictions = model(words=d['context_idxs'], 34 | src_lengths=d['context_word_length'], 35 | mask=d['context_word_mask'], 36 | entity_type=d['context_ner'], 37 | entity_id=d['context_pos'], 38 | mention_id=d['context_mention'], 39 | distance=None, 40 | entity2mention_table=d['entity2mention_table'], 41 | graphs=d['graphs'], 42 | h_t_pairs=d['h_t_pairs'], 43 | relation_mask=None, 44 | path_table=d['path_table'], 45 | entity_graphs=d['entity_graphs'], 46 | ht_pair_distance=d['ht_pair_distance'] 47 | ) 48 | 49 | predict_re = torch.sigmoid(predictions) 50 | 51 | predict_re = predict_re.data.cpu().numpy() 52 | 53 | for i in range(len(labels)): 54 | label = labels[i] 55 | L = L_vertex[i] 56 | title = titles[i] 57 | index = indexes[i] 58 | overlap = overlaps[i] 59 | total_recall += len(label) 60 | 61 | for l in label.values(): 62 | if not l: 63 | total_recall_ignore += 1 64 | 65 | j = 0 66 | 67 | for h_idx in range(L): 68 | for t_idx in range(L): 69 | if h_idx != t_idx: 70 | for r in range(1, relation_num): 71 | rel_ins = (h_idx, t_idx, r) 72 | intrain = label.get(rel_ins, False) 73 | 74 | if (ours and (h_idx, t_idx) in overlap) or not ours: 75 | test_result.append((rel_ins in label, float(predict_re[i, j, r]), intrain, 76 | title, id2rel[r], index, h_idx, t_idx, r)) 77 | 78 | j += 1 79 | 80 | test_result.sort(key=lambda x: x[1], reverse=True) 81 | 82 | if ours: 83 | total_recall = 0 84 | for item in test_result: 85 | if item[0]: 86 | total_recall += 1 87 | 88 | pr_x = [] 89 | pr_y = [] 90 | correct = 0 91 | w = 0 92 | 93 | if total_recall == 0: 94 | total_recall = 1 95 | 96 | for i, item in enumerate(test_result): 97 | correct += item[0] 98 | pr_y.append(float(correct) / (i + 1)) # Precision 99 | pr_x.append(float(correct) / total_recall) # Recall 100 | if item[1] > input_theta: 101 | w = i 102 | 103 | pr_x = np.asarray(pr_x, dtype='float32') 104 | pr_y = np.asarray(pr_y, dtype='float32') 105 | f1_arr = (2 * pr_x * pr_y / (pr_x + pr_y + 1e-20)) 106 | f1 = f1_arr.max() 107 | f1_pos = f1_arr.argmax() 108 | theta = test_result[f1_pos][1] 109 | 110 | if input_theta == -1: 111 | w = f1_pos 112 | input_theta = theta 113 | 114 | auc = sklearn.metrics.auc(x=pr_x, y=pr_y) 115 | if not is_test: 116 | logging('ALL : Theta {:3.4f} | F1 {:3.4f} | AUC {:3.4f}'.format(theta, f1, auc)) 117 | else: 118 | logging( 119 | 'ma_f1 {:3.4f} | input_theta {:3.4f} test_result P {:3.4f} test_result R {:3.4f} test_result F1 {:3.4f} | AUC {:3.4f}' \ 120 | .format(f1, input_theta, pr_y[w], pr_x[w], f1_arr[w], auc)) 121 | 122 | if output: 123 | # output = [x[-4:] for x in test_result[:w+1]] 124 | output = [{'index': x[-4], 'h_idx': x[-3], 't_idx': x[-2], 'r_idx': x[-1], 125 | 'score': x[1], 'intrain': x[2], 126 | 'r': x[-5], 'title': x[-6]} for x in test_result[:w + 1]] 127 | json.dump(output, open(test_prefix + "_index.json", "w")) 128 | 129 | pr_x = [] 130 | pr_y = [] 131 | correct = correct_in_train = 0 132 | w = 0 133 | 134 | # https://github.com/thunlp/DocRED/issues/47 135 | for i, item in enumerate(test_result): 136 | correct += item[0] 137 | if item[0] & item[2]: 138 | correct_in_train += 1 139 | if correct_in_train == correct: 140 | p = 0 141 | else: 142 | p = float(correct - correct_in_train) / (i + 1 - correct_in_train) 143 | pr_y.append(p) 144 | pr_x.append(float(correct) / total_recall) 145 | 146 | if item[1] > input_theta: 147 | w = i 148 | 149 | pr_x = np.asarray(pr_x, dtype='float32') 150 | pr_y = np.asarray(pr_y, dtype='float32') 151 | f1_arr = (2 * pr_x * pr_y / (pr_x + pr_y + 1e-20)) 152 | f1 = f1_arr.max() 153 | 154 | auc = sklearn.metrics.auc(x=pr_x, y=pr_y) 155 | 156 | logging( 157 | 'Ignore ma_f1 {:3.4f} | inhput_theta {:3.4f} test_result P {:3.4f} test_result R {:3.4f} test_result F1 {:3.4f} | AUC {:3.4f}' \ 158 | .format(f1, input_theta, pr_y[w], pr_x[w], f1_arr[w], auc)) 159 | 160 | return f1, auc, pr_x, pr_y 161 | 162 | 163 | if __name__ == '__main__': 164 | print('processId:', os.getpid()) 165 | print('prarent processId:', os.getppid()) 166 | opt = get_opt() 167 | print(json.dumps(opt.__dict__, indent=4)) 168 | opt.data_word_vec = word2vec 169 | 170 | if opt.use_model == 'bert': 171 | # datasets 172 | train_set = BERTDGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train', 173 | opt=opt) 174 | test_set = BERTDGLREDataset(opt.test_set, opt.test_set_save, word2id, ner2id, rel2id, dataset_type='test', 175 | instance_in_train=train_set.instance_in_train, opt=opt) 176 | 177 | test_loader = DGLREDataloader(test_set, batch_size=opt.test_batch_size, dataset_type='test') 178 | 179 | model = GAIN_BERT(opt) 180 | elif opt.use_model == 'bilstm': 181 | # datasets 182 | train_set = DGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train', 183 | opt=opt) 184 | test_set = DGLREDataset(opt.test_set, opt.test_set_save, word2id, ner2id, rel2id, dataset_type='test', 185 | instance_in_train=train_set.instance_in_train, opt=opt) 186 | 187 | test_loader = DGLREDataloader(test_set, batch_size=opt.test_batch_size, dataset_type='test') 188 | 189 | model = GAIN_GloVe(opt) 190 | else: 191 | assert 1 == 2, 'please choose a model from [bert, bilstm].' 192 | 193 | import gc 194 | 195 | del train_set 196 | gc.collect() 197 | 198 | # print(model.parameters) 199 | print_params(model) 200 | 201 | start_epoch = 1 202 | pretrain_model = opt.pretrain_model 203 | lr = opt.lr 204 | model_name = opt.model_name 205 | 206 | if pretrain_model != '': 207 | chkpt = torch.load(pretrain_model, map_location=torch.device('cpu')) 208 | model.load_state_dict(chkpt['checkpoint']) 209 | logging('load checkpoint from {}'.format(pretrain_model)) 210 | else: 211 | assert 1 == 2, 'please provide checkpoint to evaluate.' 212 | 213 | model = get_cuda(model) 214 | model.eval() 215 | 216 | f1, auc, pr_x, pr_y = test(model, test_loader, model_name, id2rel=id2rel, 217 | input_theta=opt.input_theta, output=True, test_prefix='test', is_test=True, ours=False) 218 | print('finished') 219 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | 9 | from config import * 10 | from data import DGLREDataset, DGLREDataloader, BERTDGLREDataset 11 | from models.GAIN import GAIN_GloVe, GAIN_BERT 12 | from test import test 13 | from utils import Accuracy, get_cuda, logging, print_params 14 | 15 | matplotlib.use('Agg') 16 | 17 | 18 | # for ablation 19 | # from models.GAIN_nomention import GAIN_GloVe, GAIN_BERT 20 | 21 | def train(opt): 22 | if opt.use_model == 'bert': 23 | # datasets 24 | train_set = BERTDGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train', 25 | opt=opt) 26 | dev_set = BERTDGLREDataset(opt.dev_set, opt.dev_set_save, word2id, ner2id, rel2id, dataset_type='dev', 27 | instance_in_train=train_set.instance_in_train, opt=opt) 28 | 29 | # dataloaders 30 | train_loader = DGLREDataloader(train_set, batch_size=opt.batch_size, shuffle=True, 31 | negativa_alpha=opt.negativa_alpha) 32 | dev_loader = DGLREDataloader(dev_set, batch_size=opt.test_batch_size, dataset_type='dev') 33 | 34 | model = GAIN_BERT(opt) 35 | 36 | elif opt.use_model == 'bilstm': 37 | # datasets 38 | train_set = DGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train', 39 | opt=opt) 40 | dev_set = DGLREDataset(opt.dev_set, opt.dev_set_save, word2id, ner2id, rel2id, dataset_type='dev', 41 | instance_in_train=train_set.instance_in_train, opt=opt) 42 | 43 | # dataloaders 44 | train_loader = DGLREDataloader(train_set, batch_size=opt.batch_size, shuffle=True, 45 | negativa_alpha=opt.negativa_alpha) 46 | dev_loader = DGLREDataloader(dev_set, batch_size=opt.test_batch_size, dataset_type='dev') 47 | 48 | model = GAIN_GloVe(opt) 49 | else: 50 | assert 1 == 2, 'please choose a model from [bert, bilstm].' 51 | 52 | print(model.parameters) 53 | print_params(model) 54 | 55 | start_epoch = 1 56 | pretrain_model = opt.pretrain_model 57 | lr = opt.lr 58 | model_name = opt.model_name 59 | 60 | if pretrain_model != '': 61 | chkpt = torch.load(pretrain_model, map_location=torch.device('cpu')) 62 | model.load_state_dict(chkpt['checkpoint']) 63 | logging('load model from {}'.format(pretrain_model)) 64 | start_epoch = chkpt['epoch'] + 1 65 | lr = chkpt['lr'] 66 | logging('resume from epoch {} with lr {}'.format(start_epoch, lr)) 67 | else: 68 | logging('training from scratch with lr {}'.format(lr)) 69 | 70 | model = get_cuda(model) 71 | 72 | if opt.use_model == 'bert': 73 | bert_param_ids = list(map(id, model.bert.parameters())) 74 | base_params = filter(lambda p: p.requires_grad and id(p) not in bert_param_ids, model.parameters()) 75 | 76 | optimizer = optim.AdamW([ 77 | {'params': model.bert.parameters(), 'lr': lr * 0.01}, 78 | {'params': base_params, 'weight_decay': opt.weight_decay} 79 | ], lr=lr) 80 | else: 81 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, 82 | weight_decay=opt.weight_decay) 83 | 84 | BCE = nn.BCEWithLogitsLoss(reduction='none') 85 | 86 | if opt.coslr: 87 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(opt.epoch // 4) + 1) 88 | 89 | checkpoint_dir = opt.checkpoint_dir 90 | if not os.path.exists(checkpoint_dir): 91 | os.mkdir(checkpoint_dir) 92 | fig_result_dir = opt.fig_result_dir 93 | if not os.path.exists(fig_result_dir): 94 | os.mkdir(fig_result_dir) 95 | 96 | best_ign_auc = 0.0 97 | best_ign_f1 = 0.0 98 | best_epoch = 0 99 | 100 | model.train() 101 | 102 | global_step = 0 103 | total_loss = 0 104 | 105 | plt.xlabel('Recall') 106 | plt.ylabel('Precision') 107 | plt.ylim(0.0, 1.0) 108 | plt.xlim(0.0, 1.0) 109 | plt.title('Precision-Recall') 110 | plt.grid(True) 111 | 112 | acc_NA, acc_not_NA, acc_total = Accuracy(), Accuracy(), Accuracy() 113 | logging('begin..') 114 | 115 | for epoch in range(start_epoch, opt.epoch + 1): 116 | start_time = time.time() 117 | for acc in [acc_NA, acc_not_NA, acc_total]: 118 | acc.clear() 119 | 120 | for ii, d in enumerate(train_loader): 121 | relation_multi_label = d['relation_multi_label'] 122 | relation_mask = d['relation_mask'] 123 | relation_label = d['relation_label'] 124 | 125 | predictions = model(words=d['context_idxs'], 126 | src_lengths=d['context_word_length'], 127 | mask=d['context_word_mask'], 128 | entity_type=d['context_ner'], 129 | entity_id=d['context_pos'], 130 | mention_id=d['context_mention'], 131 | distance=None, 132 | entity2mention_table=d['entity2mention_table'], 133 | graphs=d['graphs'], 134 | h_t_pairs=d['h_t_pairs'], 135 | relation_mask=relation_mask, 136 | path_table=d['path_table'], 137 | entity_graphs=d['entity_graphs'], 138 | ht_pair_distance=d['ht_pair_distance'] 139 | ) 140 | loss = torch.sum(BCE(predictions, relation_multi_label) * relation_mask.unsqueeze(2)) / ( 141 | opt.relation_nums * torch.sum(relation_mask)) 142 | 143 | optimizer.zero_grad() 144 | loss.backward() 145 | 146 | if opt.clip != -1: 147 | nn.utils.clip_grad_value_(model.parameters(), opt.clip) 148 | optimizer.step() 149 | if opt.coslr: 150 | scheduler.step(epoch) 151 | 152 | output = torch.argmax(predictions, dim=-1) 153 | output = output.data.cpu().numpy() 154 | relation_label = relation_label.data.cpu().numpy() 155 | 156 | for i in range(output.shape[0]): 157 | for j in range(output.shape[1]): 158 | label = relation_label[i][j] 159 | if label < 0: 160 | break 161 | 162 | is_correct = (output[i][j] == label) 163 | if label == 0: 164 | acc_NA.add(is_correct) 165 | else: 166 | acc_not_NA.add(is_correct) 167 | 168 | acc_total.add(is_correct) 169 | 170 | global_step += 1 171 | total_loss += loss.item() 172 | 173 | log_step = opt.log_step 174 | if global_step % log_step == 0: 175 | cur_loss = total_loss / log_step 176 | elapsed = time.time() - start_time 177 | logging( 178 | '| epoch {:2d} | step {:4d} | ms/b {:5.2f} | train loss {:5.3f} | NA acc: {:4.2f} | not NA acc: {:4.2f} | tot acc: {:4.2f} '.format( 179 | epoch, global_step, elapsed * 1000 / log_step, cur_loss * 1000, acc_NA.get(), acc_not_NA.get(), 180 | acc_total.get())) 181 | total_loss = 0 182 | start_time = time.time() 183 | 184 | if epoch % opt.test_epoch == 0: 185 | logging('-' * 89) 186 | eval_start_time = time.time() 187 | model.eval() 188 | ign_f1, ign_auc, pr_x, pr_y = test(model, dev_loader, model_name, id2rel=id2rel) 189 | model.train() 190 | logging('| epoch {:3d} | time: {:5.2f}s'.format(epoch, time.time() - eval_start_time)) 191 | logging('-' * 89) 192 | 193 | if ign_f1 > best_ign_f1: 194 | best_ign_f1 = ign_f1 195 | best_ign_auc = ign_auc 196 | best_epoch = epoch 197 | path = os.path.join(checkpoint_dir, model_name + '_best.pt') 198 | torch.save({ 199 | 'epoch': epoch, 200 | 'checkpoint': model.state_dict(), 201 | 'lr': lr, 202 | 'best_ign_f1': ign_f1, 203 | 'best_ign_auc': ign_auc, 204 | 'best_epoch': epoch 205 | }, path) 206 | 207 | plt.plot(pr_x, pr_y, lw=2, label=str(epoch)) 208 | plt.legend(loc="upper right") 209 | plt.savefig(os.path.join(fig_result_dir, model_name)) 210 | 211 | if epoch % opt.save_model_freq == 0: 212 | path = os.path.join(checkpoint_dir, model_name + '_{}.pt'.format(epoch)) 213 | torch.save({ 214 | 'epoch': epoch, 215 | 'lr': lr, 216 | 'checkpoint': model.state_dict() 217 | }, path) 218 | 219 | print("Finish training") 220 | print("Best epoch = %d | Best Ign F1 = %f" % (best_epoch, best_ign_f1)) 221 | print("Storing best result...") 222 | print("Finish storing") 223 | 224 | 225 | if __name__ == '__main__': 226 | print('processId:', os.getpid()) 227 | print('prarent processId:', os.getppid()) 228 | opt = get_opt() 229 | print(json.dumps(opt.__dict__, indent=4)) 230 | opt.data_word_vec = word2vec 231 | train(opt) 232 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def get_cuda(tensor): 8 | if torch.cuda.is_available(): 9 | return tensor.cuda() 10 | return tensor 11 | 12 | 13 | def logging(s): 14 | print(datetime.now(), s) 15 | 16 | 17 | class Accuracy(object): 18 | def __init__(self): 19 | self.correct = 0 20 | self.total = 0 21 | 22 | def add(self, is_correct): 23 | self.total += 1 24 | if is_correct: 25 | self.correct += 1 26 | 27 | def get(self): 28 | if self.total == 0: 29 | return 0.0 30 | else: 31 | return float(self.correct) / self.total 32 | 33 | def clear(self): 34 | self.correct = 0 35 | self.total = 0 36 | 37 | 38 | def print_params(model): 39 | print('total parameters:', sum([np.prod(list(p.size())) for p in model.parameters() if p.requires_grad])) 40 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | ``` 4 | Data Format: 5 | { 6 | 'title', 7 | 'sents': [ 8 | [word in sent 0], 9 | [word in sent 1] 10 | ] 11 | 'vertexSet': [ 12 | [ 13 | { 'name': mention_name, 14 | 'sent_id': mention in which sentence, 15 | 'pos': postion of mention in a sentence, 16 | 'type': NER_type} 17 | {anthor mention} 18 | ], 19 | [anthoer entity] 20 | ] 21 | 'labels': [ 22 | { 23 | 'h': idx of head entity in vertexSet, 24 | 't': idx of tail entity in vertexSet, 25 | 'r': relation, 26 | 'evidence': evidence sentences' id 27 | } 28 | ] 29 | } 30 | 31 | ``` 32 | 33 | Please submit the test set result to Codalab. 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /data/prepro_data/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/data/prepro_data/.placeholder -------------------------------------------------------------------------------- /pictures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/pictures/model.png -------------------------------------------------------------------------------- /pictures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamInvoker/GAIN/178344cf00789c7ba05cfe4dca90df4b17c2caa9/pictures/results.png --------------------------------------------------------------------------------