├── .gitignore ├── LICENSE ├── README.md ├── config ├── Config.py └── __init__.py ├── draw_plot.py ├── gen_data.py ├── models ├── CNN_ATT.py ├── CNN_AVE.py ├── CNN_ONE.py ├── Model.py ├── PCNN_ATT.py ├── PCNN_AVE.py ├── PCNN_ONE.py └── __init__.py ├── networks ├── __init__.py ├── classifier.py ├── embedding.py ├── encoder.py └── selector.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Shulin Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenNRE-PyTorch 2 | 3 | An open-source framework for neural relation extraction implemented in PyTorch. 4 | 5 | Contributed by [Shulin Cao](https://github.com/ShulinCao), [Tianyu Gao](https://github.com/gaotianyu1350), [Xu Han](https://github.com/THUCSTHanxu13), [Lumin Tang](https://github.com/Tsingularity), [Yankai Lin](https://github.com/Mrlyk423), [Zhiyuan Liu](http://nlp.csai.tsinghua.edu.cn/~lzy/) 6 | 7 | ## Overview 8 | 9 | It is a PyTorch-based framwork for easily building relation extraction models. We divide the pipeline of relation extraction into four parts, which are embedding, encoder, selector and classifier. For each part we have implemented several methods. 10 | 11 | * Embedding 12 | * Word embedding 13 | * Position embedding 14 | * Concatenation method 15 | * Encoder 16 | * PCNN 17 | * CNN 18 | * Selector 19 | * Attention 20 | * Maximum 21 | * Average 22 | * Classifier 23 | * Softmax loss function 24 | * Output 25 | 26 | All those methods could be combined freely. 27 | 28 | We also provide fast training and testing codes. You could change hyper-parameters or appoint model architectures by using Python arguments. A plotting method is also in the package. 29 | 30 | This project is under MIT license. 31 | 32 | ## Requirements 33 | 34 | - Python (>=2.7) 35 | - PyTorch (==0.3.1) 36 | - CUDA (>=8.0) 37 | - Matplotlib (>=2.0.0) 38 | - scikit-learn (>=0.18) 39 | 40 | ## Installation 41 | 42 | 1. Install PyTorch 43 | 2. Clone the OpenNRE repository: 44 | ```bash 45 | git clone https://github.com/ShulinCao/OpenNRE-PyTorch 46 | ``` 47 | 3. Download NYT dataset from [Google Drive](https://drive.google.com/file/d/1g95gbMUsGfeEmihZSb0kXPbMTuRA4lid/view?usp=sharing) 48 | 4. Extract dataset to `./raw_data` 49 | ``` 50 | unzip raw_data.zip 51 | ``` 52 | ## Dataset 53 | 54 | ### NYT10 Dataset 55 | 56 | NYT10 is a distantly supervised dataset originally released by the paper "Sebastian Riedel, Limin Yao, and Andrew McCallum. Modeling relations and their mentions without labeled text.". Here is the download [link](http://iesl.cs.umass.edu/riedel/ecml/) for the original data. 57 | You can download the NYT10 dataset from [Google Drive](https://drive.google.com/file/d/1g95gbMUsGfeEmihZSb0kXPbMTuRA4lid/view?usp=sharing). And the data details are as follows. 58 | 59 | ### Training Data & Testing Data 60 | 61 | Training data file and testing data file, containing sentences and their corresponding entity pairs and relations, should be in the following format 62 | 63 | ``` 64 | [ 65 | { 66 | 'sentence': 'Bill Gates is the founder of Microsoft .', 67 | 'head': {'word': 'Bill Gates', 'id': 'm.03_3d', ...(other information)}, 68 | 'tail': {'word': 'Microsoft', 'id': 'm.07dfk', ...(other information)}, 69 | 'relation': 'founder' 70 | }, 71 | ... 72 | ] 73 | ``` 74 | 75 | **IMPORTANT**: In the sentence part, words and punctuations should be separated by blank spaces. 76 | 77 | ### Word Embedding Data 78 | 79 | Word embedding data is used to initialize word embedding in the networks, and should be in the following format 80 | 81 | ``` 82 | [ 83 | {'word': 'the', 'vec': [0.418, 0.24968, ...]}, 84 | {'word': ',', 'vec': [0.013441, 0.23682, ...]}, 85 | ... 86 | ] 87 | ``` 88 | 89 | ### Relation-ID Mapping Data 90 | 91 | This file indicates corresponding IDs for relations to make sure during each training and testing period, the same ID means the same relation. Its format is as follows 92 | 93 | ``` 94 | { 95 | 'NA': 0, 96 | 'relation_1': 1, 97 | 'relation_2': 2, 98 | ... 99 | } 100 | ``` 101 | 102 | **IMPORTANT**: Make sure the ID of `NA` is always 0. 103 | 104 | ## Quick Start 105 | 106 | ### Process Data 107 | 108 | ```bash 109 | python gen_data.py 110 | ``` 111 | The processed data will be stored in `./data` 112 | 113 | ### Train Model 114 | ``` 115 | python train.py --model_name pcnn_att 116 | ``` 117 | 118 | The arg `model_name` appoints model architecture, and `pcnn_att` is the name of one of our models. All available models are in `./models`. About other arguments please refer to `./train.py`. Once you start training, all checkpoints are stored in `./checkpoint`. 119 | 120 | ### Test Model 121 | ```bash 122 | python test.py --model_name pcnn_att 123 | ``` 124 | 125 | Same usage as training. When finishing testing, the best checkpoint's corresponding pr-curve data will be stored in `./test_result`. 126 | 127 | ### Plot 128 | ```bash 129 | python draw_plot.py PCNN_ATT 130 | ``` 131 | 132 | The plot will be saved as `./test_result/pr_curve.png`. You could appoint several models in the arguments, like `python draw_plot.py PCNN_ATT PCNN_ONE PCNN_AVE`, as long as there are these models' results in `./test_result`. 133 | 134 | ## Build Your Own Model 135 | 136 | Not only could you train and test existing models in our package, you could also build your own model or add methods to the four basic modules. When adding a new model, you could create a python file in `./models` having the same name as the model and implement it like following: 137 | 138 | ```python 139 | import torch 140 | import torch.autograd as autograd 141 | import torch.nn as nn 142 | import torch.nn.functional as F 143 | import torch.optim as optim 144 | from torch.autograd import Variable 145 | from networks.embedding import * 146 | from networks.encoder import * 147 | from networks.selector import * 148 | from networks.classifier import * 149 | from .Model import Model 150 | class PCNN_ATT(Model): 151 | def __init__(self, config): 152 | super(PCNN_ATT, self).__init__(config) 153 | self.encoder = PCNN(config) 154 | self.selector = Attention(config, config.hidden_size * 3) 155 | ``` 156 | 157 | Then you can train, test and plot! 158 | 159 | -------------------------------------------------------------------------------- /config/Config.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | import numpy as np 7 | import os 8 | import time 9 | import datetime 10 | import json 11 | import sys 12 | import sklearn.metrics 13 | from tqdm import tqdm 14 | 15 | def to_var(x): 16 | return Variable(torch.from_numpy(x).cuda()) 17 | 18 | class Accuracy(object): 19 | def __init__(self): 20 | self.correct = 0 21 | self.total = 0 22 | def add(self, is_correct): 23 | self.total += 1 24 | if is_correct: 25 | self.correct += 1 26 | def get(self): 27 | if self.total == 0: 28 | return 0.0 29 | else: 30 | return float(self.correct) / self.total 31 | def clear(self): 32 | self.correct = 0 33 | self.total = 0 34 | 35 | class Config(object): 36 | def __init__(self): 37 | self.acc_NA = Accuracy() 38 | self.acc_not_NA = Accuracy() 39 | self.acc_total = Accuracy() 40 | self.data_path = './data' 41 | self.use_bag = True 42 | self.use_gpu = True 43 | self.is_training = True 44 | self.max_length = 120 45 | self.pos_num = 2 * self.max_length 46 | self.num_classes = 53 47 | self.hidden_size = 230 48 | self.pos_size = 5 49 | self.max_epoch = 15 50 | self.opt_method = 'SGD' 51 | self.optimizer = None 52 | self.learning_rate = 0.5 53 | self.weight_decay = 1e-5 54 | self.drop_prob = 0.5 55 | self.checkpoint_dir = './checkpoint' 56 | self.test_result_dir = './test_result' 57 | self.save_epoch = 1 58 | self.test_epoch = 1 59 | self.pretrain_model = None 60 | self.trainModel = None 61 | self.testModel = None 62 | self.batch_size = 160 63 | self.word_size = 50 64 | self.window_size = 3 65 | self.epoch_range = None 66 | def set_data_path(self, data_path): 67 | self.data_path = data_path 68 | def set_max_length(self, max_length): 69 | self.max_length = max_length 70 | self.pos_num = 2 * self.max_length 71 | def set_num_classes(self, num_classes): 72 | self.num_classes = num_classes 73 | def set_hidden_size(self, hidden_size): 74 | self.hidden_size = hidden_size 75 | def set_window_size(self, window_size): 76 | self.window_size = window_size 77 | def set_pos_size(self, pos_size): 78 | self.pos_size = pos_size 79 | def set_word_size(self, word_size): 80 | self.word_size = word_size 81 | def set_max_epoch(self, max_epoch): 82 | self.max_epoch = max_epoch 83 | def set_batch_size(self, batch_size): 84 | self.batch_size = batch_size 85 | def set_opt_method(self, opt_method): 86 | self.opt_method = opt_method 87 | def set_learning_rate(self, learning_rate): 88 | self.learning_rate = learning_rate 89 | def set_weight_decay(self, weight_decay): 90 | self.weight_decay = weight_decay 91 | def set_drop_prob(self, drop_prob): 92 | self.drop_prob = drop_prob 93 | def set_checkpoint_dir(self, checkpoint_dir): 94 | self.checkpoint_dir = checkpoint_dir 95 | def set_test_epoch(self, test_epoch): 96 | self.test_epoch = test_epoch 97 | def set_save_epoch(self, save_epoch): 98 | self.save_epoch = save_epoch 99 | def set_pretrain_model(self, pretrain_model): 100 | self.pretrain_model = pretrain_model 101 | def set_is_training(self, is_training): 102 | self.is_training = is_training 103 | def set_use_bag(self, use_bag): 104 | self.use_bag = use_bag 105 | def set_use_gpu(self, use_gpu): 106 | self.use_gpu = use_gpu 107 | def set_epoch_range(self, epoch_range): 108 | self.epoch_range = epoch_range 109 | 110 | def load_train_data(self): 111 | print("Reading training data...") 112 | self.data_word_vec = np.load(os.path.join(self.data_path, 'vec.npy')) 113 | self.data_train_word = np.load(os.path.join(self.data_path, 'train_word.npy')) 114 | self.data_train_pos1 = np.load(os.path.join(self.data_path, 'train_pos1.npy')) 115 | self.data_train_pos2 = np.load(os.path.join(self.data_path, 'train_pos2.npy')) 116 | self.data_train_mask = np.load(os.path.join(self.data_path, 'train_mask.npy')) 117 | if self.use_bag: 118 | self.data_query_label = np.load(os.path.join(self.data_path, 'train_ins_label.npy')) 119 | self.data_train_label = np.load(os.path.join(self.data_path, 'train_bag_label.npy')) 120 | self.data_train_scope = np.load(os.path.join(self.data_path, 'train_bag_scope.npy')) 121 | else: 122 | self.data_train_label = np.load(os.path.join(self.data_path, 'train_ins_label.npy')) 123 | self.data_train_scope = np.load(os.path.join(self.data_path, 'train_ins_scope.npy')) 124 | print("Finish reading") 125 | self.train_order = list(range(len(self.data_train_label))) 126 | self.train_batches = len(self.data_train_label) / self.batch_size 127 | if len(self.data_train_label) % self.batch_size != 0: 128 | self.train_batches += 1 129 | 130 | def load_test_data(self): 131 | print("Reading testing data...") 132 | self.data_word_vec = np.load(os.path.join(self.data_path, 'vec.npy')) 133 | self.data_test_word = np.load(os.path.join(self.data_path, 'test_word.npy')) 134 | self.data_test_pos1 = np.load(os.path.join(self.data_path, 'test_pos1.npy')) 135 | self.data_test_pos2 = np.load(os.path.join(self.data_path, 'test_pos2.npy')) 136 | self.data_test_mask = np.load(os.path.join(self.data_path, 'test_mask.npy')) 137 | if self.use_bag: 138 | self.data_test_label = np.load(os.path.join(self.data_path, 'test_bag_label.npy')) 139 | self.data_test_scope = np.load(os.path.join(self.data_path, 'test_bag_scope.npy')) 140 | else: 141 | self.data_test_label = np.load(os.path.join(self.data_path, 'test_ins_label.npy')) 142 | self.data_test_scope = np.load(os.path.join(self.data_path, 'test_ins_scope.npy')) 143 | print("Finish reading") 144 | self.test_batches = len(self.data_test_label) / self.batch_size 145 | if len(self.data_test_label) % self.batch_size != 0: 146 | self.test_batches += 1 147 | 148 | self.total_recall = self.data_test_label[:, 1:].sum() 149 | 150 | def set_train_model(self, model): 151 | print("Initializing training model...") 152 | self.model = model 153 | self.trainModel = self.model(config = self) 154 | if self.pretrain_model != None: 155 | self.trainModel.load_state_dict(torch.load(self.pretrain_model)) 156 | self.trainModel.cuda() 157 | if self.optimizer != None: 158 | pass 159 | elif self.opt_method == "Adagrad" or self.opt_method == "adagrad": 160 | self.optimizer = optim.Adagrad(self.trainModel.parameters(), lr = self.learning_rate, lr_decay = self.lr_decay, weight_decay = self.weight_decay) 161 | elif self.opt_method == "Adadelta" or self.opt_method == "adadelta": 162 | self.optimizer = optim.Adadelta(self.trainModel.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay) 163 | elif self.opt_method == "Adam" or self.opt_method == "adam": 164 | self.optimizer = optim.Adam(self.trainModel.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay) 165 | else: 166 | self.optimizer = optim.SGD(self.trainModel.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay) 167 | print("Finish initializing") 168 | 169 | def set_test_model(self, model): 170 | print("Initializing test model...") 171 | self.model = model 172 | self.testModel = self.model(config = self) 173 | self.testModel.cuda() 174 | self.testModel.eval() 175 | print("Finish initializing") 176 | 177 | def get_train_batch(self, batch): 178 | input_scope = np.take(self.data_train_scope, self.train_order[batch * self.batch_size : (batch + 1) * self.batch_size], axis = 0) 179 | index = [] 180 | scope = [0] 181 | for num in input_scope: 182 | index = index + list(range(num[0], num[1] + 1)) 183 | scope.append(scope[len(scope) - 1] + num[1] - num[0] + 1) 184 | self.batch_word = self.data_train_word[index, :] 185 | self.batch_pos1 = self.data_train_pos1[index, :] 186 | self.batch_pos2 = self.data_train_pos2[index, :] 187 | self.batch_mask = self.data_train_mask[index, :] 188 | self.batch_label = np.take(self.data_train_label, self.train_order[batch * self.batch_size : (batch + 1) * self.batch_size], axis = 0) 189 | self.batch_attention_query = self.data_query_label[index] 190 | self.batch_scope = scope 191 | 192 | def get_test_batch(self, batch): 193 | input_scope = self.data_test_scope[batch * self.batch_size : (batch + 1) * self.batch_size] 194 | index = [] 195 | scope = [0] 196 | for num in input_scope: 197 | index = index + list(range(num[0], num[1] + 1)) 198 | scope.append(scope[len(scope) - 1] + num[1] - num[0] + 1) 199 | self.batch_word = self.data_test_word[index, :] 200 | self.batch_pos1 = self.data_test_pos1[index, :] 201 | self.batch_pos2 = self.data_test_pos2[index, :] 202 | self.batch_mask = self.data_test_mask[index, :] 203 | self.batch_scope = scope 204 | def train_one_step(self): 205 | self.trainModel.embedding.word = to_var(self.batch_word) 206 | self.trainModel.embedding.pos1 = to_var(self.batch_pos1) 207 | self.trainModel.embedding.pos2 = to_var(self.batch_pos2) 208 | self.trainModel.encoder.mask = to_var(self.batch_mask) 209 | self.trainModel.selector.scope = self.batch_scope 210 | self.trainModel.selector.attention_query = to_var(self.batch_attention_query) 211 | self.trainModel.selector.label = to_var(self.batch_label) 212 | self.trainModel.classifier.label = to_var(self.batch_label) 213 | self.optimizer.zero_grad() 214 | loss, _output = self.trainModel() 215 | loss.backward() 216 | self.optimizer.step() 217 | for i, prediction in enumerate(_output): 218 | if self.batch_label[i] == 0: 219 | self.acc_NA.add(prediction == self.batch_label[i]) 220 | else: 221 | self.acc_not_NA.add(prediction == self.batch_label[i]) 222 | self.acc_total.add(prediction == self.batch_label[i]) 223 | return loss.data[0] 224 | 225 | def test_one_step(self): 226 | self.testModel.embedding.word = to_var(self.batch_word) 227 | self.testModel.embedding.pos1 = to_var(self.batch_pos1) 228 | self.testModel.embedding.pos2 = to_var(self.batch_pos2) 229 | self.testModel.encoder.mask = to_var(self.batch_mask) 230 | self.testModel.selector.scope = self.batch_scope 231 | return self.testModel.test() 232 | 233 | def train(self): 234 | if not os.path.exists(self.checkpoint_dir): 235 | os.mkdir(self.checkpoint_dir) 236 | best_auc = 0.0 237 | best_p = None 238 | best_r = None 239 | best_epoch = 0 240 | for epoch in range(self.max_epoch): 241 | print('Epoch ' + str(epoch) + ' starts...') 242 | self.acc_NA.clear() 243 | self.acc_not_NA.clear() 244 | self.acc_total.clear() 245 | np.random.shuffle(self.train_order) 246 | for batch in range(self.train_batches): 247 | self.get_train_batch(batch) 248 | loss = self.train_one_step() 249 | time_str = datetime.datetime.now().isoformat() 250 | sys.stdout.write("epoch %d step %d time %s | loss: %f, NA accuracy: %f, not NA accuracy: %f, total accuracy: %f\r" % (epoch, batch, time_str, loss, self.acc_NA.get(), self.acc_not_NA.get(), self.acc_total.get())) 251 | sys.stdout.flush() 252 | if (epoch + 1) % self.save_epoch == 0: 253 | print('Epoch ' + str(epoch) + ' has finished') 254 | print('Saving model...') 255 | path = os.path.join(self.checkpoint_dir, self.model.__name__ + '-' + str(epoch)) 256 | torch.save(self.trainModel.state_dict(), path) 257 | print('Have saved model to ' + path) 258 | if (epoch + 1) % self.test_epoch == 0: 259 | self.testModel = self.trainModel 260 | auc, pr_x, pr_y = self.test_one_epoch() 261 | if auc > best_auc: 262 | best_auc = auc 263 | best_p = pr_x 264 | best_r = pr_y 265 | best_epoch = epoch 266 | print("Finish training") 267 | print("Best epoch = %d | auc = %f" % (best_epoch, best_auc)) 268 | print("Storing best result...") 269 | if not os.path.isdir(self.test_result_dir): 270 | os.mkdir(self.test_result_dir) 271 | np.save(os.path.join(self.test_result_dir, self.model.__name__ + '_x.npy'), best_p) 272 | np.save(os.path.join(self.test_result_dir, self.model.__name__ + '_y.npy'), best_r) 273 | print("Finish storing") 274 | def test_one_epoch(self): 275 | test_score = [] 276 | for batch in tqdm(range(self.test_batches)): 277 | self.get_test_batch(batch) 278 | batch_score = self.test_one_step() 279 | test_score = test_score + batch_score 280 | test_result = [] 281 | for i in range(len(test_score)): 282 | for j in range(1, len(test_score[i])): 283 | test_result.append([self.data_test_label[i][j], test_score[i][j]]) 284 | test_result = sorted(test_result, key = lambda x: x[1]) 285 | test_result = test_result[::-1] 286 | pr_x = [] 287 | pr_y = [] 288 | correct = 0 289 | for i, item in enumerate(test_result): 290 | correct += item[0] 291 | pr_y.append(float(correct) / (i + 1)) 292 | pr_x.append(float(correct) / self.total_recall) 293 | auc = sklearn.metrics.auc(x = pr_x, y = pr_y) 294 | print("auc: ", auc) 295 | return auc, pr_x, pr_y 296 | def test(self): 297 | best_epoch = None 298 | best_auc = 0.0 299 | best_p = None 300 | best_r = None 301 | for epoch in self.epoch_range: 302 | path = os.path.join(self.checkpoint_dir, self.model.__name__ + '-' + str(epoch)) 303 | if not os.path.exists(path): 304 | continue 305 | print("Start testing epoch %d" % (epoch)) 306 | self.testModel.load_state_dict(torch.load(path)) 307 | auc, p, r = self.test_one_epoch() 308 | if auc > best_auc: 309 | best_auc = auc 310 | best_epoch = epoch 311 | best_p = p 312 | best_r = r 313 | print("Finish testing epoch %d" % (epoch)) 314 | print("Best epoch = %d | auc = %f" % (best_epoch, best_auc)) 315 | print("Storing best result...") 316 | if not os.path.isdir(self.test_result_dir): 317 | os.mkdir(self.test_result_dir) 318 | np.save(os.path.join(self.test_result_dir, self.model.__name__ + '_x.npy'), best_p) 319 | np.save(os.path.join(self.test_result_dir, self.model.__name__ + '_y.npy'), best_r) 320 | print("Finish storing") 321 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .Config import Config 2 | -------------------------------------------------------------------------------- /draw_plot.py: -------------------------------------------------------------------------------- 1 | import sklearn.metrics 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import sys 7 | import os 8 | 9 | result_dir = './test_result' 10 | 11 | def main(): 12 | models = sys.argv[1:] 13 | for model in models: 14 | x = np.load(os.path.join(result_dir, model + '_x.npy')) 15 | y = np.load(os.path.join(result_dir, model + '_y.npy')) 16 | f1 = (2 * x * y / (x + y + 1e-20)).max() 17 | auc = sklearn.metrics.auc(x = x, y = y) 18 | plt.plot(x, y, lw = 2, label = model) 19 | print(model + ' : ' + 'auc = ' + str(auc) + ' | ' + 'max F1 = ' + str(f1) + ' P@100: {} | P@200: {} | P@300: {} | Mean: {}'.format(y[100], y[200], y[300], (y[100] + y[200] + y[300]) / 3)) 20 | 21 | plt.xlabel('Recall') 22 | plt.ylabel('Precision') 23 | plt.ylim(0.3, 1.0) 24 | plt.xlim(0.0, 0.4) 25 | plt.title('Precision-Recall') 26 | plt.legend(loc = "upper right") 27 | plt.grid(True) 28 | plt.savefig(os.path.join(result_dir, 'pr_curve')) 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | 5 | in_path = "./raw_data/" 6 | out_path = "./data" 7 | case_sensitive = False 8 | if not os.path.exists('./data'): 9 | os.mkdir('./data') 10 | train_file_name = in_path + 'train.json' 11 | test_file_name = in_path + 'test.json' 12 | word_file_name = in_path + 'word_vec.json' 13 | rel_file_name = in_path + 'rel2id.json' 14 | 15 | def find_pos(sentence, head, tail): 16 | def find(sentence, entity): 17 | p = sentence.find(' ' + entity + ' ') 18 | if p == -1: 19 | if sentence[:len(entity) + 1] == entity + ' ': 20 | p = 0 21 | elif sentence[-len(entity) - 1:] == ' ' + entity: 22 | p = len(sentence) - len(entity) 23 | else: 24 | p = 0 25 | else: 26 | p += 1 27 | return p 28 | 29 | sentence = ' '.join(sentence.split()) 30 | p1 = find(sentence, head) 31 | p2 = find(sentence, tail) 32 | words = sentence.split() 33 | cur_pos = 0 34 | pos1 = -1 35 | pos2 = -1 36 | for i, word in enumerate(words): 37 | if cur_pos == p1: 38 | pos1 = i 39 | if cur_pos == p2: 40 | pos2 = i 41 | cur_pos += len(word) + 1 42 | return pos1, pos2 43 | 44 | def init(file_name, word_vec_file_name, rel2id_file_name, max_length = 120, case_sensitive = False, is_training = True): 45 | if file_name is None or not os.path.isfile(file_name): 46 | raise Exception("[ERROR] Data file doesn't exist") 47 | if word_vec_file_name is None or not os.path.isfile(word_vec_file_name): 48 | raise Exception("[ERROR] Word vector file doesn't exist") 49 | if rel2id_file_name is None or not os.path.isfile(rel2id_file_name): 50 | raise Exception("[ERROR] rel2id file doesn't exist") 51 | 52 | print("Loading data file...") 53 | ori_data = json.load(open(file_name, "r")) 54 | print("Finish loading") 55 | print("Loading word_vec file...") 56 | ori_word_vec = json.load(open(word_vec_file_name, "r")) 57 | print("Finish loading") 58 | print("Loading rel2id file...") 59 | rel2id = json.load(open(rel2id_file_name, "r")) 60 | print("Finish loading") 61 | 62 | if not case_sensitive: 63 | print("Eliminating case sensitive problem...") 64 | for i in ori_data: 65 | i['sentence'] = i['sentence'].lower() 66 | i['head']['word'] = i['head']['word'].lower() 67 | i['tail']['word'] = i['tail']['word'].lower() 68 | for i in ori_word_vec: 69 | i['word'] = i['word'].lower() 70 | print("Finish eliminating") 71 | 72 | # vec 73 | print("Building word vector matrix and mapping...") 74 | word2id = {} 75 | word_vec_mat = [] 76 | word_size = len(ori_word_vec[0]['vec']) 77 | print("Got {} words of {} dims".format(len(ori_word_vec), word_size)) 78 | for i in ori_word_vec: 79 | word2id[i['word']] = len(word2id) 80 | word_vec_mat.append(i['vec']) 81 | word2id['UNK'] = len(word2id) 82 | word2id['BLANK'] = len(word2id) 83 | word_vec_mat.append(np.random.normal(loc = 0, scale = 0.05, size = word_size)) 84 | word_vec_mat.append(np.zeros(word_size, dtype = np.float32)) 85 | word_vec_mat = np.array(word_vec_mat, dtype = np.float32) 86 | print("Finish building") 87 | 88 | # sorting 89 | print("Sorting data...") 90 | ori_data.sort(key = lambda a: a['head']['id'] + '#' + a['tail']['id'] + '#' + a['relation']) 91 | print("Finish sorting") 92 | 93 | sen_tot = len(ori_data) 94 | sen_word = np.zeros((sen_tot, max_length), dtype = np.int64) 95 | sen_pos1 = np.zeros((sen_tot, max_length), dtype = np.int64) 96 | sen_pos2 = np.zeros((sen_tot, max_length), dtype = np.int64) 97 | sen_mask = np.zeros((sen_tot, max_length, 3), dtype = np.float32) 98 | sen_label = np.zeros((sen_tot), dtype = np.int64) 99 | sen_len = np.zeros((sen_tot), dtype = np.int64) 100 | bag_label = [] 101 | bag_scope = [] 102 | bag_key = [] 103 | for i in range(len(ori_data)): 104 | if i%1000 == 0: 105 | print i 106 | sen = ori_data[i] 107 | # sen_label 108 | if sen['relation'] in rel2id: 109 | sen_label[i] = rel2id[sen['relation']] 110 | else: 111 | sen_label[i] = rel2id['NA'] 112 | words = sen['sentence'].split() 113 | # sen_len 114 | sen_len[i] = min(len(words), max_length) 115 | # sen_word 116 | for j, word in enumerate(words): 117 | if j < max_length: 118 | if word in word2id: 119 | sen_word[i][j] = word2id[word] 120 | else: 121 | sen_word[i][j] = word2id['UNK'] 122 | for j in range(j + 1, max_length): 123 | sen_word[i][j] = word2id['BLANK'] 124 | 125 | pos1, pos2 = find_pos(sen['sentence'], sen['head']['word'], sen['tail']['word']) 126 | if pos1 == -1 or pos2 == -1: 127 | raise Exception("[ERROR] Position error, index = {}, sentence = {}, head = {}, tail = {}".format(i, sen['sentence'], sen['head']['word'], sen['tail']['word'])) 128 | if pos1 >= max_length: 129 | pos1 = max_length - 1 130 | if pos2 >= max_length: 131 | pos2 = max_length - 1 132 | pos_min = min(pos1, pos2) 133 | pos_max = max(pos1, pos2) 134 | for j in range(max_length): 135 | # sen_pos1, sen_pos2 136 | sen_pos1[i][j] = j - pos1 + max_length 137 | sen_pos2[i][j] = j - pos2 + max_length 138 | # sen_mask 139 | if j >= sen_len[i]: 140 | sen_mask[i][j] = [0, 0, 0] 141 | elif j - pos_min <= 0: 142 | sen_mask[i][j] = [100, 0, 0] 143 | elif j - pos_max <= 0: 144 | sen_mask[i][j] = [0, 100, 0] 145 | else: 146 | sen_mask[i][j] = [0, 0, 100] 147 | # bag_scope 148 | if is_training: 149 | tup = (sen['head']['id'], sen['tail']['id'], sen['relation']) 150 | else: 151 | tup = (sen['head']['id'], sen['tail']['id']) 152 | if bag_key == [] or bag_key[len(bag_key) - 1] != tup: 153 | bag_key.append(tup) 154 | bag_scope.append([i, i]) 155 | bag_scope[len(bag_scope) - 1][1] = i 156 | 157 | print("Processing bag label...") 158 | # bag_label 159 | if is_training: 160 | for i in bag_scope: 161 | bag_label.append(sen_label[i[0]]) 162 | else: 163 | for i in bag_scope: 164 | multi_hot = np.zeros(len(rel2id), dtype = np.int64) 165 | for j in range(i[0], i[1]+1): 166 | multi_hot[sen_label[j]] = 1 167 | bag_label.append(multi_hot) 168 | print("Finish processing") 169 | # ins_scope 170 | ins_scope = np.stack([list(range(len(ori_data))), list(range(len(ori_data)))], axis = 1) 171 | print("Processing instance label...") 172 | # ins_label 173 | if is_training: 174 | ins_label = sen_label 175 | else: 176 | ins_label = [] 177 | for i in sen_label: 178 | one_hot = np.zeros(len(rel2id), dtype = np.int64) 179 | one_hot[i] = 1 180 | ins_label.append(one_hot) 181 | ins_label = np.array(ins_label, dtype = np.int64) 182 | print("Finishing processing") 183 | bag_scope = np.array(bag_scope, dtype = np.int64) 184 | bag_label = np.array(bag_label, dtype = np.int64) 185 | ins_scope = np.array(ins_scope, dtype = np.int64) 186 | ins_label = np.array(ins_label, dtype = np.int64) 187 | 188 | # saving 189 | print("Saving files") 190 | if is_training: 191 | name_prefix = "train" 192 | else: 193 | name_prefix = "test" 194 | np.save(os.path.join(out_path, 'vec.npy'), word_vec_mat) 195 | np.save(os.path.join(out_path, name_prefix + '_word.npy'), sen_word) 196 | np.save(os.path.join(out_path, name_prefix + '_pos1.npy'), sen_pos1) 197 | np.save(os.path.join(out_path, name_prefix + '_pos2.npy'), sen_pos2) 198 | np.save(os.path.join(out_path, name_prefix + '_mask.npy'), sen_mask) 199 | np.save(os.path.join(out_path, name_prefix + '_bag_label.npy'), bag_label) 200 | np.save(os.path.join(out_path, name_prefix + '_bag_scope.npy'), bag_scope) 201 | np.save(os.path.join(out_path, name_prefix + '_ins_label.npy'), ins_label) 202 | np.save(os.path.join(out_path, name_prefix + '_ins_scope.npy'), ins_scope) 203 | print("Finish saving") 204 | 205 | init(train_file_name, word_file_name, rel_file_name, max_length = 120, case_sensitive = False, is_training = True) 206 | init(test_file_name, word_file_name, rel_file_name, max_length = 120, case_sensitive = False, is_training = False) 207 | -------------------------------------------------------------------------------- /models/CNN_ATT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | from .Model import Model 12 | 13 | class CNN_ATT(Model): 14 | def __init__(self, config): 15 | super(CNN_ATT, self).__init__(config) 16 | self.encoder = CNN(config) 17 | self.selector = Attention(config, config.hidden_size) 18 | -------------------------------------------------------------------------------- /models/CNN_AVE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | from .Model import Model 12 | 13 | class CNN_AVE(Model): 14 | def __init__(self, config): 15 | super(CNN_AVE, self).__init__(config) 16 | self.encoder = CNN(config) 17 | self.selector = Average(config, config.hidden_size) 18 | -------------------------------------------------------------------------------- /models/CNN_ONE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | from .Model import Model 12 | 13 | class CNN_ONE(Model): 14 | def __init__(self, config): 15 | super(CNN_ONE, self).__init__(config) 16 | self.encoder = CNN(config) 17 | self.selector = One(config, config.hidden_size) 18 | -------------------------------------------------------------------------------- /models/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | 12 | class Model(nn.Module): 13 | def __init__(self, config): 14 | super(Model, self).__init__() 15 | self.config = config 16 | self.embedding = Embedding(config) 17 | self.encoder = None 18 | self.selector = None 19 | self.classifier = Classifier(config) 20 | def forward(self): 21 | embedding = self.embedding() 22 | sen_embedding = self.encoder(embedding) 23 | logits = self.selector(sen_embedding) 24 | return self.classifier(logits) 25 | def test(self): 26 | embedding = self.embedding() 27 | sen_embedding = self.encoder(embedding) 28 | return self.selector.test(sen_embedding) 29 | -------------------------------------------------------------------------------- /models/PCNN_ATT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | from .Model import Model 12 | 13 | class PCNN_ATT(Model): 14 | def __init__(self, config): 15 | super(PCNN_ATT, self).__init__(config) 16 | self.encoder = PCNN(config) 17 | self.selector = Attention(config, config.hidden_size * 3) 18 | -------------------------------------------------------------------------------- /models/PCNN_AVE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | from .Model import Model 12 | 13 | class PCNN_AVE(Model): 14 | def __init__(self, config): 15 | super(PCNN_AVE, self).__init__(config) 16 | self.encoder = PCNN(config) 17 | self.selector = Average(config, config.hidden_size * 3) 18 | -------------------------------------------------------------------------------- /models/PCNN_ONE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from networks.embedding import * 8 | from networks.encoder import * 9 | from networks.selector import * 10 | from networks.classifier import * 11 | from .Model import Model 12 | 13 | class PCNN_ONE(Model): 14 | def __init__(self, config): 15 | super(PCNN_ONE, self).__init__(config) 16 | self.encoder = PCNN(config) 17 | self.selector = One(config, config.hidden_size * 3) 18 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .Model import Model 2 | from .CNN_ATT import CNN_ATT 3 | from .PCNN_ATT import PCNN_ATT 4 | from .CNN_AVE import CNN_AVE 5 | from .PCNN_AVE import PCNN_AVE 6 | from .CNN_ONE import CNN_ONE 7 | from .PCNN_ONE import PCNN_ONE 8 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import * 2 | from .encoder import * 3 | from .selector import * 4 | from .classifier import * 5 | -------------------------------------------------------------------------------- /networks/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | 8 | class Classifier(nn.Module): 9 | def __init__(self, config): 10 | super(Classifier, self).__init__() 11 | self.config = config 12 | self.label = None 13 | self.loss = nn.CrossEntropyLoss() 14 | def forward(self, logits): 15 | loss = self.loss(logits, self.label) 16 | _, output = torch.max(logits, dim = 1) 17 | return loss, output.data 18 | -------------------------------------------------------------------------------- /networks/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | 8 | class Embedding(nn.Module): 9 | def __init__(self, config): 10 | super(Embedding, self).__init__() 11 | self.config = config 12 | self.word_embedding = nn.Embedding(self.config.data_word_vec.shape[0], self.config.data_word_vec.shape[1]) 13 | self.pos1_embedding = nn.Embedding(self.config.pos_num, self.config.pos_size, padding_idx = 0) 14 | self.pos2_embedding = nn.Embedding(self.config.pos_num, self.config.pos_size, padding_idx = 0) 15 | self.init_word_weights() 16 | self.init_pos_weights() 17 | self.word = None 18 | self.pos1 = None 19 | self.pos2 = None 20 | 21 | def init_word_weights(self): 22 | self.word_embedding.weight.data.copy_(torch.from_numpy(self.config.data_word_vec)) 23 | 24 | def init_pos_weights(self): 25 | nn.init.xavier_uniform(self.pos1_embedding.weight.data) 26 | if self.pos1_embedding.padding_idx is not None: 27 | self.pos1_embedding.weight.data[self.pos1_embedding.padding_idx].fill_(0) 28 | nn.init.xavier_uniform(self.pos2_embedding.weight.data) 29 | if self.pos2_embedding.padding_idx is not None: 30 | self.pos2_embedding.weight.data[self.pos2_embedding.padding_idx].fill_(0) 31 | def forward(self): 32 | word = self.word_embedding(self.word) 33 | pos1 = self.pos1_embedding(self.pos1) 34 | pos2 = self.pos2_embedding(self.pos2) 35 | embedding = torch.cat((word, pos1, pos2), dim = 2) 36 | return embedding 37 | -------------------------------------------------------------------------------- /networks/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | 8 | class _CNN(nn.Module): 9 | def __init__(self, config): 10 | super(_CNN, self).__init__() 11 | self.config = config 12 | self.in_channels = 1 13 | self.in_height = self.config.max_length 14 | self.in_width = self.config.word_size + 2 * self.config.pos_size 15 | self.kernel_size = (self.config.window_size, self.in_width) 16 | self.out_channels = self.config.hidden_size 17 | self.stride = (1, 1) 18 | self.padding = (1, 0) 19 | self.cnn = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding) 20 | def forward(self, embedding): 21 | return self.cnn(embedding) 22 | 23 | class _PiecewisePooling(nn.Module): 24 | def __init(self): 25 | super(_PiecewisePooling, self).__init__() 26 | def forward(self, x, mask, hidden_size): 27 | mask = torch.unsqueeze(mask, 1) 28 | x, _ = torch.max(mask + x, dim = 2) 29 | x = x - 100 30 | return x.view(-1, hidden_size * 3) 31 | 32 | class _MaxPooling(nn.Module): 33 | def __init__(self): 34 | super(_MaxPooling, self).__init__() 35 | def forward(self, x, hidden_size): 36 | x, _ = torch.max(x, dim = 2) 37 | return x.view(-1, hidden_size) 38 | 39 | class PCNN(nn.Module): 40 | def __init__(self, config): 41 | super(PCNN, self).__init__() 42 | self.config = config 43 | self.mask = None 44 | self.cnn = _CNN(config) 45 | self.pooling = _PiecewisePooling() 46 | self.activation = nn.ReLU() 47 | def forward(self, embedding): 48 | embedding = torch.unsqueeze(embedding, dim = 1) 49 | x = self.cnn(embedding) 50 | x = self.pooling(x, self.mask, self.config.hidden_size) 51 | return self.activation(x) 52 | 53 | class CNN(nn.Module): 54 | def __init__(self, config): 55 | super(CNN, self).__init__() 56 | self.config = config 57 | self.cnn = _CNN(config) 58 | self.pooling = _MaxPooling() 59 | self.activation = nn.ReLU() 60 | def forward(self, embedding): 61 | embedding = torch.unsqueeze(embedding, dim = 1) 62 | x = self.cnn(embedding) 63 | x = self.pooling(x, self.config.hidden_size) 64 | return self.activation(x) 65 | -------------------------------------------------------------------------------- /networks/selector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | 8 | class Selector(nn.Module): 9 | def __init__(self, config, relation_dim): 10 | super(Selector, self).__init__() 11 | self.config = config 12 | self.relation_matrix = nn.Embedding(self.config.num_classes, relation_dim) 13 | self.bias = nn.Parameter(torch.Tensor(self.config.num_classes)) 14 | self.attention_matrix = nn.Embedding(self.config.num_classes, relation_dim) 15 | self.init_weights() 16 | self.scope = None 17 | self.attention_query = None 18 | self.label = None 19 | self.dropout = nn.Dropout(self.config.drop_prob) 20 | def init_weights(self): 21 | nn.init.xavier_uniform(self.relation_matrix.weight.data) 22 | nn.init.normal(self.bias) 23 | nn.init.xavier_uniform(self.attention_matrix.weight.data) 24 | def get_logits(self, x): 25 | logits = torch.matmul(x, torch.transpose(self.relation_matrix.weight, 0, 1),) + self.bias 26 | return logits 27 | def forward(self, x): 28 | raise NotImplementedError 29 | def test(self, x): 30 | raise NotImplementedError 31 | 32 | class Attention(Selector): 33 | def _attention_train_logit(self, x): 34 | relation_query = self.relation_matrix(self.attention_query) 35 | attention = self.attention_matrix(self.attention_query) 36 | attention_logit = torch.sum(x * attention * relation_query, 1, True) 37 | return attention_logit 38 | def _attention_test_logit(self, x): 39 | attention_logit = torch.matmul(x, torch.transpose(self.attention_matrix.weight * self.relation_matrix.weight, 0, 1)) 40 | return attention_logit 41 | def forward(self, x): 42 | attention_logit = self._attention_train_logit(x) 43 | tower_repre = [] 44 | for i in range(len(self.scope) - 1): 45 | sen_matrix = x[self.scope[i] : self.scope[i + 1]] 46 | attention_score = F.softmax(torch.transpose(attention_logit[self.scope[i] : self.scope[i + 1]], 0, 1), 1) 47 | final_repre = torch.squeeze(torch.matmul(attention_score, sen_matrix)) 48 | tower_repre.append(final_repre) 49 | stack_repre = torch.stack(tower_repre) 50 | stack_repre = self.dropout(stack_repre) 51 | logits = self.get_logits(stack_repre) 52 | return logits 53 | def test(self, x): 54 | attention_logit = self._attention_test_logit(x) 55 | tower_output = [] 56 | for i in range(len(self.scope) - 1): 57 | sen_matrix = x[self.scope[i] : self.scope[i + 1]] 58 | attention_score = F.softmax(torch.transpose(attention_logit[self.scope[i] : self.scope[i + 1]], 0, 1), 1) 59 | final_repre = torch.matmul(attention_score, sen_matrix) 60 | logits = self.get_logits(final_repre) 61 | tower_output.append(torch.diag(F.softmax(logits, 1))) 62 | stack_output = torch.stack(tower_output) 63 | return list(stack_output.data.cpu().numpy()) 64 | 65 | class One(Selector): 66 | def forward(self, x): 67 | tower_logits = [] 68 | for i in range(len(self.scope) - 1): 69 | sen_matrix = x[self.scope[i] : self.scope[i + 1]] 70 | sen_matrix = self.dropout(sen_matrix) 71 | logits = self.get_logits(sen_matrix) 72 | score = F.softmax(logits, 1) 73 | _, k = torch.max(score, dim = 0) 74 | k = k[self.label[i]] 75 | tower_logits.append(logits[k]) 76 | return torch.cat(tower_logits, 0) 77 | def test(self, x): 78 | tower_score = [] 79 | for i in range(len(self.scope) - 1): 80 | sen_matrix = x[self.scope[i] : self.scope[i + 1]] 81 | logits = self.get_logits(sen_matrix) 82 | score = F.softmax(logits, 1) 83 | score, _ = torch.max(score, 0) 84 | tower_score.append(score) 85 | tower_score = torch.stack(tower_score) 86 | return list(tower_score.data.cpu().numpy()) 87 | 88 | class Average(Selector): 89 | def forward(self, x): 90 | tower_repre = [] 91 | for i in range(len(self.scope) - 1): 92 | sen_matrix = x[self.scope[i] : self.scope[i+ 1]] 93 | final_repre = torch.mean(sen_matrix, 0) 94 | tower_repre.append(final_repre) 95 | stack_repre = torch.stack(tower_repre) 96 | stack_repre = self.dropout(stack_repre) 97 | logits = self.get_logits(stack_repre) 98 | return logits 99 | def test(self, x): 100 | tower_repre = [] 101 | for i in range(len(self.scope) - 1): 102 | sen_matrix = x[self.scope[i] : self.scope[i + 1]] 103 | final_repre = torch.mean(sen_matrix, 0) 104 | tower_repre.append(final_repre) 105 | stack_repre = torch.stack(tower_repre) 106 | logits = self.get_logits(stack_repre) 107 | score = F.softmax(logits, 1) 108 | return list(score.data.cpu().numpy()) 109 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import config 2 | import models 3 | import numpy as np 4 | import os 5 | import time 6 | import datetime 7 | import json 8 | from sklearn.metrics import average_precision_score 9 | import sys 10 | import os 11 | import argparse 12 | 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--model_name', type = str, default = 'pcnn_att', help = 'name of the model') 16 | args = parser.parse_args() 17 | model = { 18 | 'pcnn_att': models.PCNN_ATT, 19 | 'pcnn_one': models.PCNN_ONE, 20 | 'pcnn_ave': models.PCNN_AVE, 21 | 'cnn_att': models.CNN_ATT, 22 | 'cnn_one': models.CNN_ONE, 23 | 'cnn_ave': models.CNN_AVE 24 | } 25 | con = config.Config() 26 | con.set_max_epoch(15) 27 | con.load_test_data() 28 | con.set_test_model(model[args.model_name]) 29 | con.set_epoch_range([7,12]) 30 | con.test() 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import config 2 | import models 3 | import numpy as np 4 | import os 5 | import time 6 | import datetime 7 | import json 8 | from sklearn.metrics import average_precision_score 9 | import sys 10 | import os 11 | import argparse 12 | 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--model_name', type = str, default = 'pcnn_att', help = 'name of the model') 16 | args = parser.parse_args() 17 | model = { 18 | 'pcnn_att': models.PCNN_ATT, 19 | 'pcnn_one': models.PCNN_ONE, 20 | 'pcnn_ave': models.PCNN_AVE, 21 | 'cnn_att': models.CNN_ATT, 22 | 'cnn_one': models.CNN_ONE, 23 | 'cnn_ave': models.CNN_AVE 24 | } 25 | con = config.Config() 26 | con.set_max_epoch(15) 27 | con.load_train_data() 28 | con.load_test_data() 29 | con.set_train_model(model[args.model_name]) 30 | con.train() 31 | --------------------------------------------------------------------------------