├── network ├── __init__.py ├── __pycache__ │ ├── model.cpython-35.pyc │ ├── utils.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ ├── encoder.cpython-35.pyc │ └── selector.cpython-35.pyc ├── selector.py ├── utils.py ├── encoder.py └── model.py ├── data ├── ddi │ ├── label2id.json │ ├── data_prepare.py │ └── config.py └── dti │ ├── transform.py │ ├── data_prepare.py │ └── config.py ├── plot_pr.py ├── README.md ├── test_ddi.py ├── test_dti.py ├── predict.py ├── train_ddi.py ├── train_dti.py ├── visualize.py ├── dataset.py └── LICENSE /network/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/ddi/label2id.json: -------------------------------------------------------------------------------- 1 | { 2 | "NA": 0, 3 | "advise": 1, 4 | "effect": 2, 5 | "mechanism": 3, 6 | "int": 4 7 | } -------------------------------------------------------------------------------- /network/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haiya1994/BERE/HEAD/network/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haiya1994/BERE/HEAD/network/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haiya1994/BERE/HEAD/network/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/encoder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haiya1994/BERE/HEAD/network/__pycache__/encoder.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/selector.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haiya1994/BERE/HEAD/network/__pycache__/selector.cpython-35.pyc -------------------------------------------------------------------------------- /data/dti/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Transforming (for prediction): 3 | 1. Load the statistic file (vocab.pt). 4 | 2. Replace each word with an ID. 5 | 3. Transform the file with '.json' format into '.pt' format. 6 | """ 7 | 8 | import sys 9 | 10 | import config 11 | 12 | from dataset import * 13 | 14 | sys.path.append("../..") 15 | 16 | logging.basicConfig(level=logging.INFO, 17 | format='%(asctime)s %(levelname)-8s %(message)s') 18 | 19 | vocab = torch.load('vocab.pt') 20 | logging.info('Number of classes: {}'.format(vocab.class_num)) 21 | 22 | if config.BAG_MODE: 23 | DatasetClass = REDataset_BAG 24 | else: 25 | DatasetClass = REDataset_INS 26 | 27 | 28 | def dump_dataset(data_name): 29 | dataset = DatasetClass(vocab, data_dir='.', data_name=data_name + '.json', max_length=config.MAX_LENGTH) 30 | torch.save(dataset, data_name + '.pt') 31 | 32 | 33 | dump_dataset('pmc_nintedanib') 34 | -------------------------------------------------------------------------------- /data/ddi/data_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data preprocessing (For training and test): 3 | 1. Load the pretrained word vectors. 4 | 2. Replace each word with an ID. 5 | 3. Count the basic statistics of the input data. 6 | 4. Transform the file with '.json' format into '.pt' format. 7 | """ 8 | 9 | 10 | import sys 11 | 12 | import config 13 | 14 | from dataset import * 15 | 16 | sys.path.append("../..") 17 | 18 | logging.basicConfig(level=logging.INFO, 19 | format='%(asctime)s %(levelname)-8s %(message)s') 20 | 21 | vocab = Vocab(label_path='label2id.json', emb_path='../PubMed-and-PMC-w2v.bin') 22 | 23 | logging.info('Number of classes: {}'.format(vocab.class_num)) 24 | 25 | if config.BAG_MODE: 26 | DatasetClass = REDataset_BAG 27 | else: 28 | DatasetClass = REDataset_INS 29 | 30 | 31 | def dump_dataset(data_name): 32 | dataset = DatasetClass(vocab, data_dir='.', data_name=data_name + '.json', max_length=config.MAX_LENGTH) 33 | torch.save(dataset, data_name + '.pt') 34 | 35 | 36 | dump_dataset('train') 37 | dump_dataset('valid') 38 | dump_dataset('test') 39 | 40 | vocab.post_process() 41 | logging.info('Used pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 42 | torch.save(vocab, 'vocab.pt') 43 | -------------------------------------------------------------------------------- /data/dti/data_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data preprocessing (For training and test): 3 | 1. Load the pretrained word vectors. 4 | 2. Replace each word with an ID. 5 | 3. Count the basic statistics of the input data. 6 | 4. Transform the file with '.json' format into '.pt' format. 7 | """ 8 | 9 | 10 | import sys 11 | 12 | import config 13 | 14 | from dataset import * 15 | 16 | sys.path.append("../..") 17 | 18 | logging.basicConfig(level=logging.INFO, 19 | format='%(asctime)s %(levelname)-8s %(message)s') 20 | 21 | vocab = Vocab(label_path='label2id.json', emb_path='../PubMed-and-PMC-w2v.bin') 22 | 23 | logging.info('Number of classes: {}'.format(vocab.class_num)) 24 | 25 | if config.BAG_MODE: 26 | DatasetClass = REDataset_BAG 27 | else: 28 | DatasetClass = REDataset_INS 29 | 30 | 31 | def dump_dataset(data_name): 32 | dataset = DatasetClass(vocab, data_dir='.', data_name=data_name + '.json', max_length=config.MAX_LENGTH) 33 | torch.save(dataset, data_name + '.pt') 34 | 35 | 36 | dump_dataset('train') 37 | dump_dataset('valid') 38 | dump_dataset('test') 39 | 40 | vocab.post_process() 41 | logging.info('Used pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 42 | torch.save(vocab, 'vocab.pt') 43 | -------------------------------------------------------------------------------- /network/selector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | class BagAttention(nn.Module): 8 | def __init__(self, in_dim): 9 | super(BagAttention, self).__init__() 10 | self.scale = in_dim ** -0.5 11 | 12 | self.attn_w = nn.Parameter(torch.FloatTensor(in_dim)) 13 | 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | init.normal_(self.attn_w.data, mean=0, std=0.01) 18 | 19 | def forward(self, x, scope): 20 | attn = (self.attn_w * x).sum(-1) 21 | attn = self.scale * attn # B 22 | 23 | bag_logits = [] 24 | bag_attns = [] 25 | start_offset = 0 26 | for i in range(len(scope)): 27 | end_offset = scope[i] 28 | bag_x = x[start_offset:end_offset] # n*H 29 | 30 | bag_attn = F.softmax(attn[start_offset:end_offset], -1) # n 31 | 32 | bag_attns.append(bag_attn) 33 | 34 | bag_logits.append(torch.matmul(bag_attn, bag_x)) # (n') x (n', hidden_size) = (hidden_size) 35 | 36 | start_offset = end_offset 37 | 38 | bag_logits = torch.stack(bag_logits) 39 | 40 | return bag_logits, bag_attns 41 | -------------------------------------------------------------------------------- /data/dti/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | logging.basicConfig(level=logging.INFO, 5 | format='%(asctime)s %(levelname)-8s %(message)s') 6 | 7 | ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | SAVE_DIR = 'checkpoint' 10 | RESULT_DIR = 'result' 11 | OUTPUT_DIR = 'output' 12 | DATA_SET = 'dti' 13 | 14 | BAG_MODE = True 15 | LOSS_WEIGHT = None 16 | 17 | EMBEDDING_FINE_TUNE = True 18 | BIDIRECTIONAL = True 19 | 20 | MAX_LENGTH = 60 21 | TAG_DIM = 50 22 | HIDDEN_DIM = 250 23 | 24 | DROP_PROB = 0.5 25 | L2_REG = 0 26 | 27 | LEARNING_RATE = 0.0001 28 | BATCH_SIZE = 64 29 | MAX_EPOCHS = 25 30 | 31 | 32 | def log(): 33 | logging.info('Loading config of {}'.format(ROOT_DIR)) 34 | 35 | logging.info('BAG_MODE {}'.format('✔' if BAG_MODE else '×')) 36 | logging.info('LOSS_WEIGHT: {}'.format(LOSS_WEIGHT)) 37 | logging.info('EMBEDDING_FINE_TUNE {}'.format('✔' if EMBEDDING_FINE_TUNE else '×')) 38 | logging.info('BIDIRECTIONAL {}'.format('✔' if BIDIRECTIONAL else '×')) 39 | 40 | logging.info('MAX_LENGTH: {}'.format(MAX_LENGTH)) 41 | logging.info('TAG_DIM: {}'.format(TAG_DIM)) 42 | logging.info('HIDDEN_DIM: {}'.format(HIDDEN_DIM)) 43 | 44 | logging.info('DROP_PROB: {}'.format(DROP_PROB)) 45 | logging.info('L2_REG: {}'.format(L2_REG)) 46 | 47 | logging.info('LEARNING_RATE: {}'.format(LEARNING_RATE)) 48 | logging.info('BATCH_SIZE: {}'.format(BATCH_SIZE)) 49 | logging.info('MAX_EPOCHS: {}'.format(MAX_EPOCHS)) 50 | -------------------------------------------------------------------------------- /data/ddi/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | logging.basicConfig(level=logging.INFO, 5 | format='%(asctime)s %(levelname)-8s %(message)s') 6 | 7 | ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | SAVE_DIR = 'checkpoint' 10 | RESULT_DIR = 'result' 11 | OUTPUT_DIR = 'output' 12 | DATA_SET = 'ddi' 13 | 14 | BAG_MODE = False 15 | LOSS_WEIGHT = None 16 | 17 | EMBEDDING_FINE_TUNE = True 18 | BIDIRECTIONAL = True 19 | 20 | MAX_LENGTH = 60 21 | TAG_DIM = 50 22 | HIDDEN_DIM = 250 23 | 24 | DROP_PROB = 0.5 25 | L2_REG = 0 26 | 27 | LEARNING_RATE = 0.0001 28 | BATCH_SIZE = 128 29 | MAX_EPOCHS = 50 30 | 31 | 32 | def log(): 33 | logging.info('Loading config of {}'.format(ROOT_DIR)) 34 | 35 | logging.info('BAG_MODE {}'.format('✔' if BAG_MODE else '×')) 36 | logging.info('LOSS_WEIGHT: {}'.format(LOSS_WEIGHT)) 37 | logging.info('EMBEDDING_FINE_TUNE {}'.format('✔' if EMBEDDING_FINE_TUNE else '×')) 38 | logging.info('BIDIRECTIONAL {}'.format('✔' if BIDIRECTIONAL else '×')) 39 | 40 | logging.info('MAX_LENGTH: {}'.format(MAX_LENGTH)) 41 | logging.info('TAG_DIM: {}'.format(TAG_DIM)) 42 | logging.info('HIDDEN_DIM: {}'.format(HIDDEN_DIM)) 43 | 44 | logging.info('DROP_PROB: {}'.format(DROP_PROB)) 45 | logging.info('L2_REG: {}'.format(L2_REG)) 46 | 47 | logging.info('LEARNING_RATE: {}'.format(LEARNING_RATE)) 48 | logging.info('BATCH_SIZE: {}'.format(BATCH_SIZE)) 49 | logging.info('MAX_EPOCHS: {}'.format(MAX_EPOCHS)) 50 | -------------------------------------------------------------------------------- /plot_pr.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | 5 | matplotlib.use('agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from sklearn import metrics 9 | 10 | 11 | def plotPR(config): 12 | result_dir = os.path.join(config.RESULT_DIR, config.DATA_SET) 13 | 14 | recall = np.load(os.path.join(result_dir, "BERE_x.npy")) 15 | precision = np.load(os.path.join(result_dir, "BERE_y.npy")) 16 | auc = metrics.auc(x=recall, y=precision) 17 | f1 = (2 * recall * precision / (recall + precision + 1e-20)).max() 18 | 19 | print('f1: {:.3}'.format(f1)) 20 | print('Area under the curve: {:.3}'.format(auc)) 21 | 22 | plt.plot(recall[:], precision[:], label='BERE' + ': AUPR={0:0.3f}, F1={1:0.3f}'.format(auc,f1), color='red', lw=1, marker='o', 23 | markevery=0.1, ms=6) 24 | plt.xlim(0, 1) 25 | plt.ylim(0, 1) 26 | 27 | base_list = ['BiGRU+2ATT', 'BiGRU+ATT', 'PCNN+ATT', 'PCNN'] 28 | color = ['purple', 'darkorange', 'green', 'xkcd:azure'] 29 | marker = ['d', 's', '^', '*'] 30 | 31 | for i, baseline in enumerate(base_list): 32 | recall = np.load(os.path.join(result_dir, baseline + '_x.npy')) 33 | precision = np.load(os.path.join(result_dir, baseline + '_y.npy')) 34 | auc = metrics.auc(x=recall, y=precision) 35 | f1 = (2 * recall * precision / (recall + precision + 1e-20)).max() 36 | 37 | print("\n[{0}] auc: {1:0.3f} f1: {2:0.3f}".format(baseline, auc, f1)) 38 | # plt.plot(recall, precision, color=color[i], label=baseline, lw=1, marker=marker[i], markevery=0.1, ms=6) 39 | plt.plot(recall, precision, label=baseline + ': AUPR={0:0.3f}, F1={1:0.3f}'.format(auc,f1), color=color[i], lw=1, marker=marker[i], 40 | markevery=0.1, ms=6) 41 | 42 | plt.xlabel('Recall', fontsize=14) 43 | plt.ylabel('Precision', fontsize=14) 44 | plt.legend(loc="upper right", prop={'size': 12}) 45 | plt.grid(True) 46 | plt.tight_layout() 47 | plt.show() 48 | 49 | plot_path = os.path.join(result_dir, "pr.pdf") 50 | plt.savefig(plot_path) 51 | print('Precision-Recall plot saved at: {}'.format(plot_path)) 52 | 53 | 54 | if __name__ == '__main__': 55 | from data.dti import config 56 | 57 | plotPR(config) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERE 2 | Implementation of the paper [A novel machine learning framework for automated biomedical relation extraction from large-scale literature repositories](https://www.nature.com/articles/s42256-020-0189-y). 3 | 4 | ## Environments 5 | Tested on a linux server with GeForce GTX 1080 and the running environment is as follows: 6 | 7 | - Python 3.5.2 8 | 9 | - PyTorch 1.0.0 10 | 11 | - sklearn 0.20.2 12 | 13 | - numpy 1.15.4 14 | 15 | - cuda 9.0 16 | 17 | ## Installation Guide 18 | 1. Download the pretrained word embedding `PubMed-and-PMC-w2v.bin` from http://evexdb.org/pmresources/vec-space-models/ and put it in `./data/`. 19 | 20 | 2. Download the complete DTI dataset from https://www.aliyundrive.com/s/QqSG9H3guEP and put it in `./data/dti/`. 21 | 22 | ## How to Run 23 | [DDI Expirement](less than 1h for one training) 24 | 1. Run `./data/ddi/data_prepare.py` to preprocess the DDI dataset. 25 | 26 | 2. Run `./train_ddi.py` to train BERE with different learning rates. 27 | 28 | 3. Run `./test_ddi.py` to test BERE with the best model. 29 | 30 |   31 | 32 | [DTI Expirement](taking 10~20h before convergence) 33 | 1. Run `./data/dti/data_prepare.py` to preprocess the DTI dataset. 34 | 35 | 2. Run `./train_dti.py` to train BERE with different learning rates. 36 | 37 | 3. Run `./test_dti.py` to test BERE with the best model. 38 | 39 |   40 | 41 | [Demo of DTI Prediction] 42 | 43 | 1. Train the model on the DTI dataset. 44 | 45 | 2. Run `./data/dti/transform.py` to preprocess the `pmc_nintedanib` dataset. 46 | 47 | 3. Run `./predict.py` to predict the targets of the drug nintedanib by the well-trained model. 48 | 49 | ## Data Description 50 | - `PubMed-and-PMC-w2v.bin`: The pretrained word embedding. 51 | - `train.json`, `valid.json`, `test.json`: The original dataset. 52 | - `label2id.json`: The label file. 53 | - `pmc_nintedanib.json`: The data for DTI Prediction demo. 54 | - `tree_examples.json`: The data for visualization demo. 55 | - `config.py`: The hyper-parameter settings. 56 | 57 | ## File Description 58 | - `./data/`: This directory contains DDI dataset, DTI dataset and pretrained word embedding. 59 | 60 | - `./network/`: This directory contains the codes of our model. 61 | 62 | - `./checkpoint/`(generated): This directory contains the checkpoints of model in the training process. 63 | 64 | - `./result/`(generated): This directory contains the test results and prediction results 65 | 66 | - `./output/`(generated): This directory contains the prediction results, which is more convenient for reading. 67 | 68 | - `./train_ddi.py`: This is a demo for training the BERE on DDI dataset. 69 | 70 | - `./train_dti.py`: This is a demo for testing the BERE on DDI dataset. 71 | 72 | - `./test_ddi.py`: This is a demo for training the BERE on DTI dataset. 73 | 74 | - `./test_dti.py`: This is a demo for testing the BERE on DTI dataset. 75 | 76 | - `./predict.py`: This is a demo for predicting the targets of the drug nintedanib. 77 | 78 | - `./plot_pr.py`: This file is used to plot the precision-recall curve of the results in `./result/`. 79 | 80 | - `./visualize.py`(optional): This file is used for the visualization of word attention, sentence attention and sentence tree structures. 81 | 82 | ## Notes 83 | - The full datasets for discovering novel DTIs is available from the corresponding authors upon request. 84 | - If you have any other questions or comments, please feel free to email Lixiang Hong (honglx17[at]mails[dot]tsinghua[dot]edu[dot]cn) and/or Jianyang Zeng (zengjy321[at]tsinghua[dot]edu[dot]cn). 85 | -------------------------------------------------------------------------------- /test_ddi.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | from dataset import * 4 | from network.model import * 5 | 6 | logging.basicConfig(level=logging.INFO, 7 | format='%(asctime)s %(levelname)-8s %(message)s') 8 | 9 | DEVICE = 'cuda:0' 10 | 11 | 12 | def test(config, model_name=None): 13 | if config.BAG_MODE: 14 | REModel = REModel_BAG 15 | DataLoader = DataLoader_BAG 16 | 17 | else: 18 | REModel = REModel_INS 19 | DataLoader = DataLoader_INS 20 | 21 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 22 | 23 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 24 | logging.info('Number of classes: {}'.format(vocab.class_num)) 25 | 26 | test_dataset = torch.load(os.path.join(config.ROOT_DIR, 'test.pt')) 27 | test_loader = DataLoader(test_dataset, config.BATCH_SIZE, collate_fn=test_dataset.collate, shuffle=False) 28 | 29 | logging.info('Number of test pair: {}'.format(len(test_dataset))) 30 | 31 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 32 | max_length=config.MAX_LENGTH, 33 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 34 | bidirectional=config.BIDIRECTIONAL) 35 | 36 | num_params = sum(np.prod(p.size()) for p in model.parameters()) 37 | num_embedding_params = np.prod(model.word_emb.weight.size()) + np.prod(model.tag_emb.weight.size()) 38 | print('# of parameters: {}'.format(num_params)) 39 | print('# of word embedding parameters: {}'.format(num_embedding_params)) 40 | print('# of parameters (excluding embeddings): {}'.format(num_params - num_embedding_params)) 41 | 42 | if model_name is None: 43 | model_path = utils.best_model_path(config.SAVE_DIR, config.DATA_SET, i=0) 44 | logging.info('Loading the best model on validation set: {}'.format(model_path)) 45 | model.load_state_dict(torch.load(model_path, map_location='cpu')) 46 | else: 47 | model_path = os.path.join(config.SAVE_DIR, config.DATA_SET, model_name) 48 | logging.info('Loading the model: {}'.format(model_path)) 49 | model.load_state_dict( 50 | torch.load(model_path, map_location='cpu')) 51 | model.eval() 52 | model.to(DEVICE) 53 | model.display() 54 | 55 | torch.set_grad_enabled(False) 56 | 57 | def run_iter(batch): 58 | sent = batch['sent'].to(DEVICE) 59 | tag = batch['tag'].to(DEVICE) 60 | pos1 = batch['pos1'].to(DEVICE) 61 | pos2 = batch['pos2'].to(DEVICE) 62 | length = batch['length'].to(DEVICE) 63 | 64 | label = batch['label'] 65 | id = batch['id'] 66 | scope = batch['scope'] 67 | 68 | logits = model(sent, tag, length) 69 | label_pred = logits.max(1)[1] 70 | 71 | return label_pred.cpu() 72 | 73 | test_labels = [] 74 | test_preds = [] 75 | 76 | for test_batch in test_loader: 77 | test_pred = run_iter(batch=test_batch) 78 | 79 | test_labels.extend(test_batch['label']) 80 | test_preds.extend(test_pred) 81 | 82 | test_p, test_r, test_f1, _ = metrics.precision_recall_fscore_support(test_labels, test_preds, 83 | labels=[1, 2, 3, 4], 84 | average='micro') 85 | 86 | logging.info( 87 | 'precision = {:.4f}: recall = {:.4f}, fscore = {:.4f}'.format(test_p, test_r, test_f1)) 88 | 89 | 90 | if __name__ == '__main__': 91 | from data.ddi import config 92 | 93 | test(config) 94 | -------------------------------------------------------------------------------- /test_dti.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from sklearn import metrics 3 | 4 | from dataset import * 5 | from network.model import * 6 | 7 | logging.basicConfig(level=logging.INFO, 8 | format='%(asctime)s %(levelname)-8s %(message)s') 9 | 10 | DEVICE = 'cuda:0' 11 | 12 | 13 | def test(config, model_name=None): 14 | if config.BAG_MODE: 15 | REModel = REModel_BAG 16 | DataLoader = DataLoader_BAG 17 | 18 | else: 19 | REModel = REModel_INS 20 | DataLoader = DataLoader_INS 21 | 22 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 23 | 24 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 25 | logging.info('Number of classes: {}'.format(vocab.class_num)) 26 | 27 | test_dataset = torch.load(os.path.join(config.ROOT_DIR, 'test.pt')) 28 | test_loader = DataLoader(test_dataset, config.BATCH_SIZE, collate_fn=test_dataset.collate, shuffle=False) 29 | 30 | test_labels = numpy.array(test_dataset.get_labels()) 31 | test_rel_num = sum(test_labels != vocab.NA_id) 32 | 33 | logging.info('Number of test pair: {}'.format(len(test_dataset))) 34 | 35 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 36 | max_length=config.MAX_LENGTH, 37 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 38 | bidirectional=config.BIDIRECTIONAL) 39 | 40 | num_params = sum(np.prod(p.size()) for p in model.parameters()) 41 | num_embedding_params = np.prod(model.word_emb.weight.size()) + np.prod(model.tag_emb.weight.size()) 42 | print('# of parameters: {}'.format(num_params)) 43 | print('# of word embedding parameters: {}'.format(num_embedding_params)) 44 | print('# of parameters (excluding embeddings): {}'.format(num_params - num_embedding_params)) 45 | 46 | if model_name is None: 47 | model_path = utils.best_model_path(config.SAVE_DIR, config.DATA_SET, i=0) 48 | logging.info('Loading the best model on validation set: {}'.format(model_path)) 49 | model.load_state_dict(torch.load(model_path, map_location='cpu')) 50 | else: 51 | model_path = os.path.join(config.SAVE_DIR, config.DATA_SET, model_name) 52 | logging.info('Loading the model: {}'.format(model_path)) 53 | model.load_state_dict( 54 | torch.load(model_path, map_location='cpu')) 55 | model.eval() 56 | model.to(DEVICE) 57 | model.display() 58 | 59 | torch.set_grad_enabled(False) 60 | 61 | def run_iter(batch): 62 | sent = batch['sent'].to(DEVICE) 63 | tag = batch['tag'].to(DEVICE) 64 | 65 | length = batch['length'].to(DEVICE) 66 | scope = batch['scope'] 67 | 68 | logit = model(sent, tag, length, scope) 69 | 70 | return logit.cpu() 71 | 72 | test_result = [] 73 | test_preds = [] 74 | 75 | for test_batch in test_loader: 76 | test_logit = run_iter(batch=test_batch) 77 | test_pred = test_logit.max(1)[1] 78 | test_preds.extend(test_pred) 79 | for idx in range(len(test_logit)): 80 | for rel in range(1, vocab.class_num): 81 | test_result.append( 82 | {'score': test_logit[idx][rel], 'flag': test_batch['label'][idx] == rel}) 83 | 84 | sorted_test_result = sorted(test_result, key=lambda x: x['score']) 85 | 86 | prec = [] 87 | recall = [] 88 | correct = 0 89 | for i, item in enumerate(sorted_test_result[::-1]): 90 | correct += int(item['flag']) 91 | prec.append(float(correct) / (i + 1)) 92 | recall.append(float(correct) / test_rel_num) 93 | 94 | x, y = np.array(recall), np.array(prec) 95 | 96 | auc = metrics.auc(x=x, y=y) 97 | 98 | logging.info('auc = {:.4f}'.format(auc)) 99 | test_preds = [int(t) for t in test_preds] 100 | 101 | test_p, test_r, test_f1, _ = metrics.precision_recall_fscore_support(test_labels, test_preds, 102 | labels=[1, 2, 3, 4, 5], 103 | average='micro') 104 | 105 | logging.info( 106 | 'precision = {:.4f}: recall = {:.4f}, fscore = {:.4f}'.format(test_p, test_r, test_f1)) 107 | 108 | result_dir = os.path.join(config.RESULT_DIR, config.DATA_SET) 109 | if not os.path.isdir(result_dir): 110 | os.makedirs(result_dir) 111 | 112 | np.save(os.path.join(result_dir, "BERE_x.npy"), x) 113 | np.save(os.path.join(result_dir, "BERE_y.npy"), y) 114 | 115 | 116 | if __name__ == '__main__': 117 | from data.dti import config 118 | 119 | test(config) 120 | -------------------------------------------------------------------------------- /network/utils.py: -------------------------------------------------------------------------------- 1 | """Basic or helper implementation.""" 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from torch.nn import functional 8 | 9 | 10 | def convert_to_one_hot(indices, num_classes): 11 | """ 12 | Args: 13 | indices (tensor): A vector containing indices, 14 | whose size is (batch_size,). 15 | num_classes (tensor): The number of classes, which would be 16 | the second dimension of the resulting one-hot matrix. 17 | 18 | Returns: 19 | result: The one-hot matrix of size (batch_size, num_classes). 20 | """ 21 | 22 | batch_size = indices.size(0) 23 | indices = indices.unsqueeze(1) 24 | one_hot = indices.new_zeros(batch_size, num_classes).scatter_(1, indices, 1) 25 | return one_hot 26 | 27 | 28 | def masked_softmax(logits, mask=None): 29 | eps = 1e-20 30 | probs = functional.softmax(logits, dim=1) 31 | if mask is not None: 32 | mask = mask.float() 33 | probs = probs * mask + eps 34 | probs = probs / probs.sum(1, keepdim=True) 35 | return probs 36 | 37 | 38 | def greedy_select(logits, mask=None): 39 | probs = masked_softmax(logits=logits, mask=mask) 40 | one_hot = convert_to_one_hot(indices=probs.max(1)[1], 41 | num_classes=logits.size(1)) 42 | return one_hot 43 | 44 | 45 | def st_gumbel_softmax(logits, temperature=1.0, mask=None): 46 | """ 47 | Return the result of Straight-Through Gumbel-Softmax Estimation. 48 | It approximates the discrete sampling via Gumbel-Softmax trick 49 | and applies the biased ST estimator. 50 | In the forward propagation, it emits the discrete one-hot result, 51 | and in the backward propagation it approximates the categorical 52 | distribution via smooth Gumbel-Softmax distribution. 53 | 54 | Args: 55 | logits (tensor): A un-normalized probability values, 56 | which has the size (batch_size, num_classes) 57 | temperature (float): A temperature parameter. The higher 58 | the value is, the smoother the distribution is. 59 | mask (tensor, optional): If given, it masks the softmax 60 | so that indices of '0' mask values are not selected. 61 | The size is (batch_size, num_classes). 62 | 63 | Returns: 64 | y: The sampled output, which has the property explained above. 65 | """ 66 | 67 | eps = 1e-20 68 | u = logits.data.new(*logits.size()).uniform_() 69 | gumbel_noise = -torch.log(-torch.log(u + eps) + eps) 70 | y = logits + gumbel_noise 71 | y = masked_softmax(logits=y / temperature, mask=mask) 72 | y_argmax = y.max(1)[1] 73 | y_hard = convert_to_one_hot(indices=y_argmax, num_classes=y.size(1)).float() 74 | y = (y_hard - y).detach() + y 75 | return y 76 | 77 | 78 | def sequence_mask(length, max_length=None): 79 | if max_length is None: 80 | max_length = length.max() 81 | batch_size = length.size(0) 82 | seq_range = torch.arange(0, max_length).long() 83 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_length) 84 | seq_range_expand = seq_range_expand.to(length) 85 | seq_length_expand = length.unsqueeze(1).expand_as(seq_range_expand) 86 | return seq_range_expand < seq_length_expand 87 | 88 | 89 | def reverse_padded_sequence(inputs, lengths, batch_first=True): 90 | """Reverses sequences according to their lengths. 91 | Inputs should have size ``T x B x *`` if ``batch_first`` is False, or 92 | ``B x T x *`` if True. T is the length of the longest sequence (or larger), 93 | B is the batch size, and * is any number of dimensions (including 0). 94 | Arguments: 95 | inputs (tensor): padded batch of variable length sequences. 96 | lengths (list[int]): list of sequence lengths 97 | batch_first (bool, optional): if True, inputs should be B x T x *. 98 | Returns: 99 | A tensor with the same size as inputs, but with each sequence 100 | reversed according to its length. 101 | """ 102 | 103 | if not batch_first: 104 | inputs = inputs.transpose(0, 1) 105 | if inputs.size(0) != len(lengths): 106 | raise ValueError('inputs incompatible with lengths.') 107 | reversed_indices = [list(range(inputs.size(1))) 108 | for _ in range(inputs.size(0))] 109 | for i, length in enumerate(lengths): 110 | if length > 0: 111 | reversed_indices[i][:length] = reversed_indices[i][length - 1::-1] 112 | reversed_indices = (torch.LongTensor(reversed_indices).unsqueeze(2) 113 | .expand_as(inputs)) 114 | reversed_indices = reversed_indices.to(inputs.device) 115 | reversed_inputs = torch.gather(inputs, 1, reversed_indices) 116 | if not batch_first: 117 | reversed_inputs = reversed_inputs.transpose(0, 1) 118 | return reversed_inputs 119 | 120 | 121 | def non_padding_mask(seq): 122 | return seq.ne(0).type(torch.float).unsqueeze(-1) 123 | 124 | 125 | def value_mask(seq, value): 126 | return seq.eq(value).type(torch.float).unsqueeze(-1) 127 | 128 | 129 | def padding_mask(seq): 130 | len_q = seq.size(1) 131 | pad_mask = seq.eq(0).unsqueeze(1).expand(-1, len_q, -1) # shape [B, L_q, L_k] 132 | return pad_mask 133 | 134 | 135 | def best_model_path(model_dir, model_name, i = 0): 136 | paths = glob.glob(os.path.join(model_dir, model_name, model_name + '*')) 137 | paths.sort(reverse=True) 138 | return paths[i] 139 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | from dataset import * 4 | from network.model import * 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | logging.basicConfig(level=logging.INFO, 9 | format='%(asctime)s %(levelname)-8s %(message)s') 10 | 11 | DEVICE = 'cuda:0' 12 | 13 | 14 | def predict(config, model_name, data_name): 15 | if config.BAG_MODE: 16 | REModel = REModel_BAG 17 | DataLoader = DataLoader_BAG 18 | 19 | else: 20 | REModel = REModel_INS 21 | DataLoader = DataLoader_INS 22 | 23 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 24 | 25 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 26 | logging.info('Number of classes: {}'.format(vocab.class_num)) 27 | 28 | predict_dataset = torch.load(os.path.join(config.ROOT_DIR, data_name + '.pt')) 29 | predict_loader = DataLoader(predict_dataset, batch_size=config.BATCH_SIZE, collate_fn=predict_dataset.collate, 30 | shuffle=False) 31 | 32 | logging.info('Number of predict pair: {}'.format(len(predict_dataset))) 33 | 34 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 35 | max_length=config.MAX_LENGTH, 36 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 37 | bidirectional=config.BIDIRECTIONAL) 38 | 39 | num_params = sum(np.prod(p.size()) for p in model.parameters()) 40 | num_embedding_params = np.prod(model.word_emb.weight.size()) + np.prod(model.tag_emb.weight.size()) 41 | print('# of parameters: {}'.format(num_params)) 42 | print('# of word embedding parameters: {}'.format(num_embedding_params)) 43 | print('# of parameters (excluding embeddings): {}'.format(num_params - num_embedding_params)) 44 | 45 | model.load_state_dict( 46 | torch.load(os.path.join(config.SAVE_DIR, config.DATA_SET, model_name), map_location='cpu')) 47 | model.eval() 48 | model.to(DEVICE) 49 | model.display() 50 | 51 | torch.set_grad_enabled(False) 52 | 53 | logging.info('Using device {}'.format(DEVICE)) 54 | 55 | predict_ids = [] 56 | predict_labels = [] 57 | predict_logits = [] 58 | predict_preds = [] 59 | 60 | predict_result = [] 61 | 62 | def run_iter(batch): 63 | sent = batch['sent'].to(DEVICE) 64 | tag = batch['tag'].to(DEVICE) 65 | 66 | length = batch['length'].to(DEVICE) 67 | 68 | label = batch['label'] 69 | id = batch['id'] 70 | scope = batch['scope'] 71 | 72 | logits = model(sent, tag, length, scope) 73 | logits = F.softmax(logits, dim=1) 74 | label_pred = logits.max(1)[1] 75 | 76 | return id, label, logits.detach().cpu(), label_pred.detach().cpu() 77 | 78 | for batch in tqdm(predict_loader): 79 | id, label, logits, label_pred = run_iter(batch) 80 | 81 | predict_ids.extend(id) 82 | predict_labels.extend(label) 83 | predict_logits.extend(logits) 84 | predict_preds.extend(label_pred) 85 | 86 | result = metrics.precision_recall_fscore_support(predict_labels, predict_preds, labels=[1], average='micro') 87 | 88 | for i in range(len(predict_dataset)): 89 | j = np.argmax(predict_logits[i]) 90 | if j > 0: 91 | predict_result.append({'pair_id': predict_ids[i], 'score': float(predict_logits[i][j]), 92 | 'relation': int(j)}) 93 | 94 | logging.info( 95 | 'precision = {:.4f}: recall = {:.4f}, fscore = {:.4f}'.format(result[0], result[1], result[2])) 96 | 97 | predict_result.sort(key=lambda x: x['score'], reverse=True) 98 | if not os.path.isdir(config.RESULT_DIR): 99 | os.makedirs(config.RESULT_DIR) 100 | logging.info('Save result to {}'.format(config.RESULT_DIR)) 101 | json.dump(predict_result, open(os.path.join(config.RESULT_DIR, config.DATA_SET + '_' + data_name + '.json'), 'w')) 102 | 103 | 104 | def output(data_name): 105 | output_data = OrderedDict() 106 | 107 | predict_data = json.load(open(os.path.join(config.RESULT_DIR, config.DATA_SET + '_' + data_name + '.json'), 'r')) 108 | origin_data = json.load(open(os.path.join(config.ROOT_DIR, data_name + '.json'), 'r')) 109 | label2id = json.load(open(os.path.join(config.ROOT_DIR, 'label2id'+ '.json'), 'r')) 110 | id2label = {v: k for k, v in label2id.items()} 111 | 112 | for item in predict_data: 113 | pair_id = item['pair_id'].split('#') 114 | drug_id = pair_id[0] 115 | target_id = pair_id[1] 116 | rel = item['relation'] 117 | score = item['score'] 118 | output_data[(drug_id, target_id)] = {'drug_id': drug_id, 'target_id': target_id, 'relation': id2label[rel], 119 | 'score': score, 'supporting_entry': []} 120 | 121 | for item in origin_data: 122 | drug_id = item['head']['id'] 123 | target_id = item['tail']['id'] 124 | 125 | if (drug_id, target_id) in output_data: 126 | try: 127 | pmid = item['pmid'] 128 | except: 129 | pmid = None 130 | drug_name = item['head']['word'] 131 | target_name = item['tail']['word'] 132 | sentence = item['sentence'] 133 | output_data[(drug_id, target_id)]['drugbank_relation'] = item['relation'] 134 | output_data[(drug_id, target_id)]['supporting_entry'].append( 135 | {'pmid': pmid, 'sentence': sentence, 'drug': drug_name, 'target': target_name}) 136 | 137 | if not os.path.isdir(config.OUTPUT_DIR): 138 | os.makedirs(config.OUTPUT_DIR) 139 | logging.info('Save result to {}'.format(config.OUTPUT_DIR)) 140 | result = list(output_data.values()) 141 | json.dump(result, 142 | open(os.path.join(config.OUTPUT_DIR, config.DATA_SET + '_' + data_name + '.json'), 'w')) 143 | 144 | 145 | if __name__ == '__main__': 146 | from data.dti import config 147 | 148 | predict(config, 'dti-0.5419.pkl', 'pmc_nintedanib') 149 | 150 | output(data_name='pmc_nintedanib') 151 | -------------------------------------------------------------------------------- /train_ddi.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | from torch import optim 3 | from torch.nn.utils import clip_grad_norm_ 4 | 5 | from dataset import * 6 | from network.model import * 7 | 8 | logging.basicConfig(level=logging.INFO, 9 | format='%(asctime)s %(levelname)-8s %(message)s') 10 | 11 | DEVICE = 'cuda:1' 12 | VALID_TIMES = 20 13 | 14 | 15 | def train(config, log_path): 16 | """ 17 | Before training, the input with '.json' format must be transformed into '.pt' 18 | format by 'data_prepare.py'. This process will also generate the 'vocab.pt' 19 | file which contains the basic statistics of the corpus. 20 | """ 21 | log_f = open(log_path, 'a') 22 | 23 | if config.BAG_MODE: 24 | REModel = REModel_BAG 25 | DataLoader = DataLoader_BAG 26 | 27 | else: 28 | REModel = REModel_INS 29 | DataLoader = DataLoader_INS 30 | 31 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 32 | 33 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 34 | logging.info('Number of classes: {}'.format(vocab.class_num)) 35 | 36 | train_dataset = torch.load(os.path.join(config.ROOT_DIR, 'train.pt')) 37 | train_loader = DataLoader(train_dataset, config.BATCH_SIZE, collate_fn=train_dataset.collate, shuffle=True) 38 | 39 | valid_dataset = torch.load(os.path.join(config.ROOT_DIR, 'valid.pt')) 40 | valid_loader = DataLoader(valid_dataset, config.BATCH_SIZE, collate_fn=valid_dataset.collate, shuffle=False) 41 | 42 | logging.info('Number of train pair: {}'.format(len(train_dataset))) 43 | logging.info('Number of valid pair: {}'.format(len(valid_dataset))) 44 | 45 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 46 | max_length=config.MAX_LENGTH, 47 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 48 | bidirectional=config.BIDIRECTIONAL) 49 | 50 | if not config.EMBEDDING_FINE_TUNE: 51 | model.word_emb.weight.requires_grad = False 52 | 53 | logging.info('Using device {}'.format(DEVICE)) 54 | 55 | model.to(DEVICE) 56 | model.display() 57 | 58 | weight = torch.FloatTensor(config.LOSS_WEIGHT) if config.LOSS_WEIGHT else None 59 | 60 | criterion = nn.CrossEntropyLoss(weight=weight, reduction='mean').to(DEVICE) 61 | 62 | params = [p for p in model.parameters() if p.requires_grad] 63 | optimizer = optim.Adam(params, lr=config.LEARNING_RATE, weight_decay=config.L2_REG) 64 | 65 | validate_every = len(train_loader) // VALID_TIMES 66 | 67 | def run_iter(batch, is_training): 68 | model.train(is_training) 69 | 70 | sent = batch['sent'].to(DEVICE) 71 | tag = batch['tag'].to(DEVICE) 72 | pos1 = batch['pos1'].to(DEVICE) 73 | pos2 = batch['pos2'].to(DEVICE) 74 | length = batch['length'].to(DEVICE) 75 | 76 | label = batch['label'].to(DEVICE) 77 | id = batch['id'] 78 | scope = batch['scope'] 79 | 80 | logits = model(sent, tag, length) 81 | 82 | loss = criterion(input=logits, target=label) 83 | 84 | label_pred = logits.max(1)[1] 85 | 86 | if is_training: 87 | optimizer.zero_grad() 88 | loss.backward() 89 | clip_grad_norm_(parameters=params, max_norm=5) 90 | optimizer.step() 91 | 92 | return loss, label_pred.cpu() 93 | 94 | save_dir = os.path.join(config.SAVE_DIR, config.DATA_SET) 95 | if not os.path.isdir(save_dir): 96 | os.makedirs(save_dir) 97 | 98 | best_f1 = 0 99 | 100 | for epoch_num in range(config.MAX_EPOCHS): 101 | logging.info('Epoch {}: start'.format(epoch_num)) 102 | 103 | train_labels = [] 104 | train_preds = [] 105 | 106 | for batch_iter, train_batch in enumerate(train_loader): 107 | train_loss, train_pred = run_iter(batch=train_batch, is_training=True) 108 | 109 | train_labels.extend(train_batch['label']) 110 | train_preds.extend(train_pred) 111 | 112 | if (batch_iter + 1) % validate_every == 0: 113 | 114 | torch.set_grad_enabled(False) 115 | 116 | valid_loss_sum = 0 117 | 118 | valid_labels = [] 119 | valid_preds = [] 120 | 121 | for valid_batch in valid_loader: 122 | valid_loss, valid_pred = run_iter(batch=valid_batch, is_training=False) 123 | 124 | valid_loss_sum += valid_loss.item() 125 | 126 | valid_labels.extend(valid_batch['label']) 127 | valid_preds.extend(valid_pred) 128 | 129 | torch.set_grad_enabled(True) 130 | 131 | valid_loss = valid_loss_sum / len(valid_loader) 132 | valid_p, valid_r, valid_f1, _ = metrics.precision_recall_fscore_support(valid_labels, valid_preds, 133 | labels=[1, 2, 3, 4], 134 | average='micro') 135 | 136 | train_f1 = metrics.f1_score(train_labels, train_preds, [1, 2, 3, 4], average='micro') 137 | 138 | progress = epoch_num + (batch_iter + 1) / len(train_loader) 139 | 140 | logging.info( 141 | 'Epoch {:.2f}: train loss = {:.4f}, train f1 = {:.4f}, valid loss = {:.4f}, valid f1 = {:.4f}'.format( 142 | progress, train_loss, train_f1, valid_loss, valid_f1)) 143 | 144 | if valid_f1 > best_f1: 145 | best_f1 = valid_f1 146 | model_filename = ('{}-{:.4f}.pkl'.format(config.DATA_SET, valid_f1)) 147 | model_path = os.path.join(save_dir, model_filename) 148 | torch.save(model.state_dict(), model_path) 149 | print('Saved the new best model to {}'.format(model_path)) 150 | 151 | log_f.write('{}\tlr={}\n'.format(model_filename, config.LEARNING_RATE)) 152 | log_f.flush() 153 | 154 | return best_f1 155 | 156 | 157 | if __name__ == '__main__': 158 | from data.ddi import config 159 | 160 | for lr in range(1, 11): 161 | config.LEARNING_RATE = lr / 10000.0 162 | config.log() 163 | F = train(config, 'ddi.log') 164 | -------------------------------------------------------------------------------- /train_dti.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | from torch import optim 3 | from torch.nn.utils import clip_grad_norm_ 4 | 5 | from dataset import * 6 | from network.model import * 7 | 8 | logging.basicConfig(level=logging.INFO, 9 | format='%(asctime)s %(levelname)-8s %(message)s') 10 | 11 | DEVICE = 'cuda:1' 12 | VALID_TIMES = 20 13 | 14 | 15 | def train(config, log_path): 16 | """ 17 | Before training, the input with '.json' format must be transformed into '.pt' 18 | format by 'data_prepare.py'. This process will also generate the 'vocab.pt' 19 | file which contains the basic statistics of the corpus. 20 | """ 21 | log_f = open(log_path, 'a') 22 | 23 | if config.BAG_MODE: 24 | REModel = REModel_BAG 25 | DataLoader = DataLoader_BAG 26 | 27 | else: 28 | REModel = REModel_INS 29 | DataLoader = DataLoader_INS 30 | 31 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 32 | 33 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 34 | logging.info('Number of classes: {}'.format(vocab.class_num)) 35 | 36 | train_dataset = torch.load(os.path.join(config.ROOT_DIR, 'train.pt')) 37 | train_loader = DataLoader(train_dataset, config.BATCH_SIZE, collate_fn=train_dataset.collate, shuffle=True) 38 | 39 | valid_dataset = torch.load(os.path.join(config.ROOT_DIR, 'valid.pt')) 40 | valid_loader = DataLoader(valid_dataset, config.BATCH_SIZE, collate_fn=valid_dataset.collate, shuffle=False) 41 | 42 | valid_labels = np.array(valid_dataset.get_labels()) 43 | valid_rel_num = sum(valid_labels != vocab.NA_id) 44 | 45 | logging.info('Number of train pair: {}'.format(len(train_dataset))) 46 | logging.info('Number of valid pair: {}'.format(len(valid_dataset))) 47 | 48 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 49 | max_length=config.MAX_LENGTH, 50 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 51 | bidirectional=config.BIDIRECTIONAL) 52 | 53 | if not config.EMBEDDING_FINE_TUNE: 54 | model.word_emb.weight.requires_grad = False 55 | 56 | logging.info('Using device {}'.format(DEVICE)) 57 | 58 | model.to(DEVICE) 59 | model.display() 60 | 61 | weight = torch.FloatTensor(config.LOSS_WEIGHT) if config.LOSS_WEIGHT else None 62 | 63 | criterion = nn.CrossEntropyLoss(weight=weight, reduction='mean').to(DEVICE) 64 | 65 | params = [p for p in model.parameters() if p.requires_grad] 66 | optimizer = optim.Adam(params, lr=config.LEARNING_RATE, weight_decay=config.L2_REG) 67 | 68 | validate_every = len(train_loader) // VALID_TIMES 69 | 70 | def run_iter(batch, is_training): 71 | model.train(is_training) 72 | 73 | sent = batch['sent'].to(DEVICE) 74 | 75 | tag = batch['tag'].to(DEVICE) 76 | pos1 = batch['pos1'].to(DEVICE) 77 | pos2 = batch['pos2'].to(DEVICE) 78 | length = batch['length'].to(DEVICE) 79 | 80 | label = batch['label'].to(DEVICE) 81 | id = batch['id'] 82 | scope = batch['scope'] 83 | 84 | logit = model(sent, tag, length, scope) 85 | 86 | loss = criterion(input=logit, target=label) 87 | 88 | label_pred = logit.max(1)[1] 89 | 90 | if is_training: 91 | optimizer.zero_grad() 92 | loss.backward() 93 | clip_grad_norm_(parameters=params, max_norm=5) 94 | optimizer.step() 95 | 96 | return loss, logit.cpu(), label_pred.cpu() 97 | 98 | def test(): 99 | torch.set_grad_enabled(False) 100 | 101 | valid_loss_sum = 0 102 | 103 | test_result = [] 104 | 105 | for valid_batch in valid_loader: 106 | valid_loss, valid_logit, valid_pred = run_iter(batch=valid_batch, is_training=False) 107 | 108 | valid_loss_sum += valid_loss.item() 109 | 110 | for idx in range(len(valid_logit)): 111 | for rel in range(1, vocab.class_num): 112 | test_result.append( 113 | {'score': valid_logit[idx][rel], 'flag': valid_batch['label'][idx] == rel}) 114 | 115 | torch.set_grad_enabled(True) 116 | 117 | sorted_test_result = sorted(test_result, key=lambda x: x['score']) 118 | prec = [] 119 | recall = [] 120 | correct = 0 121 | for i, item in enumerate(sorted_test_result[::-1]): 122 | correct += int(item['flag']) 123 | prec.append(float(correct) / (i + 1)) 124 | recall.append(float(correct) / valid_rel_num) 125 | 126 | x, y = np.array(recall), np.array(prec) 127 | 128 | auc = metrics.auc(x=x, y=y) 129 | loss = valid_loss_sum / len(valid_loader) 130 | 131 | return auc, loss 132 | 133 | save_dir = os.path.join(config.SAVE_DIR, config.DATA_SET) 134 | if not os.path.isdir(save_dir): 135 | os.makedirs(save_dir) 136 | 137 | best_metric = 0 138 | 139 | for epoch_num in range(config.MAX_EPOCHS): 140 | logging.info('Epoch {}: start'.format(epoch_num)) 141 | 142 | train_labels = [] 143 | train_preds = [] 144 | 145 | for batch_iter, train_batch in enumerate(train_loader): 146 | train_loss, train_logit, train_pred = run_iter(batch=train_batch, is_training=True) 147 | 148 | train_labels.extend(train_batch['label']) 149 | train_preds.extend(train_pred) 150 | 151 | if (batch_iter + 1) % validate_every == 0: 152 | 153 | valid_auc, valid_loss = test() 154 | 155 | train_f1 = metrics.f1_score(train_labels, train_preds, [1, 2, 3, 4, 5], average='micro') 156 | 157 | progress = epoch_num + (batch_iter + 1) / len(train_loader) 158 | 159 | logging.info( 160 | 'Epoch {:.2f}: train loss = {:.4f}, train f1 = {:.4f}, valid loss = {:.4f}, valid auc = {:.4f}'.format( 161 | progress, train_loss, train_f1, valid_loss, valid_auc)) 162 | 163 | if valid_auc > best_metric: 164 | best_metric = valid_auc 165 | model_filename = ('{}-{:.4f}.pkl'.format(config.DATA_SET, valid_auc)) 166 | model_path = os.path.join(save_dir, model_filename) 167 | torch.save(model.state_dict(), model_path) 168 | print('Saved the new best model to {}'.format(model_path)) 169 | 170 | log_f.write('{}\tlr={}\n'.format(model_filename, config.LEARNING_RATE)) 171 | log_f.flush() 172 | 173 | return best_metric 174 | 175 | 176 | if __name__ == '__main__': 177 | from data.dti import config 178 | 179 | for lr in range(1, 11): 180 | config.LEARNING_RATE = lr / 10000.0 181 | config.log() 182 | F = train(config, 'dti.log') 183 | -------------------------------------------------------------------------------- /network/encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import init 6 | 7 | from network import utils as utils 8 | 9 | class MultiAttn(nn.Module): 10 | def __init__(self, in_dim, head_num=10): 11 | super(MultiAttn, self).__init__() 12 | 13 | self.head_dim = in_dim // head_num 14 | self.head_num = head_num 15 | 16 | # scaled dot product attention 17 | self.scale = self.head_dim ** -0.5 18 | 19 | self.w_qs = nn.Linear(in_dim, head_num * self.head_dim, bias=True) 20 | self.w_ks = nn.Linear(in_dim, head_num * self.head_dim, bias=True) 21 | self.w_vs = nn.Linear(in_dim, head_num * self.head_dim, bias=True) 22 | 23 | self.w_os = nn.Linear(head_num * self.head_dim, in_dim, bias=True) 24 | 25 | self.gamma = nn.Parameter(torch.FloatTensor([0])) 26 | 27 | self.softmax = nn.Softmax(dim=-1) 28 | 29 | def forward(self, x, attn_mask, non_pad_mask): 30 | B, L, H = x.size() 31 | head_num = self.head_num 32 | head_dim = self.head_dim 33 | 34 | q = self.w_qs(x).view(B * head_num, L, head_dim) 35 | k = self.w_ks(x).view(B * head_num, L, head_dim) 36 | v = self.w_vs(x).view(B * head_num, L, head_dim) 37 | 38 | attn_mask = attn_mask.repeat(head_num, 1, 1) 39 | 40 | attn = torch.bmm(q, k.transpose(1, 2)) # B*head_num, L, L 41 | attn = self.scale * attn 42 | attn = attn.masked_fill_(attn_mask, -np.inf) 43 | attn = self.softmax(attn) 44 | 45 | out = torch.bmm(attn, v) # B*head_num, L, head_dim 46 | 47 | out = out.view(B, L, head_dim * head_num) 48 | 49 | out = self.w_os(out) 50 | 51 | out = non_pad_mask * out 52 | 53 | out = self.gamma * out + x 54 | 55 | return out, attn 56 | 57 | 58 | class PackedGRU(nn.Module): 59 | def __init__(self, in_dim, hid_dim, bidirectional=True): 60 | super(PackedGRU, self).__init__() 61 | 62 | self.gru = nn.GRU(in_dim, hid_dim, batch_first=True, bidirectional=bidirectional) 63 | 64 | def forward(self, x, length): 65 | packed = torch.nn.utils.rnn.pack_padded_sequence(x, length, batch_first=True) 66 | out, _ = self.gru(packed) 67 | out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 68 | 69 | return out 70 | 71 | 72 | class LeafRNN(nn.Module): 73 | def __init__(self, in_dim, hid_dim, bidirectional=True): 74 | super(LeafRNN, self).__init__() 75 | self.bidirectional = bidirectional 76 | 77 | self.leaf_rnn = nn.GRU(in_dim, hid_dim, batch_first=True) 78 | 79 | if self.bidirectional: 80 | self.leaf_rnn_bw = nn.GRU(in_dim, hid_dim, batch_first=True) 81 | 82 | def forward(self, x, non_pad_mask, length=None): 83 | out, _ = self.leaf_rnn(x) 84 | out = non_pad_mask * out 85 | 86 | if self.bidirectional: 87 | in_bw = utils.reverse_padded_sequence(x, length, batch_first=True) 88 | out_bw, _ = self.leaf_rnn_bw(in_bw) 89 | out_bw = non_pad_mask * out_bw 90 | out_bw = utils.reverse_padded_sequence(out_bw, length, batch_first=True) 91 | out = torch.cat([out, out_bw], -1) 92 | 93 | return out 94 | 95 | 96 | class BinaryTreeGRULayer(nn.Module): 97 | def __init__(self, hidden_dim): 98 | super(BinaryTreeGRULayer, self).__init__() 99 | 100 | self.fc1 = nn.Linear(in_features=2 * hidden_dim, out_features=3 * hidden_dim) 101 | self.fc2 = nn.Linear(in_features=2 * hidden_dim, out_features=hidden_dim) 102 | 103 | def forward(self, hl, hr): 104 | """ 105 | Args: 106 | hl: (batch_size, max_length, hidden_dim). 107 | hr: (batch_size, max_length, hidden_dim). 108 | Returns: 109 | h: (batch_size, max_length, hidden_dim). 110 | """ 111 | 112 | hlr_cat1 = torch.cat([hl, hr], dim=-1) 113 | treegru_vector = self.fc1(hlr_cat1) 114 | i, f, r = treegru_vector.chunk(chunks=3, dim=-1) 115 | 116 | hlr_cat2 = torch.cat([hl * r.sigmoid(), hr * r.sigmoid()], dim=-1) 117 | 118 | h_hat = self.fc2(hlr_cat2) 119 | 120 | h = (hl + hr) * f.sigmoid() + h_hat.tanh() * i.sigmoid() 121 | 122 | return h 123 | 124 | 125 | class GumbelTreeGRU(nn.Module): 126 | def __init__(self, hidden_dim): 127 | super(GumbelTreeGRU, self).__init__() 128 | self.hidden_dim = hidden_dim 129 | 130 | self.gumbel_temperature = nn.Parameter(torch.FloatTensor([1])) 131 | 132 | self.treegru_layer = BinaryTreeGRULayer(hidden_dim) 133 | 134 | self.comp_query = nn.Parameter(torch.FloatTensor(hidden_dim)) 135 | init.normal_(self.comp_query.data, mean=0, std=0.01) 136 | 137 | self.query_layer = nn.Sequential(nn.Linear(hidden_dim, hidden_dim // 10, bias=True), nn.Tanh(), 138 | nn.Linear(hidden_dim // 10, 1, bias=True)) 139 | 140 | @staticmethod 141 | def update_state(old_h, new_h, done_mask): 142 | done_mask = done_mask.float().unsqueeze(1).unsqueeze(2) 143 | h = done_mask * new_h + (1 - done_mask) * old_h[:, :-1, :] 144 | return h 145 | 146 | def select_composition(self, old_h, new_h, mask): 147 | old_h_left, old_h_right = old_h[:, :-1, :], old_h[:, 1:, :] 148 | 149 | comp_weights = self.query_layer(new_h).squeeze(2) 150 | 151 | 152 | if self.training: 153 | select_mask = utils.st_gumbel_softmax( 154 | logits=comp_weights, temperature=self.gumbel_temperature, 155 | mask=mask) 156 | else: 157 | select_mask = utils.greedy_select(logits=comp_weights, mask=mask).float() 158 | 159 | select_mask_cumsum = select_mask.cumsum(1) 160 | left_mask = 1 - select_mask_cumsum 161 | right_mask = select_mask_cumsum - select_mask 162 | 163 | new_h = (select_mask.unsqueeze(2) * new_h 164 | + left_mask.unsqueeze(2) * old_h_left 165 | + right_mask.unsqueeze(2) * old_h_right) 166 | 167 | return new_h, select_mask 168 | 169 | def forward(self, input, length): 170 | max_depth = input.size(1) 171 | length_mask = utils.sequence_mask(length=length, max_length=max_depth) 172 | select_masks = [] 173 | 174 | h = input 175 | 176 | for i in range(max_depth - 1): 177 | hl = h[:, :-1, :] 178 | hr = h[:, 1:, :] 179 | new_h = self.treegru_layer(hl, hr) 180 | if i < max_depth - 2: 181 | # We don't need to greedily select the composition in the 182 | # last iteration, since it has only one option left. 183 | new_h, select_mask = self.select_composition( 184 | old_h=h, new_h=new_h, 185 | mask=length_mask[:, i + 1:]) 186 | 187 | select_masks.append(select_mask) 188 | 189 | done_mask = length_mask[:, i + 1] 190 | 191 | h = self.update_state(old_h=h, new_h=new_h, 192 | done_mask=done_mask) 193 | 194 | out = h.squeeze(1) 195 | 196 | return out, select_masks 197 | -------------------------------------------------------------------------------- /network/model.py: -------------------------------------------------------------------------------- 1 | from network.encoder import * 2 | from network.selector import * 3 | 4 | 5 | class REModel_INS(nn.Module): 6 | """ 7 | The relation extraction model with INS mode, which will classify each 8 | sentence instance into an individual class. 9 | 10 | Args: 11 | vocab (object): The vocab object which contains the 12 | basic statics of the corpus. See dataset.py for 13 | more details. 14 | tag_dim (int): The dimension of POS (part-of-speech) 15 | embedding. 16 | max_length (int): All the sentences will be cropped 17 | or padded to the max_length. 18 | hidden_dim (int): The dimension of hidden unit in 19 | GRU. 20 | dropout_prob (float): Probability of an element 21 | to be zeroed. 22 | bidirectional (bool): If true, bi-directional GRU 23 | will be used. 24 | 25 | Inputs: sent, tag, length, verbose_output 26 | - **sent** of shape `(batch, seq_len, word_dim)`. 27 | - **tag** of shape `(batch, seq_len, tag_dim)`. 28 | - **length** of shape `(batch)`. 29 | 30 | Outputs: logit 31 | - **logit** of shape `(batch, class_num)`. 32 | """ 33 | 34 | def __init__(self, vocab, tag_dim, max_length, hidden_dim, 35 | dropout_prob, 36 | bidirectional=True): 37 | super(REModel_INS, self).__init__() 38 | 39 | self.vocab = vocab 40 | 41 | class_num = vocab.class_num 42 | word_num = vocab.word_num 43 | 44 | word_dim = vocab.word_dim 45 | tag_num = vocab.tag_num 46 | 47 | self.ent1_id = vocab.ent1_id 48 | self.ent2_id = vocab.ent2_id 49 | 50 | self.max_length = max_length 51 | self.word_emb = nn.Embedding(word_num, word_dim, padding_idx=0) 52 | self.tag_emb = nn.Embedding(tag_num, tag_dim, padding_idx=0) 53 | 54 | self.word_emb.weight.data.set_(vocab.vectors) 55 | 56 | in_dim = word_dim + tag_dim 57 | 58 | self.attn = MultiAttn(in_dim) 59 | self.leaf_rnn = PackedGRU(in_dim, hidden_dim, bidirectional=bidirectional) 60 | 61 | if bidirectional: 62 | hidden_dim = 2 * hidden_dim 63 | 64 | self.encoder = GumbelTreeGRU(hidden_dim) 65 | self.selector = BagAttention(3 * hidden_dim) 66 | 67 | feat_dim = 3 * hidden_dim 68 | 69 | self.classifier = nn.Sequential(nn.Linear(feat_dim, feat_dim // 10), nn.ReLU(), 70 | nn.Linear(feat_dim // 10, class_num)) 71 | 72 | self.dropout = nn.Dropout(dropout_prob) 73 | 74 | def display(self): 75 | print(self) 76 | 77 | def forward(self, sent, tag, length, verbose_output=False): 78 | ent1_mask = torch.eq(sent, self.ent1_id).unsqueeze(-1).float() 79 | ent2_mask = torch.eq(sent, self.ent2_id).unsqueeze(-1).float() 80 | 81 | word_embedding = self.dropout(self.word_emb(sent)) 82 | tag_embedding = self.dropout(self.tag_emb(tag)) 83 | 84 | embedding = torch.cat([word_embedding, tag_embedding], dim=-1) 85 | 86 | # -- Prepare masks 87 | attn_mask = utils.padding_mask(sent) 88 | non_pad_mask = utils.non_padding_mask(sent) 89 | 90 | embedding, word_attn = self.attn(embedding, attn_mask, non_pad_mask) 91 | embedding = self.leaf_rnn(embedding, length) 92 | 93 | tree_feat, tree_order = self.encoder(embedding, length) 94 | 95 | ent1_feat = (embedding * ent1_mask).sum(1) # (B,D) 96 | ent2_feat = (embedding * ent2_mask).sum(1) # (B,D) 97 | 98 | feat = torch.cat([tree_feat, ent1_feat, ent2_feat], -1) 99 | 100 | feat = self.dropout(feat) 101 | logit = self.classifier(feat) 102 | 103 | if verbose_output: 104 | return logit, word_attn, tree_order 105 | 106 | else: 107 | return logit 108 | 109 | 110 | class REModel_BAG(nn.Module): 111 | """ 112 | The relation extraction model with BAG mode, which will classify each 113 | sentence bag into an individual class. 114 | 115 | Args: 116 | vocab (object): The vocab object which contains the 117 | basic statics of the corpus. See dataset.py for 118 | more details. 119 | tag_dim (int): The dimension of POS (part-of-speech) 120 | embedding. 121 | max_length (int): All the sentences will be cropped 122 | or padded to the max_length. 123 | hidden_dim (int): The dimension of hidden unit in 124 | GRU. 125 | dropout_prob (float): Probability of an element 126 | to be zeroed. 127 | bidirectional (bool): If true, bi-directional GRU 128 | will be used. 129 | 130 | Inputs: sent, tag, length, verbose_output 131 | - **sent** of shape `(batch, seq_len, word_dim)`. 132 | - **tag** of shape `(batch, seq_len, tag_dim)`. 133 | - **length** of shape `(batch)`. 134 | 135 | Outputs: logit 136 | - **logit** of shape `(batch, class_num)`. 137 | """ 138 | 139 | def __init__(self, vocab, tag_dim, max_length, hidden_dim, 140 | dropout_prob, 141 | bidirectional=True): 142 | super(REModel_BAG, self).__init__() 143 | 144 | self.vocab = vocab 145 | 146 | class_num = vocab.class_num 147 | word_num = vocab.word_num 148 | 149 | word_dim = vocab.word_dim 150 | tag_num = vocab.tag_num 151 | 152 | self.ent1_id = vocab.ent1_id 153 | self.ent2_id = vocab.ent2_id 154 | 155 | self.max_length = max_length 156 | self.word_emb = nn.Embedding(word_num, word_dim, padding_idx=0) 157 | self.tag_emb = nn.Embedding(tag_num, tag_dim, padding_idx=0) 158 | 159 | self.word_emb.weight.data.set_(vocab.vectors) 160 | 161 | in_dim = word_dim + tag_dim 162 | 163 | self.attn = MultiAttn(in_dim) 164 | self.leaf_rnn = LeafRNN(in_dim, hidden_dim, bidirectional=bidirectional) 165 | 166 | if bidirectional: 167 | hidden_dim = 2 * hidden_dim 168 | 169 | self.encoder = GumbelTreeGRU(hidden_dim) 170 | 171 | self.selector = BagAttention(3 * hidden_dim) 172 | 173 | feat_dim = 3 * hidden_dim 174 | 175 | self.classifier = nn.Sequential(nn.Linear(feat_dim, feat_dim // 10), nn.ReLU(), 176 | nn.Linear(feat_dim // 10, class_num)) 177 | 178 | self.dropout = nn.Dropout(dropout_prob) 179 | 180 | def display(self): 181 | print(self) 182 | 183 | def forward(self, sent, tag, length, scope, verbose_output=False): 184 | ent1_mask = torch.eq(sent, self.ent1_id).unsqueeze(-1).float() 185 | ent2_mask = torch.eq(sent, self.ent2_id).unsqueeze(-1).float() 186 | 187 | word_embedding = self.dropout(self.word_emb(sent)) 188 | 189 | tag_embedding = self.dropout(self.tag_emb(tag)) 190 | 191 | embedding = torch.cat([word_embedding, tag_embedding], dim=-1) 192 | 193 | # -- Prepare masks 194 | attn_mask = utils.padding_mask(sent) 195 | non_pad_mask = utils.non_padding_mask(sent) 196 | 197 | embedding, word_attn = self.attn(embedding, attn_mask, non_pad_mask) 198 | embedding = self.leaf_rnn(embedding, non_pad_mask, length) 199 | 200 | tree_feat, tree_order = self.encoder(embedding, length) 201 | 202 | ent1_feat = (embedding * ent1_mask).sum(1) # (B,D) 203 | ent2_feat = (embedding * ent2_mask).sum(1) # (B,D) 204 | 205 | feat = torch.cat([tree_feat, ent1_feat, ent2_feat], -1) 206 | 207 | feat, sent_attn = self.selector(feat, scope) 208 | 209 | feat = self.dropout(feat) 210 | logit = self.classifier(feat) 211 | 212 | if verbose_output: 213 | return logit, word_attn, tree_order, sent_attn 214 | 215 | else: 216 | return logit 217 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | from dataset import * 6 | from network.model import * 7 | 8 | logging.basicConfig(level=logging.INFO, 9 | format='%(asctime)s %(levelname)-8s %(message)s') 10 | 11 | 12 | def top_k(config, model_name, top_k=10): 13 | if config.BAG_MODE: 14 | REModel = REModel_BAG 15 | DataLoader = DataLoader_BAG 16 | 17 | else: 18 | REModel = REModel_INS 19 | DataLoader = DataLoader_INS 20 | 21 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 22 | 23 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 24 | logging.info('Number of classes: {}'.format(vocab.class_num)) 25 | 26 | test_dataset = torch.load(os.path.join(config.ROOT_DIR, 'test.pt')) 27 | valid_dataset = torch.load(os.path.join(config.ROOT_DIR, 'valid.pt')) 28 | train_dataset = torch.load(os.path.join(config.ROOT_DIR, 'train.pt')) 29 | 30 | all_dataset = test_dataset + valid_dataset + train_dataset 31 | 32 | all_loader = DataLoader(all_dataset, config.BATCH_SIZE, collate_fn=test_dataset.collate, shuffle=False) 33 | 34 | logging.info('Number of total pair: {}'.format(len(all_dataset))) 35 | 36 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 37 | max_length=config.MAX_LENGTH, 38 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 39 | bidirectional=config.BIDIRECTIONAL) 40 | 41 | model.load_state_dict( 42 | torch.load(os.path.join(config.SAVE_DIR, config.DATA_SET, model_name), map_location='cpu')) 43 | model.eval() 44 | model.to(DEVICE) 45 | model.display() 46 | 47 | torch.set_grad_enabled(False) 48 | 49 | def run_iter(batch): 50 | sent = batch['sent'].to(DEVICE) 51 | tag = batch['tag'].to(DEVICE) 52 | 53 | length = batch['length'].to(DEVICE) 54 | scope = batch['scope'] 55 | 56 | label = batch['label'] 57 | id = batch['id'] 58 | 59 | logit = model(sent, tag, length, scope) 60 | 61 | return id, label, logit.cpu() 62 | 63 | def collect_data(config, data_map): 64 | all_data = json.load(open(os.path.join(config.ROOT_DIR, 'train.json'), 'r')) + json.load( 65 | open(os.path.join(config.ROOT_DIR, 'valid.json'), 'r')) + json.load( 66 | open(os.path.join(config.ROOT_DIR, 'test.json'), 'r')) 67 | 68 | data = [] 69 | for item in all_data: 70 | drug_id = item['head']['id'] 71 | target_id = item['tail']['id'] 72 | 73 | if (drug_id, target_id) in data_map: 74 | item['score'] = data_map[(drug_id, target_id)] 75 | data.append(item) 76 | 77 | data = sorted(data, key=lambda x: (x['score'], x['head']['id'], x['tail']['id']), reverse=True) 78 | return data 79 | 80 | result = [] 81 | 82 | for test_batch in all_loader: 83 | all_id, all_label, all_logit = run_iter(batch=test_batch) 84 | 85 | for idx in range(len(all_logit)): 86 | for rel in range(1, 2): 87 | if all_label[idx] == rel: 88 | result.append( 89 | {'id': all_id[idx], 'label': rel, 'score': float(all_logit[idx][rel])}) 90 | 91 | top_result = sorted(result, key=lambda x: x['score'], reverse=True)[:top_k] 92 | 93 | top_map = OrderedDict() 94 | for item in top_result: 95 | pair_id = item['id'].split('#') 96 | drug_id = pair_id[0] 97 | target_id = pair_id[1] 98 | score = item['score'] 99 | top_map[(drug_id, target_id)] = score 100 | 101 | data = collect_data(config, top_map) 102 | 103 | if not os.path.isdir(config.RESULT_DIR): 104 | os.makedirs(config.RESULT_DIR) 105 | 106 | json.dump(data, 107 | open(os.path.join(config.RESULT_DIR, '{}_top_{}_inhibitor.json'.format(config.DATA_SET, top_k)), 'w')) 108 | 109 | 110 | def visualize(config, model_name, case_name): 111 | if config.BAG_MODE: 112 | REModel = REModel_BAG 113 | 114 | 115 | else: 116 | REModel = REModel_INS 117 | 118 | vocab = torch.load(os.path.join(config.ROOT_DIR, 'vocab.pt')) 119 | 120 | logging.info('Load pretrained vectors: {}*{}'.format(vocab.word_num, vocab.word_dim)) 121 | logging.info('Number of classes: {}'.format(vocab.class_num)) 122 | 123 | case_data = json.load(open(os.path.join(config.ROOT_DIR, case_name), 'r')) 124 | case_dataset = REDataset_INS(vocab, data_dir=config.ROOT_DIR, data_name=case_name, max_length=config.MAX_LENGTH, 125 | sort=False) 126 | case_loader = DataLoader_INS(case_dataset, batch_size=1, collate_fn=case_dataset.collate, shuffle=False) 127 | 128 | logging.info('Number of total pair: {}'.format(len(case_dataset))) 129 | 130 | model = REModel(vocab=vocab, tag_dim=config.TAG_DIM, 131 | max_length=config.MAX_LENGTH, 132 | hidden_dim=config.HIDDEN_DIM, dropout_prob=config.DROP_PROB, 133 | bidirectional=config.BIDIRECTIONAL) 134 | 135 | model.load_state_dict( 136 | torch.load(os.path.join(config.SAVE_DIR, config.DATA_SET, model_name), map_location='cpu')) 137 | print(model.attn.gamma) 138 | 139 | model.eval() 140 | model.to(DEVICE) 141 | model.display() 142 | 143 | torch.set_grad_enabled(False) 144 | 145 | def run_iter(batch): 146 | sent = batch['sent'].to(DEVICE) 147 | tag = batch['tag'].to(DEVICE) 148 | 149 | length = batch['length'].to(DEVICE) 150 | scope = batch['scope'] 151 | 152 | logit, word_attn, tree_order, sent_attn = model(sent, tag, length, scope, verbose_output=True) 153 | 154 | return logit.cpu(), word_attn.cpu(), tree_order, sent_attn 155 | 156 | def plot_attn(word_attn, sent): 157 | plt.matshow(word_attn) 158 | plt.colorbar() 159 | x, y = word_attn.shape 160 | x = np.array(range(x)) 161 | y = np.array(range(y)) 162 | 163 | plt.xticks(x, sent, rotation=90, fontsize=12) 164 | plt.yticks(y, sent, fontsize=12) 165 | 166 | plt.tight_layout() 167 | plt.show() 168 | 169 | # plot_path = os.path.join(config.RESULT_DIR, "word_attn.pdf") 170 | # plt.savefig(plot_path) 171 | # print('Attention map plot saved at: {}'.format(plot_path)) 172 | 173 | def make_new_sent(sent, name1, name2, pos1, pos2): 174 | assert pos1 <= pos2 175 | new_sent = '{0} {1} {2} {3} {4}'.format(sent[:pos1[0]], name1, sent[pos1[1]:pos2[0]], name2, 176 | sent[pos2[1]:]) 177 | return new_sent 178 | 179 | def get_sentence(item): 180 | sent = item['sentence'] 181 | head_word = item['head']['word'] 182 | tail_word = item['tail']['word'] 183 | 184 | head_pos = sent.index(head_word) 185 | head_pos = [head_pos, head_pos + len(head_word)] 186 | tail_pos = sent.index(tail_word) 187 | tail_pos = [tail_pos, tail_pos + len(tail_word)] 188 | 189 | if head_pos <= tail_pos: 190 | sent = make_new_sent(sent, '', '', head_pos, tail_pos) 191 | else: 192 | sent = make_new_sent(sent, '', '', tail_pos, head_pos) 193 | 194 | sent = re.sub(r"\s+", r" ", sent).strip().split() 195 | 196 | head_idx = sent.index('') 197 | tail_idx = sent.index('') 198 | sent[head_idx] = head_word 199 | sent[tail_idx] = tail_word 200 | return sent 201 | 202 | def get_parse_tree(sent, tree_order): 203 | comp_order = [] 204 | for order in tree_order: 205 | order = order.cpu() 206 | 207 | index = torch.nonzero(order)[0][1] 208 | comp_order.append(index) 209 | comp_word = [] 210 | for order in comp_order: 211 | order = int(order) 212 | comp_word.append((sent[order], sent[order + 1])) 213 | sent[order] = sent[order] + ' ' + sent[order + 1] 214 | sent.pop(order + 1) 215 | comp_word.append((sent[0], sent[1])) 216 | sent[0] = sent[0] + ' ' + sent[1] 217 | sent.pop(1) 218 | return comp_word 219 | 220 | index = 0 221 | for case in case_loader: 222 | item = case_data[index] 223 | print(item) 224 | sent = get_sentence(item) 225 | 226 | case_logit, case_word_attn, case_tree_order, case_sent_attn = run_iter(batch=case) 227 | case_word_attn = np.mean(np.array(case_word_attn), axis=0) 228 | 229 | print(case_logit) 230 | plot_attn(case_word_attn, sent) 231 | 232 | parse_tree = get_parse_tree(sent, case_tree_order) 233 | print(sent) 234 | print(parse_tree) 235 | print(case_sent_attn) 236 | 237 | index += 1 238 | 239 | 240 | if __name__ == '__main__': 241 | from data.dti import config 242 | 243 | DEVICE = 'cuda:0' 244 | 245 | visualize(config, 'dti-0.5419.pkl', case_name='tree_example.json') 246 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import random 5 | import re 6 | from itertools import chain 7 | 8 | import gensim 9 | import numpy as np 10 | import torch 11 | from nltk import pos_tag 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | 15 | logging.basicConfig(level=logging.INFO, 16 | format='%(asctime)s %(levelname)-8s %(message)s') 17 | 18 | 19 | class Vocab(object): 20 | def __init__(self, label_path, emb_path): 21 | self.pad_id = 0 22 | self.unk_id = 1 23 | self.ent1_id = 2 24 | self.ent2_id = 3 25 | 26 | self.word2vec = gensim.models.KeyedVectors.load_word2vec_format(emb_path, 27 | binary=True) 28 | self.word_dim = self.word2vec.vector_size 29 | 30 | self.word2id = {'': self.pad_id, '': self.unk_id, '': self.ent1_id, '': self.ent2_id} 31 | self.tag2id = {'': self.pad_id, '': self.unk_id, '': self.ent1_id, '': self.ent2_id} 32 | 33 | self.vectors = [np.zeros(self.word_dim, dtype=np.float32), np.random.normal(size=self.word_dim), 34 | np.random.normal(size=self.word_dim), np.random.normal(size=self.word_dim)] 35 | 36 | self.label2id = json.load(open(label_path, 'r')) 37 | self.NA_id = self.label2id['NA'] 38 | 39 | self.class_num = len(self.label2id) 40 | self.freeze = False 41 | 42 | def get_id2label(self): 43 | id2label = {} 44 | for label, id in self.label2id.items(): 45 | id2label[id] = label 46 | return id2label 47 | 48 | def add_tags(self, tags): 49 | for tag in tags: 50 | if tag not in self.tag2id: 51 | self.tag2id[tag] = len(self.tag2id) 52 | 53 | def add_words(self, words): 54 | for word in words: 55 | if word in self.word2vec and word not in self.word2id: 56 | self.word2id[word] = len(self.word2id) 57 | self.vectors.append(self.word2vec[word]) 58 | 59 | def post_process(self): 60 | self.word2vec = None 61 | self.freeze = True 62 | 63 | self.word_num = len(self.word2id) 64 | self.tag_num = len(self.tag2id) 65 | 66 | self.vectors = torch.FloatTensor(self.vectors) 67 | 68 | 69 | class REDataset(Dataset): 70 | def __init__(self, vocab, data_dir, data_name, max_length, sort=True): 71 | self.max_length = max_length 72 | 73 | self.pad_id = vocab.pad_id 74 | 75 | data_path = os.path.join(data_dir, data_name) 76 | data = json.load(open(data_path, 'r')) 77 | 78 | self._data = [] 79 | 80 | logging.info('Process: {}'.format(data_path)) 81 | self.process(vocab, data) 82 | 83 | if sort: 84 | self._data.sort(key=lambda a: np.max(a[4]), reverse=True) 85 | 86 | def process(self, vocab, data): 87 | raise NotImplementedError 88 | 89 | def make_new_sent(self, sent, name1, name2, pos1, pos2): 90 | assert pos1 <= pos2 91 | new_sent = '{0} {1} {2} {3} {4}'.format(sent[:pos1[0]], name1, sent[pos1[1]:pos2[0]], name2, 92 | sent[pos2[1]:]) 93 | return new_sent 94 | 95 | def convert(self, vocab, ins): 96 | head, tail, sent, rel = ins['head'], ins['tail'], ins['sentence'].strip(), ins['relation'] 97 | 98 | head_word = head['word'] 99 | tail_word = tail['word'] 100 | 101 | head_pos = sent.index(head_word) 102 | head_pos = [head_pos, head_pos + len(head_word)] 103 | tail_pos = sent.index(tail_word) 104 | tail_pos = [tail_pos, tail_pos + len(tail_word)] 105 | 106 | if head_pos <= tail_pos: 107 | sent = self.make_new_sent(sent, '', '', head_pos, tail_pos) 108 | else: 109 | sent = self.make_new_sent(sent, '', '', tail_pos, head_pos) 110 | 111 | sent = re.sub(r"\s+", r" ", sent).strip().split() 112 | tags = [item[1] for item in pos_tag(sent)] 113 | 114 | head_pos = sent.index('') 115 | tail_pos = sent.index('') 116 | 117 | tags[head_pos] = '' 118 | tags[tail_pos] = '' 119 | 120 | head_pos = min(self.max_length - 1, head_pos) 121 | tail_pos = min(self.max_length - 1, tail_pos) 122 | 123 | sent = sent[:self.max_length] 124 | tags = tags[:self.max_length] 125 | length = len(sent) 126 | 127 | if not vocab.freeze: 128 | vocab.add_words(sent) 129 | vocab.add_tags(tags) 130 | 131 | sent = [vocab.word2id.get(w, vocab.unk_id) for w in sent] 132 | tags = [vocab.tag2id.get(w, vocab.unk_id) for w in tags] 133 | 134 | label = vocab.label2id.get(rel, vocab.NA_id) 135 | 136 | pos1 = [i - head_pos + self.max_length for i in range(length)] 137 | pos2 = [i - tail_pos + self.max_length for i in range(length)] 138 | 139 | return sent, tags, pos1, pos2, length, label 140 | 141 | def pad_seq(self, batch_seq, pad_value): 142 | max_length = max(len(seq) for seq in batch_seq) 143 | 144 | padded = [seq + [pad_value] * (max_length - len(seq)) 145 | for seq in batch_seq] 146 | 147 | return padded 148 | 149 | def get_labels(self): 150 | labels = [item[5] for item in self._data] 151 | return labels 152 | 153 | def __getitem__(self, index): 154 | return self._data[index] 155 | 156 | def __len__(self): 157 | return len(self._data) 158 | 159 | def collate(self, batch): 160 | sent_batch, tag_batch, pos1_batch, pos2_batch, length_batch, label_batch, id_batch, size_batch = list( 161 | zip(*batch)) 162 | 163 | sent_batch = list(chain(*sent_batch)) 164 | tag_batch = list(chain(*tag_batch)) 165 | pos1_batch = list(chain(*pos1_batch)) 166 | pos2_batch = list(chain(*pos2_batch)) 167 | length_batch = list(chain(*length_batch)) 168 | 169 | sent_batch = self.pad_seq(sent_batch, self.pad_id) 170 | tag_batch = self.pad_seq(tag_batch, self.pad_id) 171 | pos1_batch = self.pad_seq(pos1_batch, self.pad_id) 172 | pos2_batch = self.pad_seq(pos2_batch, self.pad_id) 173 | 174 | sent_batch = torch.LongTensor(sent_batch) 175 | tag_batch = torch.LongTensor(tag_batch) 176 | pos1_batch = torch.LongTensor(pos1_batch) 177 | pos2_batch = torch.LongTensor(pos2_batch) 178 | length_batch = torch.LongTensor(length_batch) 179 | 180 | label_batch = torch.LongTensor(label_batch) 181 | 182 | scope_batch = np.cumsum(size_batch) 183 | 184 | return {'sent': sent_batch, 'tag': tag_batch, 185 | 'pos1': pos1_batch, 'pos2': pos2_batch, 'length': length_batch, 186 | 'label': label_batch, 'id': id_batch, 'scope': scope_batch} 187 | 188 | 189 | class REDataset_INS(REDataset): 190 | def __init__(self, vocab, data_dir, data_name, max_length, sort=True): 191 | super(REDataset_INS, self).__init__(vocab, data_dir, data_name, max_length, sort) 192 | 193 | def process(self, vocab, data): 194 | for ins in tqdm(data): 195 | sent, tag, pos1, pos2, length, label = self.convert(vocab, ins) 196 | ins_id = ins['head']['id'] + '#' + ins['tail']['id'] 197 | self._data.append([[sent], [tag], [pos1], [pos2], [length], label, ins_id, 1]) 198 | 199 | 200 | class REDataset_BAG(REDataset): 201 | def __init__(self, vocab, data_dir, data_name, max_length, sort=True): 202 | super(REDataset_BAG, self).__init__(vocab, data_dir, data_name, max_length, sort) 203 | 204 | def process(self, vocab, data): 205 | data.sort(key=lambda a: a['head']['id'] + '#' + a['tail']['id']) 206 | last_ins_id = 'None#None' 207 | for ins in tqdm(data): 208 | sent, tag, pos1, pos2, length, label = self.convert(vocab, ins) 209 | ins_id = ins['head']['id'] + '#' + ins['tail']['id'] 210 | if ins_id != last_ins_id: 211 | self._data.append([[sent], [tag], [pos1], [pos2], [length], label, ins_id, 1]) 212 | else: 213 | self._data[-1][0].append(sent) 214 | self._data[-1][1].append(tag) 215 | self._data[-1][2].append(pos1) 216 | self._data[-1][3].append(pos2) 217 | self._data[-1][4].append(length) 218 | self._data[-1][7] += 1 219 | 220 | last_ins_id = ins_id 221 | 222 | 223 | class DataLoader_BAG(object): 224 | def __init__(self, dataset, batch_size, collate_fn, shuffle=False): 225 | self.dataset = dataset 226 | self.collate_fn = collate_fn 227 | self.batch_size = batch_size 228 | 229 | self.order = list(range(0, len(dataset))) 230 | 231 | self.shuffle = shuffle 232 | 233 | def __iter__(self): 234 | if self.shuffle: 235 | random.shuffle(self.order) 236 | i = 0 237 | while i < len(self.dataset): 238 | j = 0 239 | batch = [] 240 | while j < self.batch_size and i < len(self.dataset): 241 | batch.append(self.dataset[self.order[i]]) 242 | 243 | j += self.dataset[self.order[i]][7] 244 | i += 1 245 | 246 | yield self.collate_fn(batch) 247 | 248 | def __len__(self): 249 | return (len(self.dataset) + self.batch_size - 1) // self.batch_size 250 | 251 | 252 | class DataLoader_INS(object): 253 | def __init__(self, dataset, batch_size, collate_fn, shuffle=False): 254 | self.dataset = dataset 255 | self.collate_fn = collate_fn 256 | self.batch_size = batch_size 257 | 258 | self.order = list(range(0, len(dataset))) 259 | self.order = [self.order[i:i + batch_size] for i in range(0, len(dataset), batch_size)] 260 | 261 | self.shuffle = shuffle 262 | 263 | def __iter__(self): 264 | if self.shuffle: 265 | random.shuffle(self.order) 266 | 267 | for indices in self.order: 268 | batch = self.collate_fn([self.dataset[i] for i in indices]) 269 | 270 | yield batch 271 | 272 | def __len__(self): 273 | return (len(self.dataset) + self.batch_size - 1) // self.batch_size 274 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------