├── dnee ├── __init__.py ├── evals │ ├── __init__.py │ ├── confusion_matrix.py │ └── intrinsic.py ├── events │ ├── __init__.py │ ├── predicate_gr.py │ ├── animacy.py │ ├── indices.py │ └── events.py ├── models │ ├── __init__.py │ ├── sampler.py │ ├── skipthoughts.py │ ├── event_trans.py │ └── predicate_gr.py ├── datasets │ ├── __init__.py │ ├── predicate_gr.py │ ├── event_chains.py │ └── event_relation.py ├── discourse_annotator.py └── utils.py ├── relation_9disc.json ├── train_config_predicategr.json ├── relation_pdtb.json ├── train_config_ds_transe.json ├── train_config_transe_v0.2.10_long9.json ├── train_config_transr_v0.2.10_long9.json ├── requirements.txt ├── LICENSE ├── bin ├── evaluations │ ├── eval_mcnc_predicategr.py │ ├── eval_mcnc.py │ ├── eval_disc_elmo.py │ ├── test_combined_features.py │ ├── train_combined_features.py │ ├── eval_mcns.py │ ├── eval_disc.py │ └── eval_disc_binary.py └── train.py ├── requirements3.txt └── README.md /dnee/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dnee/evals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dnee/events/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .events import * 3 | -------------------------------------------------------------------------------- /dnee/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .event_trans import * 3 | from .sampler import * 4 | -------------------------------------------------------------------------------- /dnee/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .event_relation import EventRelationDataset, EventRelationConcatDataset 3 | from .predicate_gr import EventCompDataset 4 | -------------------------------------------------------------------------------- /relation_9disc.json: -------------------------------------------------------------------------------- 1 | { 2 | "rel2idx": { 3 | "Comparison.Contrast": 0, 4 | "Contingency.Cause.Reason": 1, 5 | "Contingency.Cause.Result": 2, 6 | "Contingency.Condition": 3, 7 | "Expansion.Restatement": 4, 8 | "Expansion.Conjunction": 5, 9 | "Expansion.Instantiation": 6, 10 | "Temporal.Synchrony": 7, 11 | "Temporal.Asynchronous": 8, 12 | "Context": 9, 13 | "Coref": 10 14 | }, 15 | "disc_begin": 0, 16 | "disc_end": 9 17 | } 18 | -------------------------------------------------------------------------------- /train_config_predicategr.json: -------------------------------------------------------------------------------- 1 | { 2 | "Word2Vec": { 3 | "embedding_file": "/homes/lee2226/scratch3/downloaded_embeddings/GoogleNews-vectors-negative300.txt", 4 | "emb_file_skip_first_line": 1 5 | }, 6 | "Word2VecEvent": { 7 | "embedding_file": "data/out_word2vec_events.txt" 8 | }, 9 | "batch_size": 500, 10 | "n_dataloader_workers": 1, 11 | "skipthought_dir": "data/skipthought_models", 12 | "event_dim": 500, 13 | "margin": 1, 14 | "n_epochs": 2, 15 | "optimizer": "adagrad", 16 | "n_batches_per_record": 100 17 | } 18 | -------------------------------------------------------------------------------- /relation_pdtb.json: -------------------------------------------------------------------------------- 1 | { 2 | "rel2idx": {"Temporal.Asynchronous.Precedence": 0, 3 | "Temporal.Asynchronous.Succession": 1, 4 | "Temporal.Synchrony": 2, 5 | "Contingency.Cause.Reason": 3, 6 | "Contingency.Cause.Result": 4, 7 | "Contingency.Condition": 5, 8 | "Comparison.Contrast": 6, 9 | "Comparison.Concession": 7, 10 | "Expansion.Conjunction": 8, 11 | "Expansion.Instantiation": 9, 12 | "Expansion.Restatement": 10, 13 | "Expansion.Alternative": 11, 14 | "Expansion.Alternative.Chosen alternative": 12, 15 | "Expansion.Exception": 13, 16 | "EntRel": 14 17 | }, 18 | "disc_begin": 0, 19 | "disc_end": 15 20 | } 21 | -------------------------------------------------------------------------------- /train_config_ds_transe.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_data": "data/nyt_indexed_triplets_train_v0.2.0", 3 | "predicate_indices": "data/pred_index_v0.2.0_filtered.txt", 4 | "argw_indices": "data/argw_index_v0.2.0_filtered.txt", 5 | "batch_size": 1000, 6 | "n_dataloader_workers": 2, 7 | "skipthought_dir": "data/skipthought_models", 8 | "n_old_rel_types": 11, 9 | "n_rel_types": 15, 10 | "pred_dim": 500, 11 | "event_hidden_dim": 1000, 12 | "event_dim": 500, 13 | "rel_dim": 500, 14 | "arg0_max_len": 15, 15 | "arg1_max_len": 15, 16 | "arg2_max_len": 15, 17 | "margin": 1, 18 | "n_epochs": 2, 19 | "optimizer": "adagrad", 20 | "n_batches_per_record": 100, 21 | "argw_encoder_opt": "customized_biskip", 22 | "model_type": "EventTransE", 23 | "norm": 1 24 | } 25 | -------------------------------------------------------------------------------- /train_config_transe_v0.2.10_long9.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_data": "data/training_data_v0.2.10_long9/nyt_indexed_triplets_train_v0.2.10_long9", 3 | "predicate_indices": "data/pred_index_filtered_v0.2.10_long9.txt", 4 | "argw_indices": "data/argw_index_filtered_v0.2.10_long9.txt", 5 | "batch_size": 1000, 6 | "n_dataloader_workers": 2, 7 | "skipthought_dir": "data/skipthought_models", 8 | "n_rel_types": 11, 9 | "pred_dim": 500, 10 | "event_hidden_dim": 1000, 11 | "event_dim": 500, 12 | "rel_dim": 500, 13 | "arg0_max_len": 15, 14 | "arg1_max_len": 15, 15 | "arg2_max_len": 15, 16 | "margin": 1, 17 | "n_epochs": 3, 18 | "optimizer": "adagrad", 19 | "n_batches_per_record": 100, 20 | "argw_encoder_opt": "customized_biskip", 21 | "model_type": "EventTransE", 22 | "norm": 1 23 | } 24 | -------------------------------------------------------------------------------- /train_config_transr_v0.2.10_long9.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_data": "data/training_data_v0.2.10_long9/nyt_indexed_triplets_train_v0.2.10_long9", 3 | "predicate_indices": "data/pred_index_filtered_v0.2.10_long9.txt", 4 | "argw_indices": "data/argw_index_filtered_v0.2.10_long9.txt", 5 | "batch_size": 1000, 6 | "n_dataloader_workers": 2, 7 | "skipthought_dir": "data/skipthought_models", 8 | "n_rel_types": 11, 9 | "pred_dim": 500, 10 | "event_hidden_dim": 1000, 11 | "event_dim": 500, 12 | "rel_dim": 500, 13 | "arg0_max_len": 15, 14 | "arg1_max_len": 15, 15 | "arg2_max_len": 15, 16 | "margin": 1, 17 | "n_epochs": 3, 18 | "optimizer": "adagrad", 19 | "n_batches_per_record": 100, 20 | "argw_encoder_opt": "customized_biskip", 21 | "model_type": "EventTransR", 22 | "norm": 1 23 | } 24 | -------------------------------------------------------------------------------- /dnee/datasets/predicate_gr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import linecache 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | 11 | class EventCompDataset(Dataset): 12 | def __init__(self, fpath, use_torch=True): 13 | super(EventCompDataset, self).__init__() 14 | self.logger = logging.getLogger(self.__class__.__name__) 15 | self.fpath = fpath 16 | self.use_torch = use_torch 17 | with open(self.fpath, 'r') as fr: 18 | self._len = len(fr.readlines()) - 1 19 | 20 | def __getitem__(self, idx): 21 | line = linecache.getline(self.fpath, idx+1) 22 | x = np.array([int(n) for n in line.split(' ')], dtype=np.int64) 23 | if self.use_torch: 24 | x = torch.from_numpy(x) 25 | # e1 [0:4], e2 [4:8] ne [8:12] 26 | return x 27 | 28 | def __len__(self): 29 | return self._len 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | antlr4-python2-runtime==4.7.2 2 | backports.functools-lru-cache==1.5 3 | boto==2.49.0 4 | boto3==1.9.96 5 | botocore==1.12.96 6 | bratreader==1.0.1 7 | bz2file==0.98 8 | certifi==2018.11.29 9 | chardet==3.0.4 10 | cycler==0.10.0 11 | docutils==0.14 12 | futures==3.2.0 13 | gensim==3.7.1 14 | h5py==2.9.0 15 | idna==2.7 16 | jmespath==0.9.3 17 | Keras==2.2.4 18 | Keras-Applications==1.0.6 19 | Keras-Preprocessing==1.0.5 20 | kiwisolver==1.0.1 21 | Lasagne==0.2.dev1 22 | lxml==4.3.2 23 | matplotlib==2.2.3 24 | nltk==3.4 25 | numpy==1.15.4 26 | parse==1.9.0 27 | progressbar==2.5 28 | protobuf==3.7.1 29 | psutil==5.4.8 30 | pyparsing==2.3.0 31 | python-dateutil==2.7.5 32 | pytz==2018.7 33 | PyYAML==3.13 34 | requests==2.20.1 35 | s3transfer==0.2.0 36 | scikit-learn==0.20.2 37 | scipy==1.1.0 38 | simplejson==3.16.0 39 | singledispatch==3.4.0.3 40 | six==1.11.0 41 | sklearn==0.0 42 | smart-open==1.8.0 43 | stanfordcorenlp==3.9.1.1 44 | stanfordnlp==0.1.2 45 | stix2==1.1.2 46 | stix2-patterns==1.1.0 47 | subprocess32==3.5.3 48 | tabulate==0.7.7 49 | torch==0.4.1 50 | tqdm==4.31.1 51 | urllib3==1.24.1 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 I-Ta Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dnee/datasets/event_chains.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import linecache 5 | 6 | from torch.utils.data import Dataset 7 | import torch 8 | 9 | 10 | class EventChainDataset(Dataset): 11 | def __init__(self, fpath, n_args=3, n_pos=9, n_neg=4): 12 | super(EventChainDataset, self).__init__() 13 | self.logger = logging.getLogger(self.__class__.__name__) 14 | self.logger.info("loading dataset: {}".format(fpath)) 15 | self.logger.info('arg_lens={}'.format(arg_lens)) 16 | self.n_args = n_args 17 | self.n_pos, self.n_neg = n_pos, n_neg 18 | self.fpath = fpath 19 | self._len = 0 20 | with open(self.fpath, 'r') as fr: 21 | self._len = len(fr.readlines()) - 1 22 | 23 | def __getitem__(self, idx): 24 | line = linecache.getline(self.fpath, idx+1) 25 | indexed_events = json.loads(line) 26 | 27 | x = torch.zeros((1+n_args)*(n_pos+n_neg), dtype=torch.int64) 28 | for i, ie in enumerate(indexed_events): 29 | start = i * (1+n_args) 30 | end = (i+1) * (1+n_args) 31 | import pdb; pdb.set_trace() 32 | x[start:end] = ie 33 | return x 34 | 35 | def __len__(self): 36 | return self._len 37 | -------------------------------------------------------------------------------- /dnee/models/sampler.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | class NegativeSampler: 6 | def __init__(self, maxlens, n_rels, sample_rel=False): 7 | self.arg0_max_len, self.arg1_max_len, self.arg2_max_len = maxlens 8 | self.sample_rel = sample_rel 9 | self.n_rels = n_rels 10 | 11 | def sampling(self, x): 12 | # x has to be on CPU, since PyTorch 0.4.0 doesn't support randperm on cuda. 13 | 14 | # truncate events 15 | n_samples = x.shape[0] 16 | x_neg = x.clone() 17 | 18 | idxs = torch.randperm(n_samples) 19 | chunks = torch.chunk(idxs, 3) if self.sample_rel else torch.chunk(idxs, 2) 20 | 21 | # corrupt e1 22 | neg_e = x[:, 1:1+1+self.arg0_max_len+self.arg1_max_len] 23 | neg_e = neg_e[torch.randperm(x.shape[0])] 24 | x_neg[chunks[0], 1:1+1+self.arg0_max_len+self.arg1_max_len] = \ 25 | neg_e[chunks[0]] 26 | 27 | # corrupt e2 28 | neg_e = x[:, 1+1+self.arg0_max_len+self.arg1_max_len:] 29 | neg_e = neg_e[torch.randperm(x.shape[0])] 30 | x_neg[chunks[1], 1+1+self.arg0_max_len+self.arg1_max_len:] = \ 31 | neg_e[chunks[1]] 32 | 33 | if self.sample_rel: 34 | # corrupt relations 35 | neg_rels = torch.randint_like(chunks[2], 0, self.n_rels) 36 | while True: 37 | idxs = (neg_rels == x[chunks[2], 0]).nonzero().view(-1) 38 | if len(idxs) == 0: 39 | break 40 | for i in idxs: 41 | neg_rels[i] = torch.randint(0, self.n_rels, (1,)) 42 | x_neg[chunks[2], 0] = neg_rels 43 | return x_neg 44 | -------------------------------------------------------------------------------- /dnee/datasets/event_relation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import linecache 5 | 6 | from torch.utils.data import ConcatDataset, Dataset 7 | import torch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class EventRelationDataset(Dataset): 13 | def __init__(self, fpath, arg_lens): 14 | self.arg_lens = arg_lens 15 | logger.info("loading dataset: {}".format(fpath)) 16 | logger.info('arg_lens={}'.format(arg_lens)) 17 | self.fpath = fpath 18 | self._len = 0 19 | with open(self.fpath, 'r') as fr: 20 | self._len = len(fr.readlines()) - 1 21 | 22 | def __getitem__(self, idx): 23 | line = linecache.getline(self.fpath, idx+1) 24 | rel = json.loads(line) 25 | 26 | y = torch.zeros(1, dtype=torch.int64) 27 | y[0] = rel[0] if rel[0] == 1 else -1 # positive or negative 28 | 29 | x = self.raw2x(rel, self.arg_lens) 30 | return x, y 31 | 32 | def __len__(self): 33 | return self._len 34 | 35 | @staticmethod 36 | def raw2x(raw, arg_lens): 37 | e_len = 1 + arg_lens[0] + arg_lens[1] 38 | x = torch.zeros(2 * e_len + 1, dtype=torch.int64) 39 | x[0] = raw[1] # rtype 40 | e1_begin = 1 41 | x[e1_begin] = raw[2] # e1 predicate 42 | x[e1_begin+1: e1_begin+1+len(raw[3])] = torch.LongTensor(raw[3]) # e1 arg0 43 | x[e1_begin+1+arg_lens[0]: e1_begin+1+arg_lens[0]+len(raw[4])] = torch.LongTensor(raw[4]) # e1 arg1 44 | # rel[5] is e1 arg2 45 | e2_begin = e1_begin + 1 + arg_lens[0] + arg_lens[1] 46 | x[e2_begin] = raw[6] # e2 predicate 47 | x[e2_begin+1: e2_begin+1+len(raw[7])] = torch.LongTensor(raw[7]) # e2 arg0 48 | x[e2_begin+1+arg_lens[0]: e2_begin+1+arg_lens[0]+len(raw[8])] = torch.LongTensor(raw[8]) # e2 arg1 49 | # rel[9] is e2 arg2 50 | return x 51 | 52 | 53 | class EventRelationConcatDataset(ConcatDataset): 54 | def __init__(self, fld, arg_lens): 55 | self.fpaths = [os.path.join(fld, f) for f in os.listdir(fld)] 56 | datasets = [EventRelationDataset(fp, arg_lens) for fp in self.fpaths] 57 | super(EventRelationConcatDataset, self).__init__(datasets) 58 | -------------------------------------------------------------------------------- /bin/evaluations/eval_mcnc_predicategr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import cPickle as pkl 9 | 10 | from scipy import sparse 11 | import numpy as np 12 | import torch 13 | import progressbar 14 | 15 | from dnee import utils 16 | from dnee.evals import intrinsic 17 | from dnee.events import indices 18 | from dnee.models.predicate_gr import * 19 | 20 | 21 | # for reproducing the result 22 | np.random.seed(123) 23 | 24 | 25 | def get_arguments(argv): 26 | parser = argparse.ArgumentParser(description='MCNC evaluation on PredateGR models') 27 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 28 | help='config for training') 29 | parser.add_argument('question_file', metavar='QUESTION_FILE', 30 | help='questions.') 31 | parser.add_argument('model_class', metavar='MODEL_CLASS', choices=['Word2Vec', 'Word2VecEvent'], 32 | help='model class name.') 33 | 34 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 35 | help='gpu id') 36 | 37 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 38 | help='show info messages') 39 | parser.add_argument('-d', '--debug', action='store_true', default=False, 40 | help='show debug messages') 41 | args = parser.parse_args(argv) 42 | return args 43 | 44 | 45 | def predict_mcnc(model, q): 46 | return model.predict_mcnc(q.get_contexts(), q.choices) 47 | 48 | 49 | def main(): 50 | config = json.load(open(args.training_config, 'r')) 51 | questions = pkl.load(open(args.question_file, 'r')) 52 | logging.info("#questions={}".format(len(questions))) 53 | 54 | n_correct, n_incorrect = 0, 0 55 | 56 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 57 | bar = progressbar.ProgressBar(widgets=widgets, maxval=len(questions)).start() 58 | 59 | ModelClass = eval(args.model_class) 60 | model = ModelClass(config[args.model_class]) 61 | if 'model_folder' in config[args.model_class]: 62 | model.load(config[args.model_class]['model_folder']) 63 | 64 | # logging.info("batch_size = {}".format(args.batch_size)) 65 | # batch_size = args.batch_size 66 | # n_batches = len(questions) // batch_size + 1 67 | # logging.info("n_batches = {}".format(n_batches)) 68 | i_q = 0 69 | for q in questions: 70 | pred = predict_mcnc(model, q) 71 | if pred == q.ans_idx: 72 | n_correct += 1 73 | else: 74 | n_incorrect += 1 75 | i_q += 1 76 | bar.update(i_q) 77 | logging.debug("pred={}, ans={}".format(pred, q.ans_idx)) 78 | bar.finish() 79 | print("n_correct={}, n_incorrect={}".format(n_correct, n_incorrect)) 80 | print("accuracy={}".format(float(n_correct)/(n_correct+n_incorrect))) 81 | 82 | 83 | if __name__ == "__main__": 84 | args = utils.bin_config(get_arguments) 85 | if torch.cuda.is_available(): 86 | args.device = torch.device('cuda') if args.gpu_id is None \ 87 | else torch.device('cuda:{}'.format(args.gpu_id)) 88 | else: 89 | args.device = torch.device('cpu') 90 | main() 91 | -------------------------------------------------------------------------------- /requirements3.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.0 2 | ad3==2.2.1 3 | alabaster==0.7.12 4 | allennlp==0.8.1 5 | asn1crypto==0.24.0 6 | astor==0.7.1 7 | atomicwrites==1.3.0 8 | attrs==18.2.0 9 | aws-xray-sdk==0.95 10 | awscli==1.16.106 11 | Babel==2.6.0 12 | beautifulsoup4==4.7.1 13 | boto==2.49.0 14 | boto3==1.9.96 15 | botocore==1.12.96 16 | cchardet==2.1.4 17 | certifi==2019.3.9 18 | cffi==1.11.5 19 | chardet==3.0.4 20 | Click==7.0 21 | colorama==0.3.9 22 | conllu==0.11 23 | cookies==2.2.1 24 | cryptography==2.5 25 | cvxopt==1.2.3 26 | cycler==0.10.0 27 | cymem==2.0.2 28 | Cython==0.26.1 29 | cytoolz==0.9.0.1 30 | decorator==4.1.2 31 | dill==0.2.9 32 | distro-info===0.18ubuntu0.18.04.1 33 | docker==3.7.0 34 | docker-pycreds==0.4.0 35 | docutils==0.14 36 | ecdsa==0.13 37 | editdistance==0.5.2 38 | eml-parser==1.11.1 39 | en-core-web-sm==2.0.0 40 | fail2ban==0.10.2 41 | flaky==3.5.3 42 | Flask==1.0.2 43 | Flask-Cors==3.0.7 44 | ftfy==5.5.1 45 | future==0.17.1 46 | gast==0.2.2 47 | gevent==1.3.6 48 | greenlet==0.4.15 49 | grpcio==1.18.0 50 | h5py==2.9.0 51 | html2text==2018.1.9 52 | httplib2==0.9.2 53 | idna==2.8 54 | imagesize==1.1.0 55 | ipaddress==1.0.22 56 | itsdangerous==1.1.0 57 | Jinja2==2.10 58 | jmespath==0.9.3 59 | joblib==0.13.2 60 | jsondiff==1.1.1 61 | jsonnet==0.10.0 62 | jsonpickle==1.1 63 | Keras==2.2.4 64 | Keras-Applications==1.0.7 65 | Keras-Preprocessing==1.0.9 66 | kiwisolver==1.0.1 67 | lxml==4.3.3 68 | mail-parser==3.9.3 69 | Markdown==3.0.1 70 | MarkupSafe==1.1.0 71 | matplotlib==2.2.3 72 | mock==2.0.0 73 | more-itertools==6.0.0 74 | moto==1.3.4 75 | msgpack==0.5.6 76 | msgpack-numpy==0.4.3.2 77 | murmurhash==1.0.2 78 | networkx==1.11 79 | nltk==3.4 80 | nose==1.3.7 81 | numpy==1.16.2 82 | numpydoc==0.8.0 83 | overrides==1.9 84 | packaging==19.0 85 | pandas==0.22.0 86 | parse==1.11.1 87 | parsimonious==0.8.0 88 | pasaffe==0.51 89 | pbr==5.1.2 90 | Pillow==5.1.0 91 | plac==0.9.6 92 | pluggy==0.8.1 93 | praw==5.3.0 94 | prawcore==0.13.0 95 | preshed==2.0.1 96 | progressbar==2.5 97 | protobuf==3.7.1 98 | psutil==5.5.1 99 | py==1.7.0 100 | pyaml==18.11.0 101 | pyasn1==0.4.5 102 | pycparser==2.19 103 | pycryptodome==3.7.3 104 | pycurl==7.43.0.1 105 | Pygments==2.3.1 106 | pygobject==3.26.1 107 | pyparsing==2.3.1 108 | pystruct==0.3.2 109 | pytest==4.2.1 110 | python-apt==1.6.3+ubuntu1 111 | python-dateutil==2.8.0 112 | python-debian==0.1.32 113 | python-jose==2.0.2 114 | pytorch-pretrained-bert==0.3.0 115 | pytz==2017.3 116 | PyYAML==3.13 117 | regex==2018.1.10 118 | requests==2.21.0 119 | requests-unixsocket==0.1.5 120 | responses==0.10.5 121 | rsa==3.4.2 122 | s3transfer==0.2.0 123 | sacremoses==0.0.19 124 | scikit-learn==0.20.2 125 | scipy==1.2.1 126 | scour==0.36 127 | simplejson==3.16.0 128 | singledispatch==3.4.0.3 129 | six==1.12.0 130 | sklearn==0.0 131 | snowballstemmer==1.2.1 132 | soupsieve==1.9 133 | spacy==2.0.18 134 | Sphinx==1.8.4 135 | sphinxcontrib-websupport==1.1.0 136 | sqlparse==0.2.4 137 | stanfordcorenlp==3.9.1.1 138 | stanfordnlp==0.1.2 139 | systemd-python==234 140 | tabulate==0.7.7 141 | tensorboard==1.12.2 142 | tensorboardX==1.2 143 | tensorflow==1.12.0 144 | termcolor==1.1.0 145 | thinc==6.12.1 146 | toolz==0.9.0 147 | torch==1.0.1.post2 148 | tqdm==4.31.1 149 | typing==3.6.6 150 | ujson==1.35 151 | Unidecode==1.0.23 152 | urllib3==1.24.2 153 | wcwidth==0.1.7 154 | websocket-client==0.54.0 155 | Werkzeug==0.14.1 156 | wrapt==1.11.1 157 | xmltodict==0.12.0 158 | -------------------------------------------------------------------------------- /dnee/events/predicate_gr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Events in Predicate-GRs form, for competitor models like 3 | SGNN 4 | EventComp 5 | PMI 6 | SkipGram 7 | """ 8 | import sys 9 | import parse 10 | from ..events import indices 11 | 12 | 13 | event_parser = parse.compile("{pred}({dep},{arg0},{arg1},{arg2})") 14 | 15 | 16 | 17 | def load_we_index(fpath): 18 | idx = 1 # save 0 for zero vector 19 | e2idx = {} 20 | with open(fpath) as fr: 21 | for line in fr: 22 | sp = line.split(' ') 23 | e2idx[sp[0]] = idx 24 | idx += 1 25 | return e2idx 26 | 27 | 28 | class Event(object): 29 | def __init__(self, predicate, dep, arg0, arg1, arg2): 30 | if sys.version_info >= (3, 0): 31 | self.pred = predicate.lower() 32 | self.a0 = arg0.lower() 33 | self.a1 = arg1.lower() 34 | self.a2 = arg2.lower() 35 | else: 36 | self.pred = predicate.encode('ascii', 'ignore').lower() 37 | self.a0 = arg0.encode('ascii', 'ignore').lower() 38 | self.a1 = arg1.encode('ascii', 'ignore').lower() 39 | self.a2 = arg2.encode('ascii', 'ignore').lower() 40 | self.dep = dep 41 | 42 | def __repr__(self): 43 | return "{}({},{},{},{})".format(self.pred, self.dep, self.a0, self.a1, self.a2) 44 | 45 | @classmethod 46 | def from_string(cls, line): 47 | global event_parser 48 | res = event_parser.parse(line) 49 | if res is None: 50 | print(line) 51 | return cls(res['pred'], res['dep'], res['arg0'], res['arg1'], res['arg2']) 52 | 53 | @classmethod 54 | def from_json(cls, e): 55 | pred = e['predicate'] 56 | dep = e['dep'] 57 | arg0_head = e['arg0'][0] if 'arg0' in e else indices.NO_ARG 58 | arg1_head = e['arg1'][0] if 'arg1' in e else indices.NO_ARG 59 | arg2_head = e['arg2'][0] if 'arg2' in e else indices.NO_ARG 60 | obj = cls(pred, dep, arg0_head, arg1_head, arg2_head) 61 | return obj 62 | 63 | def index(self, pred2idx, argw2idx): 64 | if self.pred not in pred2idx: 65 | return None 66 | idxs = [pred2idx[self.pred]] 67 | for a in [self.a0, self.a1, self.a2]: 68 | if a in argw2idx: 69 | idxs.append(argw2idx[a]) 70 | else: 71 | idxs.append(argw2idx[indices.UNKNOWN_ARG_WORD]) 72 | import pdb; pdb.set_trace() 73 | return idxs 74 | 75 | @staticmethod 76 | def cj08_format(pred, dep): 77 | return '({},{})'.format(pred, dep) 78 | 79 | @staticmethod 80 | def predicategr_format(pred, dep, a0, a1, a2, protagonist_str='_PROTAGONIST_'): 81 | if dep == 'nsubj': 82 | estr = '{}({},{},{})'.format(pred, protagonist_str, a1, a2) 83 | elif dep == 'dobj' or dep == 'nsubjpass': 84 | estr = '{}({},{},{})'.format(pred, a0, protagonist_str, a2) 85 | else: 86 | estr = '{}({},{},{})'.format(pred, a0, a1, protagonist_str) 87 | return estr 88 | 89 | def to_cj08_format(self): 90 | return self.cj08_format(self.pred, self.dep) 91 | 92 | def to_predicategr_format(self): 93 | return self.predicategr_format(self.pred, self.dep, self.a0, self.a1, self.a2) 94 | 95 | 96 | class EventChain(object): 97 | def __init__(self, events): 98 | for e in events: 99 | assert isinstance(e, Event) 100 | self.events = events 101 | 102 | def __repr__(self): 103 | es = [repr(e) for e in self.events] 104 | return ' '.join(es) 105 | 106 | def gen(self): 107 | for e in self.events: 108 | yield e 109 | 110 | def __len__(self): 111 | return len(self.events) 112 | 113 | def __getitem__(self, idx): 114 | return self.events[idx] 115 | 116 | @classmethod 117 | def from_string(cls, line): 118 | line = line.rstrip('\n') 119 | sp = line.split(' ') 120 | events = [Event.from_string(e) for e in sp] 121 | return cls(events) 122 | -------------------------------------------------------------------------------- /dnee/events/animacy.py: -------------------------------------------------------------------------------- 1 | """Module for determining animacy 2 | 3 | Last Update: July 15th 2016 4 | Author: I-Ta Lee @Purdue 5 | """ 6 | import os 7 | import sys 8 | from io import open 9 | 10 | 11 | # ToDo: create module configurations 12 | module_base = os.path.dirname(__file__) 13 | animate_words_fpath = os.path.join(module_base, 'animacy_data/animate.unigrams.txt') 14 | inanimate_words_fpath = os.path.join(module_base, 'animacy_data/inanimate.unigrams.txt') 15 | animate_words = None 16 | inanimate_words = None 17 | 18 | ANIMATE_STR = 'ANIMATE' 19 | INANIMATE_STR = 'INANIMATE' 20 | UNKNOWN_ANIMACY_STR = 'UNKNOWN' 21 | 22 | animate_pronouns = ["i", "me", "myself", "mine", "my", "we", "us", "ourself", "ourselves", "ours", "our", "you", "yourself", "yours", "your", "yourselves", "he", "him", "himself", "his", "she", "her", "herself", "hers", "her", "one", "oneself", "one's", "they", "them", "themself", "themselves", "theirs", "their", "they", "them", "'em", "themselves", "who", "whom", "whose"] 23 | inanimate_pronouns = ["it", "itself", "its", "where", "when"] 24 | 25 | 26 | def get_animacy_by_index(entities, sent_idx, tok_idx): 27 | for eid, ent in entities.iteritems(): 28 | if (sent_idx == ent['sentNum'] - 1 29 | and tok_idx >= ent['startIndex'] - 1 30 | and tok_idx < ent['endIndex'] - 1): 31 | return ent['animacy'] 32 | return None 33 | 34 | 35 | def get_animacy_from_corefs(corefs, sent_idx, tok_idx): 36 | for cid, cchain in corefs.iteritems(): 37 | for ent in cchain: 38 | if (sent_idx == ent['sentNum'] - 1 39 | and tok_idx >= ent['startIndex'] - 1 40 | and tok_idx < ent['endIndex'] - 1): 41 | return ent['animacy'] 42 | return None 43 | 44 | 45 | def get_animacy_from_parsed(parsed, doc_id, sent_idx, tok_idx): 46 | corefs = parsed[doc_id]['corefs'] 47 | return get_animacy_from_corefs(corefs, sent_idx, tok_idx) 48 | 49 | 50 | def load_animacy_file(fpath): 51 | """load animacy list file 52 | 53 | Args: 54 | fpath: File path 55 | 56 | Returns: 57 | A list of words. 58 | 59 | """ 60 | ret_list = [] 61 | with open(fpath, encoding='utf-8') as fr: 62 | for line in fr: 63 | ret_list.append(line.rstrip('\n')) 64 | return ret_list 65 | 66 | 67 | def get_animacy_auto(word, ner, entities, sent_idx, tok_idx): 68 | tmp_ani = get_animacy_by_index(entities, sent_idx, tok_idx) 69 | if tmp_ani is None: 70 | tmp_ani = get_animacy(word, ner) 71 | return tmp_ani 72 | 73 | 74 | def get_animacy(word, ner): 75 | """Get animacy (animate or inanimate) for a word 76 | 77 | Args: 78 | word: The word to be checked. It should be a word without any space. 79 | ner: Named entity recognition label of the word. 80 | 81 | Returns: 82 | A string 'ANIMATE' or 'INANIMATE'. 83 | 84 | """ 85 | 86 | ret_animacy = UNKNOWN_ANIMACY_STR 87 | # determine using NER, pronoun list, and word list 88 | if ner == "PERSON" or ner.startswith('PER'): 89 | ret_animacy = ANIMATE_STR 90 | elif (ner == 'LOCATION' 91 | or ner.startswith('LOC') 92 | or ner == 'MONEY' 93 | or ner == 'NUMBER' 94 | or ner == 'PERCENT' 95 | or ner == 'DATE' 96 | or ner == 'TIME' 97 | or ner.startswith('FAC') 98 | or ner.startswith('GPE') 99 | or ner.startswith('WEA') 100 | or ner.startswith('ORG')): 101 | ret_animacy = INANIMATE_STR 102 | elif word in animate_pronouns: # check with pronoun list 103 | ret_animacy = ANIMATE_STR 104 | elif word in inanimate_pronouns: # check with pronoun list 105 | ret_animacy = INANIMATE_STR 106 | else: 107 | # check with animate/inanimate words 108 | global animate_words 109 | global animate_words_fpath 110 | global inanimate_words 111 | global inanimate_words_fpath 112 | if animate_words is None: 113 | animate_words = load_animacy_file(animate_words_fpath) 114 | if inanimate_words is None: 115 | inanimate_words = load_animacy_file(inanimate_words_fpath) 116 | if word in animate_words: 117 | ret_animacy = ANIMATE_STR 118 | elif word in inanimate_words: 119 | ret_animacy = INANIMATE_STR 120 | return ret_animacy 121 | -------------------------------------------------------------------------------- /dnee/events/indices.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | import six 5 | 6 | 7 | # animacy 8 | UNKNOWN_ANIMACY = "unknown" 9 | ANI2IDX= {UNKNOWN_ANIMACY: 0, "animate": 1, "inanimate": 2} 10 | IDX2ANI = {0: UNKNOWN_ANIMACY, 1: "animate", 2: "inanimate"} 11 | 12 | # sentiment 13 | SENT2IDX = {"verynegative":0, 14 | "negative": 1, 15 | "neutral": 2, 16 | "positive": 3, 17 | "verypositive": 4} 18 | IDX2SENT = {0: "verynegative", 19 | 1: "negative", 20 | 2: "neutral", 21 | 3: "positive", 22 | 4: "verypositive"} 23 | 24 | # relations 25 | REL_COREF = "Coref" 26 | REL_CONTEXT = "Context" 27 | REL2IDX = {} 28 | # REL2IDX = {"Comparison.Contrast": 0, 29 | # "Contingency.Cause.Reason": 1, 30 | # "Contingency.Cause.Result": 2, 31 | # "Contingency.Condition": 3, 32 | # "Expansion.Restatement": 4, 33 | # "Expansion.Conjunction": 5, 34 | # "Expansion.Instantiation": 6, 35 | # "Temporal.Synchrony": 7, 36 | # "Temporal.Asynchronous": 8, 37 | # REL_CONTEXT: 9, 38 | # REL_COREF: 10 39 | # } 40 | 41 | DISCOURSE_REL2IDX = {} 42 | DISCOURSE_IDX2REL = {} 43 | # DISCOURSE_REL2IDX = {"Comparison.Contrast": 0, 44 | # "Contingency.Cause.Reason": 1, 45 | # "Contingency.Cause.Result": 2, 46 | # "Contingency.Condition": 3, 47 | # "Expansion.Restatement": 4, 48 | # "Expansion.Conjunction": 5, 49 | # "Expansion.Instantiation": 6, 50 | # "Temporal.Synchrony": 7, 51 | # "Temporal.Asynchronous": 8 52 | # } 53 | 54 | IDX2REL = {} 55 | # IDX2REL = {0: "Comparison.Contrast", 56 | # 1: "Contingency.Cause.Reason", 57 | # 2: "Contingency.Cause.Result", 58 | # 3: "Contingency.Condition", 59 | # 4: "Expansion.Restatement", 60 | # 5: "Expansion.Conjunction", 61 | # 6: "Expansion.Instantiation", 62 | # 7: "Temporal.Synchrony", 63 | # 8: "Temporal.Asynchronous", 64 | # 9: REL_CONTEXT, 65 | # 10: REL_COREF 66 | # } 67 | 68 | # constants 69 | PRED_OOV = '__PRED_OOV__' 70 | NO_ARG = '__NO_ARG__' 71 | UNKNOWN_ARG_WORD = 'UNK' 72 | EOS_ARG_WORD = '' 73 | 74 | 75 | def set_relation_classes(fpath): 76 | global REL2IDX, IDX2REL, DISCOURSE_REL2IDX 77 | rel_config = json.load(open(fpath, 'r')) 78 | REL2IDX = rel_config['rel2idx'] 79 | IDX2REL = {v: k for k, v in six.iteritems(REL2IDX)} 80 | 81 | db, de = int(rel_config['disc_begin']), int(rel_config['disc_end']) 82 | for i in range(db, de): 83 | DISCOURSE_REL2IDX[IDX2REL[i]] = i 84 | DISCOURSE_IDX2REL[i] = IDX2REL[i] 85 | 86 | 87 | def load_freqs(fpath): 88 | ret = {} 89 | with open(fpath) as fr: 90 | for line in fr: 91 | line = line.rstrip('\n') 92 | sp = line.split('\t') 93 | ret[sp[0]] = int(sp[1]) 94 | return ret 95 | 96 | 97 | def dump_freqs(output_file, freq): 98 | with open(output_file, 'w') as fw: 99 | if isinstance(freq, dict): 100 | for k, v in six.iteritems(freq): 101 | fw.write('{}\t{}\n'.format(k, v)) 102 | else: 103 | for p in freq: 104 | fw.write('{}\t{}\n'.format(p[0], p[1])) 105 | 106 | 107 | def load_predicates(fpath): 108 | return _load_freqs(fpath, oov_key=PRED_OOV) 109 | 110 | 111 | def load_argw(fpath, eos_key=EOS_ARG_WORD): 112 | return _load_freqs(fpath, oov_key=UNKNOWN_ARG_WORD, 113 | eos_key=eos_key, begin_idx=1) 114 | 115 | 116 | def _load_freqs(fpath, oov_key=None, eos_key=None, begin_idx=0): 117 | key2idx, idx2key = OrderedDict(), OrderedDict() 118 | key_freq = {} 119 | idx = begin_idx 120 | if oov_key: 121 | key2idx[oov_key] = idx 122 | idx2key[idx] = oov_key 123 | idx += 1 124 | if eos_key: 125 | key2idx[eos_key] = idx 126 | idx2key[idx] = eos_key 127 | idx += 1 128 | with open(fpath, 'r') as fr: 129 | for line in fr: 130 | line = line.rstrip('\n') 131 | sp = line.split('\t') 132 | if sp[0] not in key2idx: 133 | key2idx[sp[0]] = idx 134 | idx2key[idx] = sp[0] 135 | idx += 1 136 | key_freq[sp[0]] = int(sp[1]) 137 | return key2idx, idx2key, key_freq 138 | 139 | 140 | def load_dictionary(fpath): 141 | with open(fpath, 'r') as fr: 142 | all_lines = fr.readlines() 143 | argw2idx = {word.strip(): i for i, word in enumerate(all_lines)} 144 | return argw2idx 145 | -------------------------------------------------------------------------------- /dnee/discourse_annotator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | 9 | import progressbar 10 | 11 | from dnee import utils 12 | 13 | 14 | def extract_explicit_connectives(doc, dmarkers): 15 | sents = doc["sentences"] 16 | dm_idxs = [] 17 | for i, sent in enumerate(sents): 18 | toks = [tok["word"].lower() for tok in sent["tokens"]] 19 | cidx2tidx = {0: 0} 20 | acc_clen = 0 21 | for j, tok in enumerate(sent["tokens"]): 22 | if j == 0: 23 | acc_clen += (len(tok["word"]) + 1) 24 | continue 25 | cidx2tidx[acc_clen] = j 26 | acc_clen += (len(tok["word"]) + 1) 27 | 28 | text = ' '.join(toks) 29 | for dtype, dm in dmarkers: 30 | dlen = len(dm.split(" ")) 31 | tidxs = [] 32 | for m in re.finditer(dm, text): 33 | if m.start() not in cidx2tidx: 34 | continue 35 | begin_cidx = cidx2tidx[m.start()] 36 | end_cidx = cidx2tidx[m.start()] + dlen 37 | tmp_toks = [sents[i]["tokens"][k]["word"].lower() for k in range(begin_cidx, end_cidx)] 38 | tmp_text = ' '.join(tmp_toks) 39 | if tmp_text == dm: 40 | tidxs.append((i, begin_cidx, end_cidx, dtype, dm)) 41 | 42 | dm_idxs += tidxs 43 | return dm_idxs 44 | 45 | 46 | def find_min_dist(idxs, conn): 47 | # get the one close to the connective 48 | min_dist = 10000 49 | target_idx = -1 50 | for idx in idxs: 51 | dist = abs(idx - conn[1]) 52 | if dist < min_dist: 53 | min_dist = dist 54 | target_idx = idx 55 | return target_idx, min_dist 56 | 57 | 58 | def find_clause_args(conn, doc, delimiter=';'): 59 | sent_idx = conn[0] 60 | sent_toks = [tok['word'] for tok in doc['sentences'][sent_idx]['tokens']] 61 | idxs = [i for i, tok in enumerate(sent_toks) if tok == delimiter] 62 | cargs = None 63 | if len(idxs) > 0: 64 | idx, dist = find_min_dist(idxs, conn) 65 | if idx != -1: 66 | left = (sent_idx, 0, idx) 67 | right = (sent_idx, idx+1, len(sent_toks)) 68 | cargs = (left, right) 69 | return cargs 70 | 71 | 72 | def find_sentence_args(conn, doc): 73 | sent_idx = conn[0] 74 | sent_toks = [tok['word'] for tok in doc['sentences'][sent_idx]['tokens']] 75 | tok_begin_idx = conn[1] 76 | if len(doc['sentences']) <= 1: 77 | return None 78 | 79 | arg1, arg2 = None, None 80 | if sent_idx == 0: # pick right 81 | arg1 = (sent_idx, 0, len(sent_toks)) 82 | arg2 = (sent_idx+1, 0, len(doc['sentences'][sent_idx+1])) 83 | elif sent_idx == len(sent_toks) - 1: # pick left 84 | arg1 = (sent_idx-1, 0, len(doc['sentences'][sent_idx-1])) 85 | arg2 = (sent_idx, 0, len(sent_toks)) 86 | elif tok_begin_idx <= len(sent_toks) - tok_begin_idx - 1: # pick left 87 | arg1 = (sent_idx-1, 0, len(doc['sentences'][sent_idx-1])) 88 | arg2 = (sent_idx, 0, len(sent_toks)) 89 | else: # pick right 90 | arg1 = (sent_idx-1, 0, len(doc['sentences'][sent_idx-1])) 91 | arg2 = (sent_idx, 0, len(sent_toks)) 92 | return None if arg1 is None else (arg1, arg2) 93 | 94 | 95 | def pack_rel(conn, arg1, arg2): 96 | def _pack_index(k, v): 97 | ret = {} 98 | ret[k] = {'sent_idx': v[0], 'tok_begin_idx': v[1], 'tok_end_idx': v[2]} 99 | return ret 100 | 101 | rel = {} 102 | rel.update(_pack_index('arg1', arg1)) 103 | rel.update(_pack_index('arg2', arg2)) 104 | rel.update(_pack_index('connective', conn)) 105 | rel['connective']['type'] = conn[3] 106 | rel['connective']['text'] = conn[4] 107 | return rel 108 | 109 | 110 | def annotate(doc, dmarkers): 111 | if doc is None or len(doc['sentences']) == 0: 112 | return None 113 | 114 | connectives = extract_explicit_connectives(doc, dmarkers) 115 | doc_rels = [] 116 | for conn in connectives: 117 | cargs = find_clause_args(conn, doc, delimiter=';') 118 | if cargs is not None: 119 | arg1, arg2 = cargs 120 | rel = pack_rel(conn, arg1, arg2) 121 | doc_rels.append(rel) 122 | continue 123 | 124 | cargs = find_clause_args(conn, doc, delimiter=',') 125 | if cargs is not None: 126 | arg1, arg2 = cargs 127 | rel = pack_rel(conn, arg1, arg2) 128 | doc_rels.append(rel) 129 | continue 130 | 131 | cargs = find_sentence_args(conn, doc) 132 | if cargs is None: 133 | continue 134 | 135 | arg1, arg2 = cargs 136 | rel = pack_rel(conn, arg1, arg2) 137 | doc_rels.append(rel) 138 | return doc_rels 139 | -------------------------------------------------------------------------------- /bin/evaluations/eval_mcnc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import cPickle as pkl 9 | 10 | import numpy as np 11 | import torch 12 | import progressbar 13 | 14 | from dnee import utils 15 | from dnee.evals import intrinsic 16 | from dnee.events import indices 17 | from dnee.models import EventTransR, EventTransE, ArgWordEncoder, create_argw_encoder, AttnEventTransE, AttnEventTransR 18 | 19 | 20 | def get_arguments(argv): 21 | parser = argparse.ArgumentParser(description='MCNC evaluation') 22 | parser.add_argument('model_file', metavar='MODEL_FILE', 23 | help='model file.') 24 | parser.add_argument('encoder_file', metavar='ENCODER_FILE', 25 | help='encoder file.') 26 | parser.add_argument('question_file', metavar='QUESTION_FILE', 27 | help='questions.') 28 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 29 | help='config for training') 30 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 31 | help='config for relations') 32 | 33 | parser.add_argument('-b', '--batch_size', type=int, default=100, 34 | help='batch size for evaluation') 35 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 36 | help='gpu id') 37 | parser.add_argument('-c', '--context_rel', action='store_true', default=False, 38 | help='use REL_CONTEXT instead of REL_COREF') 39 | parser.add_argument('-u', '--use_head', action='store_true', default=False, 40 | help='use head word only for arguments') 41 | 42 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 43 | help='show info messages') 44 | parser.add_argument('-d', '--debug', action='store_true', default=False, 45 | help='show debug messages') 46 | args = parser.parse_args(argv) 47 | return args 48 | 49 | 50 | def build_embeddings(model, questions, config, pred2idx, argw2idx, rtype): 51 | e2idx = {} 52 | idx = 0 53 | for q in questions: 54 | for ctx in q.get_contexts(): 55 | key = ctx.__repr__() 56 | if key not in e2idx: 57 | e2idx[key] = idx 58 | idx += 1 59 | 60 | for ch in q.choices: 61 | key = ch.__repr__() 62 | if key not in e2idx: 63 | e2idx[key] = idx 64 | idx += 1 65 | e_len = 1 + config['arg0_max_len'] + config['arg1_max_len'] 66 | inputs = torch.zeros((len(e2idx), e_len), 67 | dtype=torch.int64).to(args.device) 68 | 69 | for q in questions: 70 | for e in q.get_contexts(): 71 | idx = e2idx[e.__repr__()] 72 | inputs[idx] = utils.get_raw_event_repr(e, config, pred2idx, argw2idx, device=args.device, use_head=args.use_head) 73 | for e in q.choices: 74 | idx = e2idx[e.__repr__()] 75 | inputs[idx] = utils.get_raw_event_repr(e, config, pred2idx, argw2idx, device=args.device, use_head=args.use_head) 76 | embeddings = model._transfer(model.embed_event(inputs), rtype) 77 | return e2idx, embeddings 78 | 79 | 80 | def main(): 81 | config = json.load(open(args.training_config, 'r')) 82 | indices.set_relation_classes(args.relation_config) 83 | pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) 84 | argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) 85 | n_preds = len(pred2idx) 86 | argw_vocabs = argw2idx.keys() 87 | argw_encoder = create_argw_encoder(config, args.device) 88 | if args.encoder_file: 89 | argw_encoder.load(args.encoder_file) 90 | 91 | logging.info("model class: " + config['model_type']) 92 | ModelClass = eval(config['model_type']) 93 | model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 94 | model.load_state_dict(torch.load(args.model_file, 95 | map_location=lambda storage, location: storage)) 96 | 97 | questions = pkl.load(open(args.question_file, 'r')) 98 | logging.info("#questions={}".format(len(questions))) 99 | 100 | n_correct, n_incorrect = 0, 0 101 | rtype = indices.REL2IDX[indices.REL_CONTEXT] if args.context_rel else indices.REL2IDX[indices.REL_COREF] 102 | rtype = torch.LongTensor([rtype]).to(args.device) 103 | 104 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 105 | bar = progressbar.ProgressBar(widgets=widgets, maxval=len(questions)).start() 106 | 107 | logging.info("batch_size = {}".format(args.batch_size)) 108 | batch_size = args.batch_size 109 | n_batches = len(questions) // batch_size + 1 110 | logging.info("n_batches = {}".format(n_batches)) 111 | i_q = 0 112 | for i_batch in range(n_batches): 113 | batch_questions = questions[i_batch*batch_size: (i_batch+1)*batch_size] 114 | e2idx, embeddings = build_embeddings(model, batch_questions, config, pred2idx, argw2idx, rtype) 115 | for q in batch_questions: 116 | pred = intrinsic.predict_mcnc(model, q, e2idx, embeddings, rtype, args.device) 117 | if pred == q.ans_idx: 118 | n_correct += 1 119 | else: 120 | n_incorrect += 1 121 | i_q += 1 122 | bar.update(i_q) 123 | bar.finish() 124 | print("n_correct={}, n_incorrect={}".format(n_correct, n_incorrect)) 125 | print("accuracy={}".format(float(n_correct)/(n_correct+n_incorrect))) 126 | 127 | 128 | if __name__ == "__main__": 129 | args = utils.bin_config(get_arguments) 130 | if torch.cuda.is_available(): 131 | args.device = torch.device('cuda') if args.gpu_id is None \ 132 | else torch.device('cuda:{}'.format(args.gpu_id)) 133 | else: 134 | args.device = torch.device('cpu') 135 | main() 136 | -------------------------------------------------------------------------------- /bin/evaluations/eval_disc_elmo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import pickle as pkl 9 | 10 | import six 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import progressbar 15 | from sklearn.metrics import accuracy_score, classification_report 16 | from allennlp.modules.elmo import Elmo, batch_to_ids 17 | from nltk.stem import WordNetLemmatizer 18 | from nltk import word_tokenize 19 | 20 | from dnee import utils 21 | from dnee.evals import intrinsic 22 | from dnee.events import indices 23 | from dnee.models import EventTransR, EventTransE, ArgWordEncoder, create_argw_encoder 24 | 25 | 26 | def get_arguments(argv): 27 | parser = argparse.ArgumentParser(description='intrinsic disc evaluation') 28 | parser.add_argument('question_file', metavar='QUESTION_FILE', 29 | help='questions.') 30 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 31 | help='relation config') 32 | 33 | parser.add_argument('-w', '--elmo_weight_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5", 34 | help='ELMo weight file') 35 | parser.add_argument('-p', '--elmo_option_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_options.json", 36 | help='ELMo option file') 37 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 38 | help='gpu id') 39 | parser.add_argument('-b', '--batch_size', type=int, default=100, 40 | help='batch size for evaluation') 41 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 42 | help='show info messages') 43 | parser.add_argument('-d', '--debug', action='store_true', default=False, 44 | help='show debug messages') 45 | args = parser.parse_args(argv) 46 | return args 47 | 48 | 49 | def _eval(questions, we, eval_func): 50 | n_batches = len(questions) // args.batch_size 51 | if len(questions) % args.batch_size > 0: 52 | n_batches += 1 53 | 54 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 55 | bar = progressbar.ProgressBar(widgets=widgets, maxval=n_batches).start() 56 | ys, y_preds = [], [] 57 | for i_batch in range(n_batches): 58 | subquestions = questions[i_batch*args.batch_size: (i_batch+1)*args.batch_size] 59 | logging.debug("#subquestions = {}".format(len(subquestions))) 60 | y, y_pred = eval_func(subquestions, we) 61 | ys += y 62 | y_preds += y_pred 63 | bar.update(i_batch+1) 64 | bar.finish() 65 | return ys, y_preds 66 | 67 | 68 | def _eval_by_rel_we(questions, we): 69 | ys, y_preds = [], [] 70 | for q in questions: 71 | x1 = intrinsic.embed_event_word_embeddings(q.rel.e1, we) 72 | y = q.ans_idx 73 | ys.append(y) 74 | 75 | scores = torch.zeros(len(q.choices), dtype=torch.float32) 76 | for i, c in enumerate(q.choices): 77 | ch = intrinsic.embed_event_word_embeddings(c, we) 78 | scores[i] = utils.cosine_similarity(x1, ch) 79 | _, y_pred = torch.max(scores, 0) 80 | y_preds.append(y_pred.item()) 81 | return ys, y_preds 82 | 83 | 84 | def event_toks(e): 85 | preds = e.pred.split('_') 86 | arg0s = [w.lower() for w in word_tokenize(e.arg0)] 87 | arg1s = [w.lower() for w in word_tokenize(e.arg1)] 88 | arg2s = [w.lower() for w in word_tokenize(e.arg2)] 89 | return preds + arg0s + arg1s + arg2s 90 | 91 | 92 | lemmatizer = None 93 | def embed_event_elmo(etoks, tok2idx, embs, dim=512): 94 | global lemmatizer 95 | if lemmatizer is None: 96 | lemmatizer = WordNetLemmatizer() 97 | 98 | lemma2idx = {lemmatizer.lemmatize(k, 'v'): v for k, v in six.iteritems(tok2idx)} 99 | cnt = 0 100 | res = torch.zeros(dim, dtype=torch.float32).to(args.device) 101 | for t in etoks: 102 | if t in tok2idx: 103 | res += embs[tok2idx[t]] 104 | cnt += 1 105 | elif t in lemma2idx: 106 | res += embs[lemma2idx[t]] 107 | cnt += 1 108 | if cnt > 0: 109 | res /= cnt 110 | return res 111 | 112 | 113 | def get_event_sentence(e): 114 | # return e.sentence 115 | preds = ' '.join(e.pred.split('_')) 116 | sent = e.arg0 + ' ' + preds + ' ' + e.arg1 117 | return sent 118 | 119 | 120 | mrr = 0.0 121 | def _eval_by_rel_elmo(questions, we): 122 | n_choices = len(questions[0].choices) 123 | e1_sents, ch_sents = [], [] 124 | for q in questions: 125 | e1_tokens = [w.lower() for w in word_tokenize(get_event_sentence(q.rel.e1))] 126 | e1_sents.append(e1_tokens) 127 | 128 | for i_c, ch in enumerate(q.choices): 129 | ch_tokens = [w.lower() for w in word_tokenize(get_event_sentence(ch))] 130 | ch_sents.append(ch_tokens) 131 | 132 | 133 | e1_ids = batch_to_ids(e1_sents).to(args.device) 134 | ch_ids = batch_to_ids(ch_sents).to(args.device) 135 | e1_embs = we(e1_ids)['elmo_representations'][0] 136 | ch_embs = we(ch_ids)['elmo_representations'][0] 137 | ys, y_preds = [], [] 138 | for i_q, q in enumerate(questions): 139 | e1_sent = e1_sents[i_q] 140 | e1_tok2idx = {w: i for i, w in enumerate(e1_sent)} 141 | e1_toks = event_toks(q.rel.e1) 142 | e1_emb = embed_event_elmo(e1_toks, e1_tok2idx, e1_embs[i_q]) 143 | 144 | scores = torch.zeros(n_choices, dtype=torch.float32).to(args.device) 145 | _ch_sents = ch_sents[i_q * n_choices: (i_q+1) * n_choices] 146 | _ch_embs = ch_embs[i_q * n_choices: (i_q+1) * n_choices] 147 | for i_c, ch in enumerate(q.choices): 148 | ch_sent = _ch_sents[i_c] 149 | tmp_ch_embs = _ch_embs[i_c] 150 | ch_tok2idx = {w: i for i, w in enumerate(ch_sent)} 151 | ch_toks = event_toks(ch) 152 | ch_emb = embed_event_elmo(ch_toks, ch_tok2idx, tmp_ch_embs) 153 | scores[i_c] = F.cosine_similarity(e1_emb, ch_emb, dim=0) 154 | 155 | _, y_pred = torch.max(scores, 0) 156 | y_preds.append(y_pred.item()) 157 | ys.append(q.ans_idx) 158 | return ys, y_preds 159 | 160 | 161 | def _eval_by_rel(questions, we): 162 | return _eval_by_rel_elmo(questions, we) 163 | 164 | 165 | def eval_by_rel(questions, we): 166 | return _eval(questions, we, _eval_by_rel) 167 | 168 | 169 | def main(): 170 | indices.set_relation_classes(args.relation_config) 171 | questions = pkl.load(open(args.question_file, 'rb')) 172 | logging.info("#questions={}".format(len(questions))) 173 | 174 | we = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0).to(args.device) 175 | with torch.no_grad(): 176 | logging.info('evaluating by rel') 177 | y, y_pred = eval_by_rel(questions, we) 178 | print("accuracy = {}".format(accuracy_score(y, y_pred))) 179 | logging.info("accuracy = {}".format(accuracy_score(y, y_pred))) 180 | 181 | 182 | if __name__ == "__main__": 183 | args = utils.bin_config(get_arguments) 184 | if args.gpu_id is not None: 185 | args.device = torch.device('cuda:{}'.format(args.gpu_id)) 186 | else: 187 | args.device = torch.device('cpu') 188 | main() 189 | -------------------------------------------------------------------------------- /dnee/events/events.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | import six 5 | 6 | from . import indices 7 | 8 | 9 | class Event(object): 10 | def __init__(self, pred, arg0, arg0_head, arg1, arg1_head, arg2, arg2_head, 11 | sentiment, ani0, ani1, ani2): 12 | #assert indices.pred2idx is not None 13 | rep = {'\n': ' ', '::': ' '} 14 | rep = dict((re.escape(k), v) for k, v in six.iteritems(rep)) 15 | pat = re.compile("|".join(rep.keys())) 16 | 17 | #self.pred = pred if pred in indices.pred2idx else indices.PRED_OOV 18 | self.pred = pred 19 | 20 | # self.pred_idx = indices.pred2idx[self.pred] 21 | self.arg0 = pat.sub(lambda m: rep[re.escape(m.group(0))], arg0) 22 | self.arg0_head = arg0_head 23 | self.arg1 = pat.sub(lambda m: rep[re.escape(m.group(0))], arg1) 24 | self.arg1_head = arg1_head 25 | self.arg2 = pat.sub(lambda m: rep[re.escape(m.group(0))], arg2) 26 | self.arg2_head = arg2_head 27 | 28 | if sys.version_info < (3, 0): 29 | self.pred = self.pred.encode('ascii', 'ignore') 30 | self.arg0 = self.arg0.encode('ascii', 'ignore') 31 | self.arg1 = self.arg1.encode('ascii', 'ignore') 32 | self.arg2 = self.arg2.encode('ascii', 'ignore') 33 | self.arg0_head = self.arg0_head.encode('ascii', 'ignore') 34 | self.arg1_head = self.arg1_head.encode('ascii', 'ignore') 35 | self.arg2_head = self.arg2_head.encode('ascii', 'ignore') 36 | self.sentiment = sentiment 37 | # self.sentiment_idx = indices.SENT2IDX[sentiment] 38 | self.ani0 = ani0 39 | # self.ani0_idx = indices.ANI2IDX[ani0] 40 | self.ani1 = ani1 41 | # self.ani1_idx = indices.ANI2IDX[ani1] 42 | self.ani2 = ani2 43 | # self.ani2_idx = indices.ANI2IDX[ani2] 44 | 45 | def __repr__(self): 46 | return "({}::{}::{}::{}::{}::{}::{}::{}::{}::{}::{})".format( 47 | self.pred, 48 | self.arg0, self.arg0_head, 49 | self.arg1, self.arg1_head, 50 | self.arg2, self.arg2_head, 51 | self.sentiment, self.ani0, 52 | self.ani1, self.ani2) 53 | 54 | def get_pred_index(self, pred2idx): 55 | return pred2idx[self.pred] if self.pred in pred2idx else pred2idx[indices.PRED_OOV] 56 | 57 | def get_arg_indices(self, argn, argw2idx, use_head=False, arg_len=-1): 58 | if use_head: 59 | if argn == 0: 60 | target = self.arg0_head 61 | elif argn == 1: 62 | target = self.arg1_head 63 | elif argn == 2: 64 | target = self.arg2_head 65 | else: 66 | return None 67 | else: 68 | if argn == 0: 69 | target = self.arg0 70 | elif argn == 1: 71 | target = self.arg1 72 | elif argn == 2: 73 | target = self.arg2 74 | else: 75 | return None 76 | sp = target.split(' ') 77 | if arg_len != -1 and len(sp) > arg_len: 78 | sp = sp[:arg_len] 79 | # append EOS 80 | # sp.append(indices.EOS_ARG_WORD) 81 | ret = [argw2idx[tok] if tok in argw2idx 82 | else argw2idx[indices.UNKNOWN_ARG_WORD] for tok in sp] 83 | # padding 0 84 | if len(ret) < arg_len: 85 | for i in range(arg_len-len(ret)): 86 | ret.append(0) 87 | return ret 88 | 89 | @classmethod 90 | def from_string(cls, line): 91 | line = line.rstrip("\n")[1:-1] 92 | sp = line.split('::') 93 | obj = cls(sp[0], sp[1], sp[2], sp[3], sp[4], sp[5], sp[6], 94 | sp[7], sp[8], sp[9], sp[10]) 95 | return obj 96 | 97 | @classmethod 98 | def from_json(cls, e): 99 | pred = e['predicate'] 100 | # only use the first sub-argument for now 101 | arg0_head = e['arg0'][0] if 'arg0' in e else indices.NO_ARG 102 | arg0 = e['arg0_text'][0] if 'arg0_text' in e else indices.NO_ARG 103 | 104 | arg1_head = e['arg1'][0] if 'arg1' in e else indices.NO_ARG 105 | arg1 = e['arg1_text'][0] if 'arg1_text' in e else indices.NO_ARG 106 | 107 | arg2_head = e['arg2'][0] if 'arg2' in e else indices.NO_ARG 108 | arg2 = e['arg2_text'][0] if 'arg2_text' in e else indices.NO_ARG 109 | sentiment = e['sentiment'] if 'sentiment' in e else None 110 | ani0 = e['ani0'][0] if 'ani0' in e else indices.UNKNOWN_ANIMACY 111 | ani1 = e['ani1'][0] if 'ani1' in e else indices.UNKNOWN_ANIMACY 112 | ani2 = e['ani2'][0] if 'ani2' in e else indices.UNKNOWN_ANIMACY 113 | obj = cls(pred, arg0, arg0_head, arg1, arg1_head, arg2, arg2_head, 114 | sentiment, ani0, ani1, ani2) 115 | return obj 116 | 117 | def __eq__(self, other): 118 | if isinstance(other, self.__class__): 119 | return self.__repr__() == other.__repr__() 120 | return False 121 | 122 | def __ne__(self, other): 123 | return not self.__eq__(other) 124 | 125 | def valid_pred(self, pred2idx): 126 | return (self.pred in pred2idx) 127 | 128 | def valid_arg_len(self, config): 129 | # minus one for appending EOS 130 | if len(self.arg0.split(' ')) > config['arg0_max_len']: 131 | return False 132 | if len(self.arg1.split(' ')) > config['arg1_max_len']: 133 | return False 134 | if len(self.arg2.split(' ')) > config['arg2_max_len']: 135 | return False 136 | return True 137 | 138 | 139 | class EventRelation: 140 | def __init__(self, e1, e2, rtype, label=1): 141 | assert label == 1 or label == 0 142 | self.label = label 143 | self.e1 = e1 144 | self.e2 = e2 145 | self.rtype = rtype 146 | self.rtype_idx = indices.REL2IDX[rtype] 147 | 148 | def __repr__(self): 149 | return "{} ||| {} ||| {}".format(self.rtype_idx, 150 | self.e1, 151 | self.e2) 152 | 153 | @classmethod 154 | def from_string(cls, line): 155 | line = line.rstrip("\n") 156 | sp = line.split(' ||| ') 157 | assert len(sp) == 3 158 | rtype_idx = int(sp[0]) 159 | rtype = indices.IDX2REL[rtype_idx] 160 | e1 = Event.from_string(sp[1]) 161 | e2 = Event.from_string(sp[2]) 162 | obj = cls(e1, e2, rtype) 163 | return obj 164 | 165 | def is_valid(self, pred2idx, config): 166 | return self.valid_pred(pred2idx) and self.valid_arg_len(config) 167 | 168 | def valid_pred(self, pred2idx): 169 | return (self.e1.valid_pred(pred2idx) and self.e2.valid_pred(pred2idx)) 170 | 171 | def valid_arg_len(self, config): 172 | return self.e1.valid_arg_len(config) and self.e2.valid_arg_len(config) 173 | 174 | def to_indices(self, pred2idx, argw2idx, use_head=False, arg_len=-1): 175 | return [self.label, 176 | self.rtype_idx, 177 | self.e1.get_pred_index(pred2idx), 178 | self.e1.get_arg_indices(0, argw2idx, use_head=use_head, arg_len=arg_len), 179 | self.e1.get_arg_indices(1, argw2idx, use_head=use_head, arg_len=arg_len), 180 | self.e1.get_arg_indices(2, argw2idx, use_head=use_head, arg_len=arg_len), 181 | self.e2.get_pred_index(pred2idx), 182 | self.e2.get_arg_indices(0, argw2idx, use_head=use_head, arg_len=arg_len), 183 | self.e2.get_arg_indices(1, argw2idx, use_head=use_head, arg_len=arg_len), 184 | self.e2.get_arg_indices(2, argw2idx, use_head=use_head, arg_len=arg_len)] 185 | -------------------------------------------------------------------------------- /bin/evaluations/test_combined_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Note that this re-uses some codes from Skip-Thoughts 3 | https://github.com/ryankiros/skip-thoughts 4 | """ 5 | import os 6 | import argparse 7 | import logging 8 | import time 9 | import json 10 | import codecs 11 | from collections import OrderedDict 12 | import random 13 | from copy import deepcopy 14 | 15 | import progressbar 16 | import numpy as np 17 | from nltk import word_tokenize 18 | from sklearn.metrics import f1_score 19 | import torch 20 | from torch.utils.data import DataLoader 21 | from torch.autograd import Variable 22 | import torch.optim as optim 23 | from allennlp.modules.elmo import Elmo, batch_to_ids 24 | 25 | from dnee import utils 26 | from dnee.events import indices, Event, EventRelation 27 | from dnee.models import EventTransR, EventTransE, ArgWordEncoder, create_argw_encoder 28 | from dnee.models import AttnEventTransE, AttnEventTransR 29 | from dnee.events import extract_events 30 | from dnee.evals import discourse_sense as ds 31 | 32 | 33 | def get_arguments(argv): 34 | parser = argparse.ArgumentParser(description='evaluate embeddings on Discourse Sense') 35 | parser.add_argument('ds_dev_file', metavar='DS_DEV_FILE', 36 | help='the file for Discourse Sense developing') 37 | parser.add_argument('ds_test_file', metavar='DS_TEST_FILE', 38 | help='the file for Discourse Sense testing') 39 | parser.add_argument('ds_blind_file', metavar='DS_BLIND_FILE', 40 | help='the file for Discourse Sense blind testing') 41 | 42 | parser.add_argument('model_file', metavar='MODEL_FILE', 43 | help='modelf file.') 44 | parser.add_argument('encoder_file', metavar='ENCODER_FILE', 45 | help='argument word encoder.') 46 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 47 | help='config for training') 48 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 49 | help='config for relation classes') 50 | 51 | parser.add_argument('ds_model_file', metavar='DS_MODEL_FILE', 52 | help='ds modelf file.') 53 | 54 | parser.add_argument('output_folder', metavar='OUTPUT_FOLDER', 55 | help='the folder for outputs.') 56 | 57 | parser.add_argument('-w', '--elmo_weight_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5", 58 | help='ELMo weight file') 59 | parser.add_argument('-p', '--elmo_option_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_options.json", 60 | help='ELMo option file') 61 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 62 | help='gpu id') 63 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 64 | help='show info messages') 65 | parser.add_argument('-d', '--debug', action='store_true', default=False, 66 | help='show debug messages') 67 | 68 | args = parser.parse_args(argv) 69 | return args 70 | 71 | 72 | def rel_output(rel, predicted_sense): 73 | new_rel = {} 74 | new_rel['DocID'] = rel['DocID'] 75 | new_rel['ID'] = rel['ID'] 76 | 77 | new_rel['Arg1'] = {} 78 | new_rel['Arg1']['TokenList'] = [] 79 | for tok in rel['Arg1']['TokenList']: 80 | new_rel['Arg1']['TokenList'].append(tok[2]) 81 | 82 | new_rel['Arg2'] = {} 83 | new_rel['Arg2']['TokenList'] = [] 84 | for tok in rel['Arg2']['TokenList']: 85 | new_rel['Arg2']['TokenList'].append(tok[2]) 86 | 87 | new_rel['Connective'] = {} 88 | new_rel['Connective']['TokenList'] = [] 89 | for tok in rel['Connective']['TokenList']: 90 | new_rel['Connective']['TokenList'].append(tok[2]) 91 | 92 | new_rel['Sense'] = [predicted_sense] 93 | new_rel['Type'] = rel['Type'] 94 | return new_rel 95 | 96 | 97 | def eval_relations(fpath, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, dnee_seq_len): 98 | rels = [json.loads(line) for line in open(fpath)] 99 | rels = [rel for rel in rels if rel['Type'] != 'Explicit' and rel['Sense'][0] in indices.DISCOURSE_REL2IDX] 100 | cm, valid_senses = ds.create_cm(rels, indices.DISCOURSE_REL2IDX) 101 | 102 | x, y = ds.get_features(rels, elmo, seq_len, dnee_model, dnee_seq_len, config, pred2idx, argw2idx, indices.DISCOURSE_REL2IDX, 103 | device=args.device, use_dnee=True) 104 | x0, x1, x0_dnee, x1_dnee, x_dnee = x 105 | x0, x1 = x0.to(args.device), x1.to(args.device) 106 | 107 | model = ds.AttentionNN(len(indices.DISCOURSE_REL2IDX), event_dim=config['event_dim'], use_event=True).to(args.device) 108 | model.load_state_dict(torch.load(args.ds_model_file, map_location=lambda storage, location: storage)) 109 | model.eval() 110 | 111 | n_rels = len(indices.DISCOURSE_REL2IDX) 112 | y_pred = model.predict(x0, x1, x0_dnee, x1_dnee, x_dnee) 113 | 114 | fw = open(fw_path, 'w') 115 | for i, rel in enumerate(rels): 116 | final_pred = y_pred[i].item() 117 | new_rel = rel_output(rel, indices.DISCOURSE_IDX2REL[final_pred]) 118 | fw.write(json.dumps(new_rel) + '\n') 119 | fw.close() 120 | 121 | 122 | def main(): 123 | # DNEE 124 | t1 = time.time() 125 | indices.set_relation_classes(args.relation_config) 126 | config = json.load(open(args.training_config, 'r')) 127 | pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) 128 | argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) 129 | n_preds = len(pred2idx) 130 | argw_vocabs = argw2idx.keys() 131 | argw_encoder = create_argw_encoder(config, args.device) 132 | argw_encoder.load(args.encoder_file) 133 | 134 | logging.info("model class: " + config['model_type']) 135 | ModelClass = eval(config['model_type']) 136 | dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 137 | dnee_model.load_state_dict(torch.load(args.model_file, 138 | map_location=lambda storage, location: storage)) 139 | logging.info('Loading DNEE: {} s'.format(time.time()-t1)) 140 | 141 | elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0) 142 | 143 | # DS_TRAIN_FLD = "ds_train_elmo" 144 | # DNEE_TRAIN_FLD = "ds_train_transe" if config['model_type'] == 'EventTransE' else 'ds_train_transr_tmp' 145 | # train_data = ds.DsDataset(DS_TRAIN_FLD, DNEE_TRAIN_FLD) 146 | # logging.info("DNEE_TRAIN_FLD={}".format(DNEE_TRAIN_FLD)) 147 | # seq_len = train_data.seq_len 148 | # event_seq_len = train_data.dnee_seq_len 149 | 150 | # These are the max seq length from training data (above code) 151 | # We hardcode them to avoid loading training data 152 | seq_len, event_seq_len = 392, 14 153 | logging.info("seq_len={}, event_seq_len={}".format(seq_len, event_seq_len)) 154 | 155 | t1 = time.time() 156 | logging.info('dev...') 157 | fw_path = os.path.join(args.output_folder, 'dev_res.json') 158 | eval_relations(args.ds_dev_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len) 159 | logging.info('Eval DEV: {} s'.format(time.time()-t1)) 160 | 161 | t1 = time.time() 162 | logging.info('test...') 163 | fw_path = os.path.join(args.output_folder, 'test_res.json') 164 | eval_relations(args.ds_test_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len) 165 | logging.info('Eval TEST: {} s'.format(time.time()-t1)) 166 | 167 | t1 = time.time() 168 | logging.info('blind...') 169 | fw_path = os.path.join(args.output_folder, 'blind_res.json') 170 | eval_relations(args.ds_blind_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len) 171 | logging.info('Eval BLIND: {} s'.format(time.time()-t1)) 172 | 173 | 174 | if __name__ == '__main__': 175 | args = utils.bin_config(get_arguments) 176 | if torch.cuda.is_available(): 177 | args.device = torch.device('cuda') if args.gpu_id is None \ 178 | else torch.device('cuda:{}'.format(args.gpu_id)) 179 | else: 180 | args.device = torch.device('cpu') 181 | main() 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Relation Script Learning for Discourse Relations 2 | This repository contains code, data, and pre-trained models for the following papers. 3 | 4 | ``` 5 | I-Ta Lee, and Dan Goldwasser, "Multi-Relational Script Learning for Discourse Relations", ACL 2019 6 | ``` 7 | 8 | bibtex 9 | ``` 10 | @inproceedings{lee2019multi, 11 | title={Multi-Relational Script Learning for Discourse Relations}, 12 | author={Lee, I-Ta and Goldwasser, Dan}, 13 | booktitle={Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | If you use any resources within this repository, please cite the paper. 19 | 20 | # Dependencies 21 | 22 | Our project was originally using Python2.7. However, when ELMo embeddings, which requires Python3, come into play, we migrate the code runnable with Python3. Please install dependencies for both. 23 | 24 | ``` 25 | pip install -r requirements.txt 26 | pip3 install -r requirements3.txt 27 | ``` 28 | 29 | # Data 30 | 31 | Our core experiments (MCNC, MCNS, MCNE) use data splits of NYT section of English Gigawords, following the data splits given by Granroth-Wilding (https://mark.granroth-wilding.co.uk/papers/what_happens_next/) [1]. 32 | 33 | As our models need entity mention spans rather than entity head words only, which are not provided in Granroth-Wilding's code outputs, considering the stochastic factors in the pre-processing, we release a newly pre-processsed train/dev/test data for MCNC, MCNS, MCNE (check the experiments in the paper). Unlike Granroth-Wilding's code, we use Stanford CoreNLP as the text pipeline tool and follow the heuristic mentioned in the paper for retrieving all the events. 34 | 35 | ## Download 36 | 37 | The datasets are in token-indexed triplet format and the index files are provided. 38 | 39 | - dev/test data and misc: [data.tar.gz](https://drive.google.com/file/d/1Jv-O69Zd0A-YeHGrYqKHlsu6qPLfD3yh/view?usp=sharing) 40 | - training data: [training_data.tar.gz](https://purdue0-my.sharepoint.com/:u:/g/personal/lee2226_purdue_edu1/EShU8bK4-sNHkSGOn9MtN5wBoHXJD4o_V-n9lFwmnP3FJw?e=ME5y6L) 41 | - pretrained models: [pretrained.tar.gz](https://drive.google.com/file/d/1ogphXeArL4_qZFuN3qQIpGPLCnrbk5Gs/view?usp=sharing) 42 | - skip-thought word embeddings: we use skip-thought's word embeddings [2] for our encoder. You can download pre-trained embeddings from https://github.com/ryankiros/skip-thoughts. Put them in the following locations: 43 | ``` 44 | data/skipthought_models/dictionary.txt 45 | data/skipthought_models/utable.npy 46 | data/skipthought_models/btable.npy 47 | data/skipthought_models/uni_skip.npz 48 | data/skipthought_models/uni_skip.npz.pkl 49 | data/skipthought_models/bi_skip.npz 50 | data/skipthought_models/bi_skip.npz.pkl 51 | ``` 52 | - GloVe Word Embeddings: you can get glove.6B.300d.txt from https://nlp.stanford.edu/projects/glove/. 53 | - ELMo: you can download medium size ELMo model from https://allennlp.org/elmo (elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5 and elmo_2x2048_256_2048cnn_1xhighway_options.hdf5). Put them in ./data folder. 54 | 55 | # Pre-trained models 56 | 57 | You can download the models from the link in the Download section 58 | 59 | - EventTransE: 60 | - model: pretrained/out_transe_v0.2.10_long9_tmp 61 | - config: train_config_transe_v0.2.10_long9.json 62 | - EventTransR: pretrained/out_transr_v0.2.10_long9_tmp 63 | - model: pretrained/out_transr_v0.2.10_long9_tmp 64 | - config: train_config_transr_v0.2.10_long9.json 65 | 66 | # Run experiments with pre-trained models 67 | 68 | To begin with, download the test data and pretrain models in the Download section, and decompress them in the repo folder. You need two configuation files: train_config_{xxx}.json and relation_{xxx}.json. Check all the file paths in the configuration file are correct. 69 | 70 | Note that for all the executable .py, you can use **-h** argument to see the help and **-g ** to specify running on a specific gpu. 71 | 72 | ## MCNC 73 | 74 | For EventTransE 75 | ``` 76 | python bin/evaluations/eval_mcnc.py -v pretrained/out_transe_v02.10_long9_tmp/model_2_3_2591.pt pretrained/out_transe_v0.2.10_long9_tmp/argw_enc_2_3_2591.pt data/mcnc_test_v0.3.4/mcnc_coref_next.pkl train_config_transe_v0.2.10_long9.json relation_9disc.json 77 | ``` 78 | 79 | For EventTransR, you simply need to replace the model file, argument encoder file, and config file. 80 | 81 | 82 | ## MCNS and MCNE 83 | 84 | Two experiments run in one command. 85 | 86 | For EventTransE: 87 | ``` 88 | python bin/evaluations/eval_mcns.py -v pretrained/out_transe_v0.2.10_long9_tmp/model_2_3_2591.pt pretrained/out_transe_v0.2.0_long9_tmp/argw_enc_2_3_2591.pt glove.6B.300d.txt data/mcnc_test_v0.3.4/mcns_coref_next.pkl train_config_transe_v0.2.10_long9.json relation_9disc.json Viterbi 89 | ``` 90 | 91 | For EventTransR, you simply need to replace the model file, argument encoder file, and config file. 92 | 93 | 94 | ## Intrinsic Discourse Relations 95 | 96 | The evaluations here take long time to run. GPUs recommended. 97 | 98 | ### Predict an event or an relation class 99 | 100 | There are two set-up here: 101 | 102 | - Predict the next event given one event and relation 103 | - Predict the relation given two events 104 | 105 | The following command will output results for both setups. 106 | 107 | For EventTransE 108 | ``` 109 | python bin/evaluations/eval_disc.py pretrained/out_transe_v0.2.10_long9_tmp/model_2_3_2591.pt pretrained/out_transe_v0.2.10_ong9_tmp/argw_enc_2_3_2591.pt data/disc_test_v0.2.0.pkl train_config_transe_v0.2.10_long9.json relation_9disc.json -v 110 | ``` 111 | 112 | For EventTransR, you simply need to replace the model file, argument encoder file, and config file. 113 | 114 | You can do this using pre-trained ELMo embeddings like what is described in the paper. To do so, download ELMo models from the Download section. Python3 is required by ELMo. This following commands also output results for the two setups. 115 | ``` 116 | python3 bin/evaluations/eval_disc_elmo.py -v data/disc_test_v0.2.0.pkl relation_9disc.json 117 | ``` 118 | 119 | 120 | ### Triplet classication 121 | 122 | Another setup is to do a binary classification for a given triplet. Download ELMo models from the Download section. Python3 is required by ELMo. 123 | ``` 124 | python3 bin/evaluations/eval_disc_binary.py -v pretrained/out_transe_v0.2.10_long9_tmp/model_2_3_2591.pt pretrained/out_trane_v0.2.10_long9_tmp/argw_enc_2_3_2591.pt data/disc_dev_v0.2.0.pkl data/disc_test_v0.2.0.pkl train_config_transe_v0.2.10_long9.json relation_9disc.json 125 | ``` 126 | This command runs for **ELMo+EventTransE**. To run with ELMo-only, add **-m** to the commmand. The results are shown in the log file or stdout. The reported results are averaged over 5 runs 127 | 128 | 129 | ## Implicit Discourse Sense Classifications 130 | 131 | Download ELMo models from the Download section. Python3 is required by ELMo. Download the pre-trained classifier (in pretraind.tar.gz in the Download section) that takes in EventTransE+ELMo as input representations. The result reported in the paper is averaged over 5 runs. 132 | ``` 133 | python3 bin/evaluations/test_combined_features.py -v data/pdtb_ds/ds_dev_events.json data/pdtb_ds/ds_test_events.json data/ptb_ds/ds_blind_test_events.json pretrained/out_ds_transe_nonexplicit_t5/model_0_0_33.pt pretrained/out_ds_transe_nonexplicit_t5/argw_enc_0_0_33.pt train_config_ds_transe.json relation_pdtb.json pretrained/out_ds_comb_elmo_transe_e200_4/best_model.pt output_folder 134 | ``` 135 | This will output predictions in json format, which is supported by the CONLL 2016's official scorer. 136 | - You can get the scorer.py from here: https://github.com/attapol/conll16st 137 | - You can get the gold data from https://www.cs.brandeis.edu/~clp/conll16st/index.html 138 | 139 | 140 | # Train from scratch 141 | 142 | Download the training data from the Download section. Make sure the training file path in the config file is correct. Then run 143 | ``` 144 | python bin/train.py -v -r train_config_transe_v0.2.10_long9.json output_model 145 | ``` 146 | Again, for EventTransR, simply replace the config file. 147 | 148 | 149 | # References 150 | 151 | [1] Granroth-Wilding, Mark, and Stephen Clark. "What happens next? event prediction using a compositional neural network model." Thirtieth AAAI Conference on Artificial Intelligence. 2016. 152 | [2] Kiros, Ryan, et al. "Skip-thought vectors." Advances in neural information processing systems. 2015. 153 | -------------------------------------------------------------------------------- /bin/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import random 9 | import pickle as pkl 10 | 11 | import progressbar 12 | import torch 13 | import torch.optim as optim 14 | from torch.utils.data import DataLoader 15 | from torch.autograd import Variable 16 | import torch.nn.functional as F 17 | 18 | from dnee import utils 19 | from dnee.events import indices 20 | from dnee.datasets.event_relation import EventRelationDataset 21 | from dnee.models import EventTransR, EventTransE, NegativeSampler, ArgWordEncoder, create_argw_encoder 22 | from dnee.models import AttnEventTransE, AttnEventTransR 23 | 24 | 25 | def get_arguments(argv): 26 | parser = argparse.ArgumentParser(description='DNEE training') 27 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 28 | help='config for training') 29 | parser.add_argument('output_folder', metavar='OUTPUT_FOLDER', 30 | help='output folder for models etc.') 31 | 32 | parser.add_argument('-l', '--prev_best_loss', type=float, default=-1.0, 33 | help='previous best loss') 34 | parser.add_argument('-p', '--prev_folder', default=None, 35 | help='previous training outputs') 36 | parser.add_argument('-e', '--epoch_batch', default=None, 37 | help='continued epoch batch numbers, e.g., 1_31111') 38 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 39 | help='gpu id') 40 | parser.add_argument('-n', '--normalize_embeddings', action='store_true', default=False, 41 | help='normalize embeddings after each update') 42 | parser.add_argument('-r', '--sample_rel', action='store_true', default=False, 43 | help='sample relation types for negative sampling') 44 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 45 | help='show info messages') 46 | parser.add_argument('-d', '--debug', action='store_true', default=False, 47 | help='show debug messages') 48 | args = parser.parse_args(argv) 49 | return args 50 | 51 | 52 | def save(fld, model, optimizer, i_epoch, i_file, i_batch): 53 | fpath = os.path.join(fld, "model_{}_{}_{}.pt".format(i_epoch, i_file, i_batch)) 54 | torch.save(model.state_dict(), fpath) 55 | 56 | fpath = os.path.join(fld, "optim_{}_{}_{}.pt".format(i_epoch, i_file, i_batch)) 57 | torch.save(optimizer.state_dict(), fpath) 58 | 59 | fpath = os.path.join(fld, "argw_enc_{}_{}_{}.pt".format(i_epoch, i_file, i_batch)) 60 | model.argw_encoder.save(fpath) 61 | 62 | 63 | def save_losses(fld, losses): 64 | fpath = os.path.join(fld, 'losses.pkl') 65 | pkl.dump(losses, open(fpath, 'wb')) 66 | 67 | 68 | def main(): 69 | logging.info('using {} for computation.'.format(args.device)) 70 | config = json.load(open(args.training_config, 'r')) 71 | 72 | n_preds = EventTransE.get_n_preds(config['predicate_indices']) 73 | argw_encoder = create_argw_encoder(config, args.device) 74 | if args.prev_folder: 75 | fpath = os.path.join(args.prev_folder, 'argw_enc_{}.pt'.format(args.epoch_batch)) 76 | logging.info('loading {}...'.format(fpath)) 77 | argw_encoder.load(fpath) 78 | 79 | logging.info("model class: " + config['model_type']) 80 | ModelClass = eval(config['model_type']) 81 | model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 82 | if args.prev_folder: 83 | fpath = os.path.join(args.prev_folder, 'model_{}.pt'.format(args.epoch_batch)) 84 | logging.info('loading {}...'.format(fpath)) 85 | model.load_state_dict(torch.load(fpath, map_location=lambda storage, location: storage)) 86 | 87 | sampler = NegativeSampler([config['arg0_max_len'], config['arg1_max_len'], config['arg2_max_len']], 88 | config['n_rel_types'], 89 | sample_rel=args.sample_rel) 90 | 91 | optimizer = utils.get_optimizer(model, config) 92 | if args.prev_folder: 93 | fpath = os.path.join(args.prev_folder, 'optim_{}.pt'.format(args.epoch_batch)) 94 | logging.info('loading {}...'.format(fpath)) 95 | optimizer.load_state_dict(torch.load(fpath, map_location=lambda storage, location: storage)) 96 | 97 | arg_lens = [config['arg0_max_len'], config['arg1_max_len']] 98 | if args.prev_best_loss >= 0: 99 | min_loss = args.prev_best_loss 100 | best_epoch = -1 101 | else: 102 | min_loss, best_epoch = None, None 103 | losses = [] 104 | for i_epoch in range(config['n_epochs']): 105 | files = [os.path.join(config['training_data'], f) 106 | for f in os.listdir(config['training_data']) 107 | if f.endswith('txt')] 108 | random.shuffle(files) 109 | epoch_start = time.time() 110 | for i_file, f in enumerate(files): 111 | # ToDo: pre-create these objects and move these outside the loop 112 | train_data = EventRelationDataset(f, arg_lens) 113 | train_loader = DataLoader(train_data, 114 | batch_size=config["batch_size"], 115 | shuffle=True, 116 | num_workers=config["n_dataloader_workers"]) 117 | 118 | for i_batch, (x, y) in enumerate(train_loader): 119 | if x.shape[0] != config['batch_size']: 120 | logging.debug('batch {} batch_size mismatch, skip.'.format(i_batch)) 121 | continue 122 | 123 | # sampling 124 | x_neg = Variable(sampler.sampling(x).to(args.device), requires_grad=False) 125 | x_pos = Variable(x.to(args.device), requires_grad=False) 126 | 127 | model.train() 128 | p_score = model(x_pos) 129 | n_score = model(x_neg) 130 | 131 | # average over folds of negative samples 132 | neg_ratio = n_score.shape[0] // p_score.shape[0] 133 | n_samples = p_score.shape[0] 134 | for i in range(1, neg_ratio): 135 | n_score[:n_samples] = n_score[:n_samples] + n_score[i*n_samples: (i+1)*n_samples] 136 | n_score = n_score[:n_samples] / neg_ratio 137 | 138 | # step 139 | loss = model.loss_func(p_score, n_score) 140 | optimizer.zero_grad() 141 | loss.backward() 142 | optimizer.step() 143 | 144 | if args.normalize_embeddings: 145 | model.rel_embeddings.weight.data = F.normalize(model.rel_embeddings.weight.data, 146 | p=model.norm, dim=1) 147 | model.pred_embeddings.weight.data = F.normalize(model.pred_embeddings.weight.data, 148 | p=model.norm, dim=1) 149 | 150 | tmp_loss = loss.item() 151 | losses.append(tmp_loss) 152 | 153 | model.eval() 154 | # dev 155 | 156 | if min_loss: 157 | if tmp_loss < min_loss: 158 | logging.info("best: {}, {}, {}: loss={}, time={}".format(i_epoch, i_file, i_batch, tmp_loss, time.time()-epoch_start)) 159 | min_loss = tmp_loss 160 | best_epoch = i_epoch 161 | save(args.output_folder, model, optimizer, i_epoch, i_file, i_batch) 162 | save_losses(args.output_folder, losses) 163 | else: 164 | logging.info("best: {}, {}, {}: loss={}, time={}".format(i_epoch, i_file, i_batch, tmp_loss, time.time()-epoch_start)) 165 | min_loss = tmp_loss 166 | best_epoch = i_epoch 167 | save(args.output_folder, model, optimizer, i_epoch, i_file, i_batch) 168 | save_losses(args.output_folder, losses) 169 | 170 | if i_batch % config['n_batches_per_record'] == 0: 171 | logging.info("{}, {}, {}: loss={}, time={}".format(i_epoch, i_file, i_batch, tmp_loss, time.time()-epoch_start)) 172 | save_losses(args.output_folder, losses) 173 | 174 | logging.info("epoch {}: loss={}, time={}".format(i_epoch, tmp_loss, time.time()-epoch_start)) 175 | 176 | i_epoch = config["n_epochs"] 177 | i_file = 0 178 | i_batch = 0 179 | save(args.output_folder, model, optimizer, i_epoch, i_file, i_batch) 180 | save_losses(args.output_folder, losses) 181 | 182 | 183 | if __name__ == "__main__": 184 | args = utils.bin_config(get_arguments) 185 | if torch.cuda.is_available(): 186 | args.device = torch.device('cuda') if args.gpu_id is None \ 187 | else torch.device('cuda:{}'.format(args.gpu_id)) 188 | else: 189 | args.device = torch.device('cpu') 190 | main() 191 | -------------------------------------------------------------------------------- /bin/evaluations/train_combined_features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import random 9 | import pickle as pkl 10 | 11 | import h5py 12 | import progressbar 13 | from sklearn.metrics import f1_score 14 | import torch 15 | import torch.optim as optim 16 | from torch.utils.data import DataLoader, Dataset 17 | import torch.nn.functional as F 18 | from allennlp.modules.elmo import Elmo, batch_to_ids 19 | # torch.multiprocessing.set_sharing_strategy('file_system') 20 | 21 | from dnee import utils 22 | from dnee.events import indices 23 | from dnee.models import EventTransR, EventTransE, NegativeSampler, ArgWordEncoder, create_argw_encoder 24 | from dnee.models import AttnEventTransR, AttnEventTransE 25 | from dnee.evals import discourse_sense as ds 26 | from dnee.evals.confusion_matrix import Alphabet, ConfusionMatrix 27 | from dnee.models import skipthoughts as st 28 | 29 | 30 | def get_arguments(argv): 31 | parser = argparse.ArgumentParser(description='training for Discourse Sense') 32 | parser.add_argument('ds_train_fld', metavar='DS_TRAIN_FLD', 33 | help='preprocessed training data') 34 | parser.add_argument('ds_dev_rel_file', metavar='DS_DEV_REL_FILE', 35 | help='gold relation file with events for dev') 36 | 37 | parser.add_argument('model_file', metavar='MODEL_FILE', 38 | help='modelf file.') 39 | parser.add_argument('encoder_file', metavar='ENCODER_FILE', 40 | help='argument word encoder.') 41 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 42 | help='config for training') 43 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 44 | help='config for relation classes') 45 | parser.add_argument('output_folder', metavar='OUTPUT_FOLDER', 46 | help='output folder for models etc.') 47 | 48 | parser.add_argument('-w', '--elmo_weight_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5", 49 | help='ELMo weight file') 50 | parser.add_argument('-p', '--elmo_option_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_options.json", 51 | help='ELMo option file') 52 | parser.add_argument('-s', '--no_dnee_scores', action='store_true', default=False, 53 | help='Not using DNEE scores as features') 54 | parser.add_argument('-n', '--dnee_train_fld', default=None, 55 | help='DNEE training data') 56 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 57 | help='gpu id') 58 | parser.add_argument('-l', '--learning_rate', type=float, default=1e-3, 59 | help='initial learning rate') 60 | parser.add_argument('-b', '--batch_size', type=int, default=500, 61 | help='batch size') 62 | parser.add_argument('-e', '--n_epoches', type=int, default=5, 63 | help='number of epoches') 64 | parser.add_argument('-r', '--dropout', type=float, default=0.5, 65 | help='dropout rate') 66 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 67 | help='show info messages') 68 | parser.add_argument('-d', '--debug', action='store_true', default=False, 69 | help='show debug messages') 70 | args = parser.parse_args(argv) 71 | return args 72 | 73 | 74 | def save(fld, model, optimizer, i_epoch, i_file, i_batch): 75 | fpath = os.path.join(ld, "model_{}_{}_{}.pt".format(i_epoch, i_file, i_batch)) 76 | torch.save(model.state_dict(), fpath) 77 | 78 | fpath = os.path.join(fld, "optim_{}_{}_{}.pt".format(i_epoch, i_file, i_batch)) 79 | torch.save(optimizer.state_dict(), fpath) 80 | 81 | fpath = os.path.join(fld, "argw_enc_{}_{}_{}.pt".format(i_epoch, i_file, i_batch)) 82 | model.argw_encoder.save(fpath) 83 | 84 | 85 | def save_losses(fld, losses): 86 | fpath = os.path.join(fld, 'losses.pkl') 87 | pkl.dump(losses, open(fpath, 'w')) 88 | 89 | 90 | def save_scores(fld, scores, fname): 91 | fpath = os.path.join(fld, fname) 92 | pkl.dump(scores, open(fpath, 'w')) 93 | 94 | 95 | def main(): 96 | logging.info('using {} for computation.'.format(args.device)) 97 | config = json.load(open(args.training_config, 'r')) 98 | 99 | indices.set_relation_classes(args.relation_config) 100 | pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) 101 | argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) 102 | 103 | n_preds = len(pred2idx) 104 | argw_encoder = create_argw_encoder(config, args.device) 105 | argw_encoder.load(args.encoder_file) 106 | 107 | logging.info("model class: " + config['model_type']) 108 | ModelClass = eval(config['model_type']) 109 | dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 110 | dnee_model.load_state_dict(torch.load(args.model_file, 111 | map_location=lambda storage, location: storage)) 112 | dnee_model.eval() 113 | 114 | elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0) 115 | train_data = ds.DsDataset(args.ds_train_fld, args.dnee_train_fld) 116 | 117 | dev_rels = [json.loads(line) for line in open(args.ds_dev_rel_file)] 118 | dev_rels = [rel for rel in dev_rels if rel['Type'] != 'Explicit' and rel['Sense'][0] in indices.DISCOURSE_REL2IDX] 119 | dev_cm, dev_valid_senses = ds.create_cm(dev_rels, indices.DISCOURSE_REL2IDX) 120 | 121 | dnee_seq_len = train_data.dnee_seq_len if args.dnee_train_fld else None 122 | x_dev, y_dev = ds.get_features(dev_rels, elmo, train_data.seq_len, dnee_model, dnee_seq_len, config, pred2idx, argw2idx, indices.DISCOURSE_REL2IDX, 123 | device=args.device, use_dnee=(args.dnee_train_fld is not None)) 124 | x0_dev, x1_dev, x0_dnee_dev, x1_dnee_dev, x_dnee_dev = x_dev 125 | x0_dev, x1_dev = x0_dev.to(args.device), x1_dev.to(args.device) 126 | 127 | model = ds.AttentionNN(len(indices.DISCOURSE_REL2IDX), event_dim=config['event_dim'], dropout=args.dropout, use_event=(args.dnee_train_fld is not None), use_dnee_scores=not args.no_dnee_scores).to(args.device) 128 | optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate) 129 | logging.info("initial learning rate = {}".format(args.learning_rate)) 130 | logging.info("dropout rate = {}".format(args.dropout)) 131 | 132 | # arg_lens = [config['arg0_max_len'], config['arg1_max_len']] 133 | losses = [] 134 | dev_f1s = [] 135 | best_dev_f1, best_epoch, best_batch = -1, -1, -1 136 | logging.info("batch_size = {}".format(args.batch_size)) 137 | for i_epoch in range(args.n_epoches): 138 | train_loader = DataLoader(train_data, 139 | batch_size=args.batch_size, 140 | shuffle=True, 141 | num_workers=1) 142 | 143 | epoch_start = time.time() 144 | for i_batch, (x, y) in enumerate(train_loader): 145 | if y.shape[0] != args.batch_size: 146 | # skip the last batch 147 | continue 148 | if args.dnee_train_fld: 149 | x0, x1, x0_dnee, x1_dnee, x_dnee = x 150 | x0 = x0.to(args.device) 151 | x1 = x1.to(args.device) 152 | x0_dnee = x0_dnee.to(args.device) 153 | x1_dnee = x1_dnee.to(args.device) 154 | x_dnee = x_dnee.to(args.device) 155 | else: 156 | x0, x1 = x 157 | x0 = x0.to(args.device) 158 | x1 = x1.to(args.device) 159 | x0_dnee, x1_dnee, x_dnee = None, None, None 160 | y = y.squeeze().to(args.device) 161 | 162 | model.train() 163 | optimizer.zero_grad() 164 | out = model(x0, x1, x0_dnee, x1_dnee, x_dnee) 165 | loss = model.loss_func(out, y) 166 | 167 | # step 168 | loss.backward() 169 | optimizer.step() 170 | 171 | losses.append(loss.item()) 172 | 173 | model.eval() 174 | y_pred = model.predict(x0_dev, x1_dev, x0_dnee_dev, x1_dnee_dev, x_dnee_dev) 175 | dev_prec, dev_recall, dev_f1 = ds.scoring_cm(y_dev, y_pred.cpu(), dev_cm, dev_valid_senses, indices.DISCOURSE_IDX2REL) 176 | dev_f1s.append(dev_f1) 177 | 178 | ## if i_batch % config['n_batches_per_record'] == 0: 179 | logging.info("{}, {}: loss={}, time={}".format(i_epoch, i_batch, loss.item(), time.time()-epoch_start)) 180 | logging.info("dev: prec={}, recall={}, f1={}".format(dev_prec, dev_recall, dev_f1)) 181 | if dev_f1 > best_dev_f1: 182 | logging.info("best dev: prec={}, recall={}, f1={}".format(dev_prec, dev_recall, dev_f1)) 183 | best_dev_f1 = dev_f1 184 | best_epoch = i_epoch 185 | best_batch = i_batch 186 | fpath = os.path.join(args.output_folder, 'best_model.pt') 187 | torch.save(model.state_dict(), fpath) 188 | 189 | logging.info("{}-{}: best dev f1 = {}".format(best_epoch, best_batch, best_dev_f1)) 190 | fpath = os.path.join(args.output_folder, "losses.pkl") 191 | pkl.dump(losses, open(fpath, 'wb')) 192 | fpath = os.path.join(args.output_folder, "dev_f1s.pkl") 193 | pkl.dump(dev_f1s, open(fpath, 'wb')) 194 | 195 | 196 | if __name__ == "__main__": 197 | args = utils.bin_config(get_arguments) 198 | if args.gpu_id: 199 | args.device = torch.device('cuda:{}'.format(args.gpu_id)) 200 | else: 201 | args.device = torch.device('cpu') 202 | main() 203 | -------------------------------------------------------------------------------- /dnee/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import logging 5 | import json 6 | 7 | import six 8 | import torch 9 | import torch.optim as optim 10 | import numpy as np 11 | from scipy import spatial 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | import matplotlib.pyplot as plt 15 | 16 | from .events import indices, Event, extract_events 17 | 18 | 19 | DEV_SPLIT = 0 20 | TEST_SPLIT = 1 21 | 22 | 23 | def micro_f1(y_true, y_pred, n_classes): 24 | """ 25 | multi class micro F1 26 | """ 27 | tps, fps, fns = torch.zeros(n_classes, dtype=torch.int32), \ 28 | torch.zeros(n_classes, dtype=torch.int32), \ 29 | torch.zeros(n_classes, dtype=torch.int32) 30 | for i in range(n_classes): 31 | prediction = (y_pred == i).float() 32 | truth = (y_true == i).float() 33 | confusion_vector = prediction / truth 34 | # Element-wise division of the 2 tensors returns a new tensor which holds a 35 | # unique value for each case: 36 | # 1 where prediction and truth are 1 (True Positive) 37 | 38 | # inf where prediction is 1 and truth is 0 (False Positive) 39 | # nan where prediction and truth are 0 (True Negative) 40 | # 0 where prediction is 0 and truth is 1 (False Negative) 41 | 42 | tps[i] = torch.sum(confusion_vector == 1).item() 43 | fps[i] = torch.sum(confusion_vector == float('inf')).item() 44 | fns[i] = torch.sum(confusion_vector == 0).item() 45 | total_tps = tps.sum().float().item() 46 | total_fps = fps.sum().float().item() 47 | total_fns = fns.sum().float().item() 48 | prec = total_tps / (total_tps + total_fps) if total_tps + total_fps != 0.0 else 0.0 49 | rec = total_tps / (total_tps + total_fns) if total_tps + total_fns != 0.0 else 0.0 50 | if prec + rec == 0.0: 51 | f1 = 0.0 52 | else: 53 | f1 = (2.0 * prec * rec) / (prec + rec) 54 | return f1 55 | 56 | 57 | def parse_text(text, nlp, props): 58 | parse = None 59 | server_error = False 60 | try: 61 | ann = nlp.annotate(text, properties=props) 62 | except: 63 | # server error 64 | logging.debug("corenlp server error") 65 | server_error = True 66 | 67 | if not server_error: 68 | try: 69 | parse = json.loads(ann) 70 | except: 71 | logging.debug("json parse failed; usually timeout") 72 | return parse 73 | 74 | 75 | def load_splits(dev_list, test_list): 76 | def _load(fpath, label): 77 | splits = {} 78 | with open(fpath, 'r') as fr: 79 | for line in fr: 80 | line = line.rstrip('\n') 81 | did = '.'.join(line.split('.')[:-1]) 82 | splits[did] = label 83 | return splits 84 | 85 | _dev = _load(dev_list, DEV_SPLIT) 86 | _test = _load(test_list, TEST_SPLIT) 87 | _test.update(_dev) 88 | return _test 89 | 90 | 91 | def plot_losses(losses, fpath): 92 | if len(losses) > 0: 93 | fig = plt.figure() 94 | plt.title("Loss vs. Batch") 95 | plt.xlabel("Batch") 96 | plt.ylabel("Loss") 97 | x = list(range(len(losses))) 98 | plt.plot(x, losses, color='red', label='train', linestyle="-") 99 | plt.savefig(fpath) 100 | plt.close(fig) 101 | 102 | 103 | def bin_config(get_arg_func, log_fname=None): 104 | # get arguments 105 | args = get_arg_func(sys.argv[1:]) 106 | 107 | # set logger 108 | logger = logging.getLogger() 109 | if args.debug: 110 | logger.setLevel(logging.DEBUG) 111 | elif args.verbose: 112 | logger.setLevel(logging.INFO) 113 | else: 114 | logger.setLevel(logging.ERROR) 115 | 116 | formatter = logging.Formatter('[%(levelname)s][%(name)s] %(message)s') 117 | try: 118 | if not os.path.isdir(args.output_folder): 119 | os.mkdir(args.output_folder) 120 | fpath = os.path.join(args.output_folder, log_fname) if log_fname \ 121 | else os.path.join(args.output_folder, 'log') 122 | except: 123 | fpath = log_fname if log_fname else 'log' 124 | fileHandler = logging.FileHandler(fpath) 125 | fileHandler.setFormatter(formatter) 126 | logger.addHandler(fileHandler) 127 | 128 | consoleHandler = logging.StreamHandler() 129 | consoleHandler.setFormatter(formatter) 130 | logger.addHandler(consoleHandler) 131 | return args 132 | 133 | 134 | def load_word_embeddings(fpath, use_torch=False, skip_first_line=False): 135 | we = {} 136 | with open(fpath, 'r') as fr: 137 | for line in fr: 138 | if skip_first_line: 139 | skip_first_line = False 140 | continue 141 | line = line.rstrip() 142 | sp = line.split(" ") 143 | emb = np.squeeze(np.array([sp[1:]], dtype=np.float32)) 144 | we[sp[0]] = torch.from_numpy(emb) if use_torch else emb 145 | return we 146 | 147 | 148 | def get_avg_embeddings(words, embeddings): 149 | dim = embeddings[embeddings.keys()[0]].shape[0] 150 | total_emb = np.zeros(300, dtype=np.float32) 151 | cnt = 0 152 | for w in words: 153 | if w in embeddings: 154 | total_emb += embeddings[w] 155 | cnt += 1 156 | 157 | if cnt == 0: 158 | emb = np.random.uniform(low=-1.0/dim, high=1.0/dim, size=300) 159 | total_emb += emb 160 | cnt += 1 161 | return total_emb / cnt 162 | 163 | 164 | def cosine_similarity(v1, v2): 165 | return 1.0 - spatial.distance.cosine(v1, v2) 166 | 167 | 168 | def oov_embeddings(dim): 169 | return np.random.uniform(low=-1.0/dim, high=1.0/dim, size=dim) 170 | 171 | 172 | def load_jsons(fld_path, func_getid): 173 | files = [f for f in os.listdir(fld_path) if f.endswith(".json")] 174 | docs = {} 175 | for f in files: 176 | fpath = os.path.join(fld_path, f) 177 | doc = json.load(open(fpath, 'r')) 178 | did = func_getid(doc, f) 179 | docs[did] = doc 180 | return docs 181 | 182 | 183 | def find_index(lis, target): 184 | try: 185 | idx = lis.index(target) 186 | except: 187 | idx = -1 188 | return idx 189 | 190 | 191 | def get_optimizer(model, config, **kwargs): 192 | logging.info("optimizer: {}".format(config['optimizer'])) 193 | if config['optimizer'] == 'adam': 194 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) 195 | elif config['optimizer'] == 'adagrad': 196 | if 'lr' in kwargs: 197 | optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=kwargs['lr']) 198 | else: 199 | optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters())) 200 | elif config['optimizer'] == 'adadelta': 201 | optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters())) 202 | elif config['optimizer'] == 'momentum': 203 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1, momentum=0.9) 204 | else: 205 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01) 206 | return optimizer 207 | 208 | 209 | def get_raw_event_repr(e, config, pred2idx, argw2idx, device=None, use_head=False): 210 | e_len = 1 + config['arg0_max_len'] + config['arg1_max_len'] 211 | raw = torch.zeros(e_len, dtype=torch.int64).to(device) if device else torch.zeros(e_len, dtype=torch.int64) 212 | pred_idx = e.get_pred_index(pred2idx) 213 | arg0_idxs = e.get_arg_indices(0, argw2idx, arg_len=config['arg0_max_len'], use_head=use_head) 214 | arg1_idxs = e.get_arg_indices(1, argw2idx, arg_len=config['arg1_max_len'], use_head=use_head) 215 | 216 | raw[0] = pred_idx 217 | raw[1: 1+len(arg0_idxs)] = torch.LongTensor(arg0_idxs).to(device) if device else torch.LongTensor(arg0_idxs) 218 | raw[1+config['arg0_max_len']: 1+config['arg0_max_len']+len(arg1_idxs)] = torch.LongTensor(arg1_idxs).to(device) if device else torch.LongTensor(arg1_idxs) 219 | return raw 220 | 221 | 222 | def build_unknown_event(): 223 | return Event(indices.PRED_OOV, indices.NO_ARG, indices.NO_ARG, indices.NO_ARG, indices.NO_ARG, 224 | indices.NO_ARG, indices.NO_ARG, None, indices.UNKNOWN_ANIMACY, 225 | indices.UNKNOWN_ANIMACY, indices.UNKNOWN_ANIMACY) 226 | 227 | 228 | def _extract_unique_events(doc, lemmatizer, 229 | corenlp_dep_key="enhancedPlusPlusDependencies"): 230 | """ 231 | corenlp_dep_key: collapsed-ccprocessed-dependencies 232 | """ 233 | events = _extract_events(doc, lemmatizer, corenlp_dep_key=corenlp_dep_key) 234 | # unique events 235 | unique_events = {} 236 | for mid, evs in six.iteritems(events): 237 | if evs is None: 238 | continue 239 | for ev in evs: 240 | key = '{}_{}_{}_{}'.format(ev['sentidx'], 241 | ev['predicate_head_idx'], 242 | ev['predicate_head_char_idx_begin'], 243 | ev['predicate_head_char_idx_end']) 244 | if key not in unique_events: 245 | unique_events[key] = ev 246 | return list(unique_events.values()) 247 | 248 | 249 | def _extract_events(doc, lemmatizer, corenlp_dep_key="enhancedPlusPlusDependencies"): 250 | sentences = doc["sentences"] 251 | doc_corefs = doc["corefs"] 252 | entities = extract_events.get_all_entities(doc_corefs) 253 | 254 | events = {} 255 | for coref_key, corefs in six.iteritems(doc_corefs): 256 | logging.debug("---------------------------") 257 | logging.debug('coref_key=%s' % coref_key) 258 | 259 | # for each entiy 260 | for entity in corefs: 261 | ent_id = entity['id'] 262 | if ent_id in events: 263 | logging.warning("entity {} has been extracted.".format(ent_id)) 264 | continue 265 | 266 | tmp_events = extract_events.extract_one_multi_faceted_event( 267 | sentences, entities, 268 | entity, lemmatizer, add_sentence=True, 269 | no_sentiment=True, 270 | corenlp_dep_key=corenlp_dep_key) 271 | events[ent_id] = tmp_events 272 | return events 273 | 274 | 275 | def save_indices(fpath, idx2item): 276 | with open(fpath, 'w') as fw: 277 | for i in range(len(idx2item)): 278 | fw.write(idx2item[i]+'\n') 279 | 280 | 281 | def load_indices(fpath): 282 | item2idx, idx2item = {}, {} 283 | with open(fpath, 'r') as fr: 284 | i = 0 285 | for line in fr: 286 | line = line.rstrip('\n') 287 | item2idx[line] = i 288 | idx2item[i] = line 289 | i += 1 290 | return item2idx, idx2item 291 | -------------------------------------------------------------------------------- /bin/evaluations/eval_mcns.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import pickle as pkl 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | import progressbar 14 | 15 | from dnee import utils 16 | from dnee.evals import intrinsic 17 | from dnee.events import indices 18 | from dnee.models import EventTransR, EventTransE, ArgWordEncoder, create_argw_encoder 19 | from dnee.models import AttnEventTransE, AttnEventTransR 20 | 21 | def get_arguments(argv): 22 | parser = argparse.ArgumentParser(description='MCNS and MCNE evaluation') 23 | parser.add_argument('model_file', metavar='MODEL_FILE', 24 | help='model file.') 25 | parser.add_argument('encoder_file', metavar='ENCODER_FILE', 26 | help='encoder file.') 27 | parser.add_argument('word_embedding_file', metavar='WE_FILE', 28 | help='the file of pre-trained word embeddings') 29 | parser.add_argument('question_file', metavar='QUESTION_FILE', 30 | help='questions.') 31 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 32 | help='config for training') 33 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 34 | help='config for relations') 35 | 36 | parser.add_argument('inference_model', metavar='INFERENCE_MODEL', choices=['Viterbi', 'Baseline', 'Skyline'], 37 | help='Model to conduct inferece {"Viterbi", "Baseline", "Skyline"}.') 38 | 39 | parser.add_argument('-s', '--no_subsample', action='store_true', default=False, 40 | help='do not subsample the questions') 41 | parser.add_argument('-b', '--batch_size', type=int, default=100, 42 | help='batch size for evaluation') 43 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 44 | help='gpu id') 45 | parser.add_argument('-c', '--context_rel', action='store_true', default=False, 46 | help='use REL_CONTEXT instead of REL_COREF') 47 | parser.add_argument('-u', '--use_head', action='store_true', default=False, 48 | help='use head word only for arguments') 49 | 50 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 51 | help='show info messages') 52 | parser.add_argument('-d', '--debug', action='store_true', default=False, 53 | help='show debug messages') 54 | args = parser.parse_args(argv) 55 | return args 56 | 57 | 58 | def event_to_we(e, we, dim): 59 | toks = e.pred.split('_') 60 | cnt = 0 61 | emb = torch.zeros(dim, dtype=torch.float32) 62 | for tok in toks: 63 | if tok not in we: 64 | continue 65 | else: 66 | emb += we[tok] 67 | cnt += 1 68 | if cnt > 0: 69 | emb = emb / cnt 70 | else: 71 | emb = np.random.uniform(low=-1.0/dim, high=1.0/dim, size=dim) 72 | emb = torch.from_numpy(emb) 73 | return emb 74 | 75 | 76 | def build_embeddings(model, questions, config, pred2idx, argw2idx, rtype, we): 77 | e2idx = {} 78 | idx = 0 79 | for q in questions: 80 | key = q.echain[0].__repr__() 81 | if key not in e2idx: 82 | e2idx[key] = idx 83 | idx += 1 84 | 85 | for clist in q.choice_lists: 86 | for c in clist: 87 | key = c.__repr__() 88 | if key not in e2idx: 89 | e2idx[key] = idx 90 | idx += 1 91 | e_len = 1 + config['arg0_max_len'] + config['arg1_max_len'] 92 | inputs = torch.zeros((len(e2idx), e_len), 93 | dtype=torch.int64).to(args.device) 94 | wdim = we[we.keys()[0]].shape[0] 95 | w_embeddings = torch.zeros((len(e2idx), wdim), 96 | dtype=torch.float32).to(args.device) 97 | 98 | for q in questions: 99 | key = q.echain[0].__repr__() 100 | idx = e2idx[key] 101 | inputs[idx] = utils.get_raw_event_repr(q.echain[0], config, pred2idx, argw2idx, device=args.device, use_head=args.use_head) 102 | w_embeddings[idx] = event_to_we(q.echain[0], we, wdim) 103 | 104 | for clist in q.choice_lists: 105 | for c in clist: 106 | key = c.__repr__() 107 | idx = e2idx[key] 108 | inputs[idx] = utils.get_raw_event_repr(c, config, pred2idx, argw2idx, device=args.device, use_head=args.use_head) 109 | w_embeddings[idx] = event_to_we(c, we, wdim) 110 | ev_embeddings = model._transfer(model.embed_event(inputs), rtype) 111 | return e2idx, ev_embeddings, w_embeddings 112 | 113 | 114 | def main(): 115 | config = json.load(open(args.training_config, 'r')) 116 | indices.set_relation_classes(args.relation_config) 117 | pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) 118 | argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) 119 | n_preds = len(pred2idx) 120 | argw_vocabs = argw2idx.keys() 121 | argw_encoder = create_argw_encoder(config, args.device) 122 | if args.encoder_file: 123 | argw_encoder.load(args.encoder_file) 124 | 125 | logging.info("model class: " + config['model_type']) 126 | ModelClass = eval(config['model_type']) 127 | model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 128 | model.load_state_dict(torch.load(args.model_file, 129 | map_location=lambda storage, location: storage)) 130 | 131 | we = utils.load_word_embeddings(args.word_embedding_file, use_torch=True) 132 | 133 | questions = pkl.load(open(args.question_file, 'r')) 134 | if not args.no_subsample: 135 | # ridxs = list(range(len(questions))) 136 | # random.shuffle(ridxs) 137 | # ridxs = [ridxs[i] for i in range(1000)] 138 | # questions = [questions[i] for i in ridxs] 139 | questions = questions[:10000] 140 | logging.info("#questions={}".format(len(questions))) 141 | 142 | rtype = indices.REL2IDX[indices.REL_CONTEXT] if args.context_rel else indices.REL2IDX[indices.REL_COREF] 143 | rtype = torch.LongTensor([rtype]).to(args.device) 144 | 145 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 146 | bar = progressbar.ProgressBar(widgets=widgets, maxval=len(questions)).start() 147 | 148 | logging.info("batch_size = {}".format(args.batch_size)) 149 | batch_size = args.batch_size 150 | n_batches = len(questions) // batch_size + 1 if len(questions) % batch_size != 0 else len(questions) // batch_size 151 | logging.info("#questions = {}".format(len(questions))) 152 | logging.info("n_batches = {}".format(n_batches)) 153 | i_q = 0 154 | 155 | # ((we, ev, mix), (mcns, mcne),(incorrect, correct)) 156 | WE_IDX, EV_IDX, MIX_IDX = 0, 1, 2 157 | MCNS_IDX, MCNE_IDX = 0, 1 158 | INCORRECT_IDX, CORRECT_IDX = 0, 1 159 | results = torch.zeros((3, 2, 2), dtype=torch.int64) 160 | for i_batch in range(n_batches): 161 | batch_questions = questions[i_batch*batch_size: (i_batch+1)*batch_size] 162 | e2idx, ev_embeddings, w_embeddings = build_embeddings(model, batch_questions, config, pred2idx, argw2idx, rtype, we) 163 | for q in batch_questions: 164 | # when calculating the accuracy, we only consider the questions in the middle 165 | # so that MCNS and MCNE can have a fair comparison 166 | n_q = len(q.ans_idxs) - 1 167 | 168 | we_preds, ev_preds = intrinsic.predict_mcns(model, q, e2idx, ev_embeddings, w_embeddings, rtype, args.device, args.inference_model) 169 | for i in range(n_q): 170 | for emb_idx, preds in [(WE_IDX, we_preds), (EV_IDX, ev_preds)]: 171 | if preds[i] == q.ans_idxs[i]: 172 | results[emb_idx][MCNS_IDX][CORRECT_IDX] += 1 173 | else: 174 | results[emb_idx][MCNS_IDX][INCORRECT_IDX] += 1 175 | 176 | we_preds, ev_preds = intrinsic.predict_mcne(model, q, e2idx, ev_embeddings, w_embeddings, rtype, args.device, args.inference_model) 177 | for i in range(n_q): 178 | for emb_idx, preds in [(WE_IDX, we_preds), (EV_IDX, ev_preds)]: 179 | if preds[i] == q.ans_idxs[i]: 180 | results[emb_idx][MCNE_IDX][CORRECT_IDX] += 1 181 | else: 182 | results[emb_idx][MCNE_IDX][INCORRECT_IDX] += 1 183 | i_q += 1 184 | bar.update(i_q) 185 | bar.finish() 186 | 187 | results = results.type(torch.float32) 188 | print ("MCNS:") 189 | print ("\tWE:") 190 | print("\t\taccuracy={}".format(results[WE_IDX][MCNS_IDX][CORRECT_IDX]/(results[WE_IDX][MCNS_IDX][CORRECT_IDX]+results[WE_IDX][MCNS_IDX][INCORRECT_IDX]))) 191 | print ("\tEV:") 192 | print("\t\taccuracy={}".format(results[EV_IDX][MCNS_IDX][CORRECT_IDX]/(results[EV_IDX][MCNS_IDX][CORRECT_IDX]+results[EV_IDX][MCNS_IDX][INCORRECT_IDX]))) 193 | # print ("\tMIX:") 194 | # print("\t\taccuracy={}".format(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]+results[MIX_IDX][MCNS_IDX][INCORRECT_IDX]))) 195 | 196 | print ("MCNE:") 197 | print ("\tWE:") 198 | print("\t\taccuracy={}".format(results[WE_IDX][MCNE_IDX][CORRECT_IDX]/(results[WE_IDX][MCNE_IDX][CORRECT_IDX]+results[WE_IDX][MCNE_IDX][INCORRECT_IDX]))) 199 | print ("\tEV:") 200 | print("\t\taccuracy={}".format(results[EV_IDX][MCNE_IDX][CORRECT_IDX]/(results[EV_IDX][MCNE_IDX][CORRECT_IDX]+results[EV_IDX][MCNE_IDX][INCORRECT_IDX]))) 201 | # print ("\tMIX:") 202 | # print("\t\taccuracy={}".format(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]+results[MIX_IDX][MCNE_IDX][INCORRECT_IDX]))) 203 | 204 | logging.info("MCNS:") 205 | logging.info("\tWE:") 206 | logging.info("\t\taccuracy={}".format(results[WE_IDX][MCNS_IDX][CORRECT_IDX]/(results[WE_IDX][MCNS_IDX][CORRECT_IDX]+results[WE_IDX][MCNS_IDX][INCORRECT_IDX]))) 207 | logging.info("\tEV:") 208 | logging.info("\t\taccuracy={}".format(results[EV_IDX][MCNS_IDX][CORRECT_IDX]/(results[EV_IDX][MCNS_IDX][CORRECT_IDX]+results[EV_IDX][MCNS_IDX][INCORRECT_IDX]))) 209 | # logging.info("\tMIX:") 210 | # logging.info("\t\taccuracy={}".format(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]+results[MIX_IDX][MCNS_IDX][INCORRECT_IDX]))) 211 | 212 | logging.info("MCNE:") 213 | logging.info("\tWE:") 214 | logging.info("\t\taccuracy={}".format(results[WE_IDX][MCNE_IDX][CORRECT_IDX]/(results[WE_IDX][MCNE_IDX][CORRECT_IDX]+results[WE_IDX][MCNE_IDX][INCORRECT_IDX]))) 215 | logging.info("\tEV:") 216 | logging.info("\t\taccuracy={}".format(results[EV_IDX][MCNE_IDX][CORRECT_IDX]/(results[EV_IDX][MCNE_IDX][CORRECT_IDX]+results[EV_IDX][MCNE_IDX][INCORRECT_IDX]))) 217 | # logging.info("\tMIX:") 218 | # logging.info("\t\taccuracy={}".format(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]+results[MIX_IDX][MCNE_IDX][INCORRECT_IDX]))) 219 | 220 | 221 | if __name__ == "__main__": 222 | args = utils.bin_config(get_arguments) 223 | if torch.cuda.is_available(): 224 | args.device = torch.device('cuda') if args.gpu_id is None \ 225 | else torch.device('cuda:{}'.format(args.gpu_id)) 226 | else: 227 | args.device = torch.device('cpu') 228 | main() 229 | -------------------------------------------------------------------------------- /dnee/evals/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A collection of data structures that are particularly 3 | useful for developing and improving a classifier 4 | """ 5 | 6 | import numpy 7 | import json 8 | 9 | 10 | class ConfusionMatrix(object): 11 | """Confusion matrix for evaluating a classifier 12 | 13 | For more information on confusion matrix en.wikipedia.org/wiki/Confusion_matrix 14 | """ 15 | 16 | INIT_NUM_CLASSES = 100 17 | NEGATIVE_CLASS = '__NEGATIVE_CLASS__' 18 | def __init__(self, alphabet=None): 19 | if alphabet is None: 20 | self.alphabet = Alphabet() 21 | self.matrix = numpy.zeros((self.INIT_NUM_CLASSES, self.INIT_NUM_CLASSES)) 22 | else: 23 | self.alphabet = alphabet 24 | num_classes = alphabet.size() 25 | self.matrix = numpy.zeros((num_classes,num_classes)) 26 | 27 | def __iadd__(self, other): 28 | self.matrix += other.matrix 29 | return self 30 | 31 | def add(self, prediction, true_answer): 32 | """Add one data point to the confusion matrix 33 | 34 | If prediction is an integer, we assume that it's a legitimate index 35 | on the confusion matrix. 36 | 37 | If prediction is a string, then we will do the look up to 38 | map to the integer index for the confusion matrix. 39 | 40 | """ 41 | if type(prediction) == int and type(true_answer) == int: 42 | self.matrix[prediction, true_answer] += 1 43 | else: 44 | self.alphabet.add(prediction) 45 | self.alphabet.add(true_answer) 46 | prediction_index = self.alphabet.get_index(prediction) 47 | true_answer_index = self.alphabet.get_index(true_answer) 48 | self.matrix[prediction_index, true_answer_index] += 1 49 | #XXX: this will fail if the prediction_index is greater than 50 | # the initial capacity. I should grow the matrix if this crashes 51 | 52 | 53 | def add_list(self, predictions, true_answers): 54 | """Add a list of data point to the confusion matrix 55 | 56 | A list can be a list of integers. 57 | If prediction is an integer, we assume that it's a legitimate index 58 | on the confusion matrix. 59 | 60 | A list can be a list of strings. 61 | If prediction is a string, then we will do the look up to 62 | map to the integer index for the confusion matrix. 63 | 64 | """ 65 | for p, t in zip(predictions, true_answers): 66 | self.add(p, t) 67 | 68 | def get_prf_for_i(self, i): 69 | """Compute precision, recall, and f1 score for a given index.""" 70 | 71 | if sum(self.matrix[i,:]) == 0: 72 | precision = 1.0 73 | else: 74 | precision = self.matrix[i,i] / sum(self.matrix[i,:]) 75 | if sum(self.matrix[:,i]) == 0: 76 | recall = 1.0 77 | else: 78 | recall = self.matrix[i,i] / sum(self.matrix[:,i]) 79 | if precision + recall != 0.0: 80 | f1 = 2.0 * precision * recall / (precision + recall) 81 | else: 82 | f1 = 0.0 83 | return (precision, recall, f1) 84 | 85 | def get_prf_for_all(self): 86 | """Compute precision, recall, and f1 score for all indexes.""" 87 | 88 | precision = numpy.zeros(self.alphabet.size()) 89 | recall = numpy.zeros(self.alphabet.size()) 90 | f1 = numpy.zeros(self.alphabet.size()) 91 | 92 | # compute precision, recall, and f1 93 | for i in range(self.alphabet.size()): 94 | precision[i], recall[i], f1[i] = self.get_prf_for_i(i) 95 | 96 | return (precision, recall, f1) 97 | 98 | def get_prf(self, class_name): 99 | """Compute precision, recall, and f1 score for a given class. """ 100 | i = self.alphabet.get_index(class_name) 101 | return self.get_prf_for_i(i) 102 | 103 | def compute_micro_average_f1(self): 104 | total_correct = 0.0 105 | for i in range(self.alphabet.size()): 106 | total_correct += self.matrix[i,i] 107 | negative_index = self.alphabet.get_index(self.NEGATIVE_CLASS) 108 | total_predicted = numpy.sum([x for i, x in enumerate(self.matrix.sum(1))\ 109 | if negative_index == -1 or i != negative_index]) 110 | total_gold = numpy.sum([x for i, x in enumerate(self.matrix.sum(0)) \ 111 | if negative_index == -1 or i != negative_index]) 112 | 113 | if total_predicted == 0: 114 | precision = 1.0 115 | else: 116 | precision = total_correct / total_predicted 117 | if total_gold == 0: 118 | recall = 1.0 119 | else: 120 | recall = total_correct / total_gold 121 | if precision + recall != 0.0: 122 | f1_score = 2.0 * (precision * recall) / (precision + recall) 123 | else: 124 | f1_score = 0.0 125 | return (round(precision, 4), round(recall, 4), round(f1_score,4)) 126 | 127 | def compute_average_f1(self): 128 | precision, recall, f1 = self.get_prf_for_all() 129 | return numpy.mean(f1) 130 | 131 | def compute_average_prf(self): 132 | precision, recall, f1 = self.get_prf_for_all() 133 | return (round(numpy.mean(precision), 4), 134 | round(numpy.mean(recall), 4), 135 | round(numpy.mean(f1), 4)) 136 | 137 | def print_matrix(self): 138 | num_classes = self.alphabet.size() 139 | #header for the confusion matrix 140 | header = [' '] + [self.alphabet.get_label(i) for i in range(num_classes)] 141 | rows = [] 142 | #putting labels to the first column of rhw matrix 143 | for i in range(num_classes): 144 | row = [self.alphabet.get_label(i)] + [str(self.matrix[i,j]) for j in range(num_classes)] 145 | rows.append(row) 146 | print("row = predicted, column = truth") 147 | print(matrix_to_string(rows, header)) 148 | 149 | def print_summary(self): 150 | 151 | precision = numpy.zeros(self.alphabet.size()) 152 | recall = numpy.zeros(self.alphabet.size()) 153 | f1 = numpy.zeros(self.alphabet.size()) 154 | 155 | max_len = 0 156 | for i in range(self.alphabet.size()): 157 | label = self.alphabet.get_label(i) 158 | if label != self.NEGATIVE_CLASS and len(label) > max_len: 159 | max_len = len(label) 160 | 161 | lines = [] 162 | correct = 0.0 163 | # compute precision, recall, and f1 164 | for i in range(self.alphabet.size()): 165 | precision[i], recall[i], f1[i] = self.get_prf_for_i(i) 166 | correct += self.matrix[i,i] 167 | label = self.alphabet.get_label(i) 168 | if label != self.NEGATIVE_CLASS: 169 | space = ' ' * (max_len - len(label) + 1) 170 | lines.append( '%s%s precision %1.4f\trecall %1.4f\tF1 %1.4f' %\ 171 | (label, space, precision[i], recall[i], f1[i])) 172 | precision, recall, f1 = self.compute_micro_average_f1() 173 | space = ' ' * (max_len - 14 + 1) 174 | lines.append('*Micro-Average%s precision %1.4f\trecall %1.4f\tF1 %1.4f' %\ 175 | (space, numpy.mean(precision), numpy.mean(recall), numpy.mean(f1))) 176 | lines.sort() 177 | print('\n'.join(lines)) 178 | 179 | def print_out(self): 180 | """Printing out confusion matrix along with Macro-F1 score""" 181 | self.print_matrix() 182 | self.print_summary() 183 | 184 | 185 | def matrix_to_string(matrix, header=None): 186 | """ 187 | Return a pretty, aligned string representation of a nxm matrix. 188 | 189 | This representation can be used to print any tabular data, such as 190 | database results. It works by scanning the lengths of each element 191 | in each column, and determining the format string dynamically. 192 | 193 | the implementation is adapted from here 194 | mybravenewworld.wordpress.com/2010/09/19/print-tabular-data-nicely-using-python/ 195 | 196 | Args: 197 | matrix - Matrix representation (list with n rows of m elements). 198 | header - Optional tuple or list with header elements to be displayed. 199 | 200 | Returns: 201 | nicely formatted matrix string 202 | """ 203 | 204 | if isinstance(header, list): 205 | header = tuple(header) 206 | lengths = [] 207 | if header: 208 | lengths = [len(column) for column in header] 209 | 210 | #finding the max length of each column 211 | for row in matrix: 212 | for column in row: 213 | i = row.index(column) 214 | column = str(column) 215 | column_length = len(column) 216 | try: 217 | max_length = lengths[i] 218 | if column_length > max_length: 219 | lengths[i] = column_length 220 | except IndexError: 221 | lengths.append(column_length) 222 | 223 | #use the lengths to derive a formatting string 224 | lengths = tuple(lengths) 225 | format_string = "" 226 | for length in lengths: 227 | format_string += "%-" + str(length) + "s " 228 | format_string += "\n" 229 | 230 | #applying formatting string to get matrix string 231 | matrix_str = "" 232 | if header: 233 | matrix_str += format_string % header 234 | for row in matrix: 235 | matrix_str += format_string % tuple(row) 236 | 237 | return matrix_str 238 | 239 | 240 | class Alphabet(object): 241 | """Two way map for label and label index 242 | 243 | It is an essentially a code book for labels or features 244 | This class makes it convenient for us to use numpy.array 245 | instead of dictionary because it allows us to use index instead of 246 | label string. The implemention of classifiers uses label index space 247 | instead of label string space. 248 | """ 249 | def __init__(self): 250 | self._index_to_label = {} 251 | self._label_to_index = {} 252 | self.num_labels = 0 253 | self.growing = True 254 | 255 | def __len__(self): 256 | return self.size() 257 | 258 | def __eq__(self, other): 259 | return self._index_to_label == other._index_to_label and \ 260 | self._label_to_index == other._label_to_index and \ 261 | self.num_labels == other.num_labels 262 | 263 | def size(self): 264 | return self.num_labels 265 | 266 | def has_label(self, label): 267 | return label in self._label_to_index 268 | 269 | def get_label(self, index): 270 | """Get label from index""" 271 | if index >= self.num_labels: 272 | raise KeyError("There are %d labels but the index is %d" % (self.num_labels, index)) 273 | return self._index_to_label[index] 274 | 275 | def get_index(self, label): 276 | """Get index from label""" 277 | if not self.has_label(label): 278 | if self.growing: 279 | self.add(label) 280 | else: 281 | return -1 282 | return self._label_to_index[label] 283 | 284 | def add(self, label): 285 | """Add an index for the label if it's a new label""" 286 | if label not in self._label_to_index: 287 | if not self.growing: 288 | raise ValueError( 289 | 'Alphabet is not set to grow i.e. accepting new labels') 290 | self._label_to_index[label] = self.num_labels 291 | self._index_to_label[self.num_labels] = label 292 | self.num_labels += 1 293 | 294 | def json_dumps(self): 295 | return json.dumps(self.to_dict()) 296 | 297 | @classmethod 298 | def json_loads(cls, json_string): 299 | json_dict = json.loads(json_string) 300 | return Alphabet.from_dict(json_dict) 301 | 302 | def to_dict(self): 303 | return { 304 | '_label_to_index': self._label_to_index 305 | } 306 | 307 | @classmethod 308 | def from_dict(cls, alphabet_dictionary): 309 | """Create an Alphabet from dictionary 310 | 311 | alphabet_dictionary is a dictionary with only one field 312 | _label_to_index which is a map from label to index 313 | and should be created with to_dict method above. 314 | """ 315 | alphabet = cls() 316 | alphabet._label_to_index = alphabet_dictionary['_label_to_index'] 317 | alphabet._index_to_label = {} 318 | for label, index in alphabet._label_to_index.items(): 319 | alphabet._index_to_label[index] = label 320 | # making sure that the dimension agrees 321 | assert(len(alphabet._index_to_label) == len(alphabet._label_to_index)) 322 | alphabet.num_labels = len(alphabet._index_to_label) 323 | return alphabet 324 | 325 | 326 | -------------------------------------------------------------------------------- /bin/evaluations/eval_disc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import cPickle as pkl 9 | 10 | import numpy as np 11 | import torch 12 | import progressbar 13 | from sklearn.metrics import f1_score, accuracy_score 14 | 15 | from dnee import utils 16 | from dnee.evals import intrinsic 17 | from dnee.events import indices 18 | from dnee.models import EventTransR, EventTransE, ArgWordEncoder, create_argw_encoder 19 | from dnee.models import AttnEventTransE, AttnEventTransR 20 | 21 | 22 | def get_arguments(argv): 23 | parser = argparse.ArgumentParser(description='intrinsic disc evaluation') 24 | parser.add_argument('model_file', metavar='MODEL_FILE', 25 | help='model file.') 26 | parser.add_argument('encoder_file', metavar='ENCODER_FILE', 27 | help='encoder file.') 28 | parser.add_argument('question_file', metavar='QUESTION_FILE', 29 | help='questions.') 30 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 31 | help='config for training') 32 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 33 | help='relation classes') 34 | 35 | parser.add_argument('-b', '--batch_size', type=int, default=500, 36 | help='batch size for evaluation') 37 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 38 | help='gpu id') 39 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 40 | help='show info messages') 41 | parser.add_argument('-d', '--debug', action='store_true', default=False, 42 | help='show debug messages') 43 | args = parser.parse_args(argv) 44 | return args 45 | 46 | 47 | def build_examples(questions, config, pred2idx, argw2idx): 48 | Xs, ys = [], [] 49 | for q in questions: 50 | x1 = utils.get_raw_event_repr(q.rel.e1, config, pred2idx, argw2idx) 51 | x2 = utils.get_raw_event_repr(q.rel.e2, config, pred2idx, argw2idx) 52 | x = torch.cat((x1, x2), 0) 53 | Xs.append(x) 54 | y = q.rel.rtype_idx 55 | ys.append(y) 56 | Xs = torch.stack(Xs, dim=0).to(args.device) 57 | ys = torch.LongTensor(ys).to(args.device) 58 | return Xs, ys 59 | 60 | 61 | def _eval_by_events(questions, model, config, pred2idx, argw2idx, relation_config): 62 | global tmp_y, tmp_y_pred 63 | X, y = build_examples(questions, config, pred2idx, argw2idx) 64 | ev_dim = X.shape[1] // 2 65 | e1 = X[:, :ev_dim] 66 | e2 = X[:, ev_dim:] 67 | embs1 = model.embed_event(e1) 68 | embs2 = model.embed_event(e2) 69 | 70 | # predict 71 | n_rels = relation_config['disc_end'] - relation_config['disc_begin'] 72 | scores = torch.zeros((n_rels, X.shape[0]), dtype=torch.float32) 73 | for r in range(n_rels): 74 | ridx = torch.LongTensor([r]*e1.shape[0]).to(args.device) 75 | remb = model.rel_embeddings(ridx) 76 | _score = model._calc(model._transfer(embs1, ridx), 77 | model._transfer(embs2, ridx), 78 | remb) 79 | score = torch.sum(_score, 1) 80 | if model.norm > 1: 81 | score = torch.pow(score, 1.0 / model.norm) 82 | scores[r] = score 83 | _, y_predict = torch.min(scores, 0) 84 | return y, y_predict 85 | 86 | 87 | def eval_by_events(questions, model, config, pred2idx, argw2idx, relation_config): 88 | return _eval(questions, model, config, 89 | pred2idx, argw2idx, _eval_by_events, relation_config) 90 | 91 | 92 | def _eval(questions, model, config, pred2idx, argw2idx, eval_func, relation_config): 93 | n_batches = len(questions) // args.batch_size 94 | if len(questions) % args.batch_size > 0: 95 | n_batches += 1 96 | 97 | n_correct, n_incorrect = 0, 0 98 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 99 | bar = progressbar.ProgressBar(widgets=widgets, maxval=n_batches).start() 100 | ys, y_preds = [], [] 101 | for i_batch in range(n_batches): 102 | subquestions = questions[i_batch*args.batch_size: (i_batch+1)*args.batch_size] 103 | logging.debug("#subquestions = {}".format(len(subquestions))) 104 | y, y_pred = eval_func(subquestions, model, config, pred2idx, argw2idx, relation_config) 105 | ys.append(y) 106 | y_preds.append(y_pred) 107 | bar.update(i_batch+1) 108 | ys = torch.cat(ys, dim=0).detach().cpu().tolist() 109 | y_preds = torch.cat(y_preds, dim=0).detach().cpu().tolist() 110 | bar.finish() 111 | return ys, y_preds 112 | 113 | 114 | def _eval_by_rel(questions, model, config, pred2idx, argw2idx, relation_config): 115 | embs1, ridxs, ys = [], [], [] 116 | all_cins = [] 117 | for q in questions: 118 | x1 = utils.get_raw_event_repr(q.rel.e1, config, pred2idx, argw2idx) 119 | embs1.append(x1) 120 | ridxs.append(q.rel.rtype_idx) 121 | y = q.ans_idx 122 | ys.append(y) 123 | 124 | cins = [] 125 | for c in q.choices: 126 | cin = utils.get_raw_event_repr(c, config, pred2idx, argw2idx) 127 | cins.append(cin.tolist()) 128 | all_cins.append(cins) 129 | 130 | embs1 = torch.stack(embs1, dim=0).to(args.device) 131 | ridxs = torch.LongTensor(ridxs).to(args.device) 132 | ys = torch.LongTensor(ys).to(args.device) 133 | embs1 = model._transfer(model.embed_event(embs1), ridxs) 134 | rembs = model.rel_embeddings(ridxs) 135 | 136 | all_cins = torch.LongTensor(all_cins).to(args.device) 137 | n_choices = len(questions[0].choices) 138 | scores = torch.zeros((n_choices, len(questions)), dtype=torch.float32) 139 | for i in range(n_choices): 140 | cembs = model.embed_event(all_cins[:, i, :]) 141 | _score = model._calc(embs1, 142 | model._transfer(cembs, ridxs), 143 | rembs) 144 | score = torch.sum(_score, 1) 145 | if model.norm > 1: 146 | score = torch.pow(score, 1.0 / model.norm) 147 | scores[i] = score 148 | _, y_predict = torch.min(scores, 0) 149 | return ys, y_predict 150 | 151 | 152 | def _eval_by_next_rel(questions, model, config, pred2idx, argw2idx, relation_config): 153 | embs1, ridxs, ys = [], [], [] 154 | all_cins = [] 155 | for q in questions: 156 | x1 = utils.get_raw_event_repr(q.rel.e1, config, pred2idx, argw2idx) 157 | embs1.append(x1) 158 | ridxs.append(relation_config["rel2idx"][indices.REL_CONTEXT]) 159 | y = q.ans_idx 160 | ys.append(y) 161 | 162 | cins = [] 163 | for c in q.choices: 164 | cin = utils.get_raw_event_repr(c, config, pred2idx, argw2idx) 165 | cins.append(cin.tolist()) 166 | all_cins.append(cins) 167 | 168 | embs1 = torch.stack(embs1, dim=0).to(args.device) 169 | ridxs = torch.LongTensor(ridxs).to(args.device) 170 | ys = torch.LongTensor(ys).to(args.device) 171 | embs1 = model._transfer(model.embed_event(embs1), ridxs) 172 | rembs = model.rel_embeddings(ridxs) 173 | 174 | all_cins = torch.LongTensor(all_cins).to(args.device) 175 | n_choices = len(questions[0].choices) 176 | scores = torch.zeros((n_choices, len(questions)), dtype=torch.float32) 177 | for i in range(n_choices): 178 | cembs = model.embed_event(all_cins[:, i, :]) 179 | _score = model._calc(embs1, 180 | model._transfer(cembs, ridxs), 181 | rembs) 182 | score = torch.sum(_score, 1) 183 | if model.norm > 1: 184 | score = torch.pow(score, 1.0 / model.norm) 185 | scores[i] = score 186 | 187 | _, y_predict = torch.min(scores, 0) 188 | return ys, y_predict 189 | 190 | 191 | def _eval_by_random_rel(questions, model, config, pred2idx, argw2idx, relation_config): 192 | embs1, ridxs, ys = [], [], [] 193 | all_cins = [] 194 | for q in questions: 195 | x1 = utils.get_raw_event_repr(q.rel.e1, config, pred2idx, argw2idx) 196 | embs1.append(x1) 197 | ridxs.append(q.rel.rtype_idx) 198 | y = q.ans_idx 199 | ys.append(y) 200 | 201 | cins = [] 202 | for c in q.choices: 203 | cin = utils.get_raw_event_repr(c, config, pred2idx, argw2idx) 204 | cins.append(cin.tolist()) 205 | all_cins.append(cins) 206 | 207 | embs1 = torch.stack(embs1, dim=0).to(args.device) 208 | ridxs = torch.LongTensor(ridxs).to(args.device) 209 | ys = torch.LongTensor(ys).to(args.device) 210 | embs1 = model._transfer(model.embed_event(embs1), ridxs) 211 | rridxs = torch.randint(relation_config['disc_begin'], relation_config['disc_end'], ridxs.shape, dtype=torch.int64).to(args.device) 212 | rrembs = model.rel_embeddings(rridxs) 213 | 214 | all_cins = torch.LongTensor(all_cins).to(args.device) 215 | n_choices = len(questions[0].choices) 216 | rscores = torch.zeros((n_choices, len(questions)), dtype=torch.float32) 217 | for i in range(n_choices): 218 | cembs = model.embed_event(all_cins[:, i, :]) 219 | _rscore = model._calc(embs1, 220 | model._transfer(cembs, rridxs), 221 | rrembs) 222 | rscore = torch.sum(_rscore, 1) 223 | if model.norm > 1: 224 | rscore = torch.pow(rscore, 1.0 / model.norm) 225 | rscores[i] = rscore 226 | 227 | _, ry_predict = torch.min(rscores, 0) 228 | return ys, ry_predict 229 | 230 | 231 | def eval_by_rel(questions, model, config, pred2idx, argw2idx, relation_config): 232 | return _eval(questions, model, config, 233 | pred2idx, argw2idx, _eval_by_rel, relation_config) 234 | 235 | 236 | def eval_by_random_rel(questions, model, config, pred2idx, argw2idx, relation_config): 237 | return _eval(questions, model, config, 238 | pred2idx, argw2idx, _eval_by_random_rel, relation_config) 239 | 240 | 241 | def eval_by_next_rel(questions, model, config, pred2idx, argw2idx, relation_config): 242 | return _eval(questions, model, config, 243 | pred2idx, argw2idx, _eval_by_next_rel, relation_config) 244 | 245 | 246 | def acc(y, y_pred): 247 | a = np.array(y, dtype=np.int64) 248 | b = np.array(y_pred, dtype=np.int64) 249 | n_correct = (a == b).sum() 250 | return float(n_correct) / a.shape[0] 251 | 252 | 253 | def main(): 254 | config = json.load(open(args.training_config, 'r')) 255 | relation_config = json.load(open(args.relation_config, 'r')) 256 | pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) 257 | argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) 258 | n_preds = len(pred2idx) 259 | argw_vocabs = argw2idx.keys() 260 | argw_encoder = create_argw_encoder(config, args.device) 261 | if args.encoder_file: 262 | argw_encoder.load(args.encoder_file) 263 | 264 | logging.info("model class: " + config['model_type']) 265 | ModelClass = eval(config['model_type']) 266 | model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 267 | model.load_state_dict(torch.load(args.model_file, 268 | map_location=lambda storage, location: storage)) 269 | 270 | questions = pkl.load(open(args.question_file, 'r')) 271 | logging.info("#questions={}".format(len(questions))) 272 | 273 | logging.info('predict relation') 274 | y, y_pred = eval_by_events(questions, model, config, pred2idx, argw2idx, relation_config) 275 | logging.info("predict relation, accuracy={}".format(accuracy_score(y, y_pred))) 276 | logging.info("predict relation, accuracy={}".format(acc(y, y_pred))) 277 | 278 | logging.info('predict next event') 279 | y, y_pred = eval_by_rel(questions, model, config, pred2idx, argw2idx, relation_config) 280 | logging.info("predict next event, accuracy={}".format(accuracy_score(y, y_pred))) 281 | logging.info("predict next event, accuracy={}".format(acc(y, y_pred))) 282 | 283 | logging.info('predict next event by random rel') 284 | y, y_pred = eval_by_random_rel(questions, model, config, pred2idx, argw2idx, relation_config) 285 | logging.info("by random rel, accuracy={}".format(accuracy_score(y, y_pred))) 286 | logging.info("by random rel, accuracy={}".format(acc(y, y_pred))) 287 | 288 | logging.info('predict next event by next rel') 289 | y, y_pred = eval_by_next_rel(questions, model, config, pred2idx, argw2idx, relation_config) 290 | logging.info("by next rel, accuracy={}".format(accuracy_score(y, y_pred))) 291 | logging.info("by next rel, accuracy={}".format(acc(y, y_pred))) 292 | 293 | 294 | if __name__ == "__main__": 295 | args = utils.bin_config(get_arguments) 296 | if args.gpu_id is not None: 297 | args.device = torch.device('cuda:{}'.format(args.gpu_id)) 298 | else: 299 | args.device = torch.device('cpu') 300 | main() 301 | -------------------------------------------------------------------------------- /bin/evaluations/eval_disc_binary.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | import time 7 | import re 8 | import pickle as pkl 9 | from copy import deepcopy 10 | import random 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import progressbar 16 | from sklearn import metrics 17 | from allennlp.modules.elmo import Elmo, batch_to_ids 18 | from nltk import word_tokenize 19 | 20 | from dnee import utils 21 | from dnee.evals import intrinsic 22 | from dnee.events import indices 23 | from dnee.models import EventTransR, EventTransE, ArgWordEncoder, create_argw_encoder 24 | 25 | 26 | def get_arguments(argv): 27 | parser = argparse.ArgumentParser(description='relation specific binary classification') 28 | parser.add_argument('model_file', metavar='MODEL_FILE', 29 | help='model file.') 30 | parser.add_argument('encoder_file', metavar='ENCODER_FILE', 31 | help='encoder file.') 32 | parser.add_argument('dev_question_file', metavar='DEV_QUESTION_FILE', 33 | help='dev questions.') 34 | parser.add_argument('test_question_file', metavar='TEST_QUESTION_FILE', 35 | help='test questions.') 36 | parser.add_argument('training_config', metavar='TRAINING_CONFIG', 37 | help='config for training') 38 | parser.add_argument('relation_config', metavar='RELATION_CONFIG', 39 | help='relation classes') 40 | 41 | parser.add_argument('-w', '--elmo_weight_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5", 42 | help='ELMo weight file') 43 | parser.add_argument('-p', '--elmo_option_file', default="data/elmo_2x2048_256_2048cnn_1xhighway_options.json", 44 | help='ELMo option file') 45 | parser.add_argument('-m', '--use_elmo', action='store_true', default=False, 46 | help='use ELMo but TransE/TransR') 47 | parser.add_argument('-r', '--n_rounds', type=int, default=10, 48 | help='number of rounds to get the average') 49 | parser.add_argument('-s', '--step_size', type=float, default=0.001, 50 | help='grid search step size for thresholds') 51 | parser.add_argument('-g', '--gpu_id', type=int, default=None, 52 | help='gpu id') 53 | parser.add_argument('-v', '--verbose', action='store_true', default=False, 54 | help='show info messages') 55 | parser.add_argument('-d', '--debug', action='store_true', default=False, 56 | help='show debug messages') 57 | args = parser.parse_args(argv) 58 | return args 59 | 60 | 61 | def sample_questions(question_file, n_cat_questions=500): 62 | questions = pkl.load(open(question_file, 'rb')) 63 | 64 | all_epairs = {} 65 | all_events = {} 66 | cat_questions = {} 67 | for q in questions: 68 | if q.rel.e1.__repr__() not in all_events: 69 | all_events[q.rel.e1.__repr__()] = q.rel.e1 70 | if q.rel.e2.__repr__() not in all_events: 71 | all_events[q.rel.e2.__repr__()] = q.rel.e1 72 | 73 | 74 | if q.rel.rtype_idx not in all_epairs: 75 | all_epairs[q.rel.rtype_idx] = {} 76 | key = str((q.rel.e1, q.rel.e2)) 77 | all_epairs[q.rel.rtype_idx][key] = 1 78 | key = str((q.rel.e2, q.rel.e1)) 79 | all_epairs[q.rel.rtype_idx][key] = 1 80 | 81 | if q.rel.rtype_idx not in cat_questions: 82 | cat_questions[q.rel.rtype_idx] = [] 83 | cat_questions[q.rel.rtype_idx].append((1, q)) 84 | 85 | for i in range(len(cat_questions)): 86 | cat_questions[i] = cat_questions[i][:n_cat_questions] 87 | 88 | # negative 89 | evs = list(all_events.values()) 90 | for i_cat in range(len(cat_questions)): 91 | nqs = [] 92 | for label, q in cat_questions[i_cat]: 93 | nq = deepcopy(q) 94 | nq.choices = None 95 | nq.ans_idx = None 96 | while True: 97 | tmpe = evs[np.random.randint(0, len(evs))] 98 | key = str((nq.rel.e1, tmpe)) 99 | if key not in all_epairs[q.rel.rtype_idx]: 100 | nq.rel.e2 = tmpe 101 | break 102 | nqs.append((0, nq)) 103 | cat_questions[i_cat] += (nqs) 104 | return cat_questions 105 | 106 | 107 | def build_ev_embeddings(questions, config, pred2idx, argw2idx, model): 108 | with torch.no_grad(): 109 | idx = 0 110 | e2idx = {} 111 | xs = [] 112 | for i_cat in questions.keys(): 113 | for label, q in questions[i_cat]: 114 | for e in [q.rel.e1, q.rel.e2]: 115 | if e.__repr__() not in e2idx: 116 | e2idx[e.__repr__()] = idx 117 | idx += 1 118 | x = utils.get_raw_event_repr(e, config, pred2idx, argw2idx) 119 | xs.append(x) 120 | xs = torch.stack(xs, dim=0).to(args.device) 121 | embeddings = model.embed_event(xs) 122 | return e2idx, embeddings 123 | 124 | 125 | def get_event_sentence(e): 126 | # return e.sentence 127 | preds = ' '.join(e.pred.split('_')) 128 | sent = e.arg0 + ' ' + preds + ' ' + e.arg1 129 | return sent 130 | 131 | 132 | def _elmo_batch(ids, elmo, batch_size=500): 133 | with torch.no_grad(): 134 | embss = [] 135 | n_batches = ids.shape[0] // batch_size 136 | if ids.shape[0] % batch_size > 0: 137 | n_batches += 1 138 | for i_batch in range(n_batches): 139 | start = i_batch * batch_size 140 | end = (i_batch + 1) * batch_size 141 | 142 | embs = elmo(ids[start:end])['elmo_representations'][0].detach() 143 | embss.append(embs) 144 | embss = torch.cat(embss, dim=0).sum(dim=1) 145 | return embss 146 | 147 | 148 | def build_elmo(questions, elmo): 149 | idx = 0 150 | e2idx = {} 151 | xs = [] 152 | for i_cat in questions.keys(): 153 | for label, q in questions[i_cat]: 154 | for e in [q.rel.e1, q.rel.e2]: 155 | if e.__repr__() not in e2idx: 156 | e2idx[e.__repr__()] = idx 157 | idx += 1 158 | 159 | s = get_event_sentence(e) 160 | e_tokens = [w.lower() for w in word_tokenize(s)] 161 | xs.append(e_tokens) 162 | xs = batch_to_ids(xs).to(args.device) 163 | embeddings = _elmo_batch(xs, elmo) 164 | return e2idx, embeddings 165 | 166 | 167 | def score_questions(questions, model, e2idx, embeddings): 168 | with torch.no_grad(): 169 | cat_scores = {} 170 | for i_cat in questions.keys(): 171 | e1s, e2s, rs = [], [], [] 172 | for label, q in questions[i_cat]: 173 | e1 = embeddings[e2idx[q.rel.e1.__repr__()]] 174 | e2 = embeddings[e2idx[q.rel.e2.__repr__()]] 175 | e1s.append(e1) 176 | e2s.append(e2) 177 | rs.append(q.rel.rtype_idx) 178 | 179 | e1s = torch.stack(e1s, dim=0).to(args.device) 180 | e2s = torch.stack(e2s, dim=0).to(args.device) 181 | if args.use_elmo: 182 | scores = F.cosine_similarity(e1s, e2s, dim=1) 183 | else: 184 | rs = torch.LongTensor(rs).to(args.device) 185 | remb = model.rel_embeddings(rs) 186 | 187 | e1s = model._transfer(e1s, rs) 188 | e2s = model._transfer(e2s, rs) 189 | e2s = model._transfer(e2s, rs) 190 | _scores = model._calc(e1s, e2s, remb) 191 | scores = torch.sum(_scores, 1) 192 | if model.norm > 1: 193 | scores = torch.pow(scores, 1.0 / model.norm) 194 | cat_scores[i_cat] = scores 195 | return cat_scores 196 | 197 | 198 | def pred_acc(y, scores, threshold): 199 | if args.use_elmo: 200 | preds = (scores > threshold).type(torch.int64) 201 | else: 202 | preds = (scores < threshold).type(torch.int64) 203 | return (y == preds).type(torch.float32).sum() / y.shape[0] 204 | 205 | 206 | def dev_thresholds(questions, model, e2idx, embeddings, grid, threshold_min=0.0, threshold_max=100.0): 207 | best_thresholds = {} 208 | 209 | scores = score_questions(questions, model, e2idx, embeddings) 210 | for i_cat in questions.keys(): 211 | y = torch.LongTensor([label for label, q in questions[i_cat]]).to(args.device) 212 | best_acc = 0.0 213 | best_threshold = None 214 | for t in np.arange(threshold_min, threshold_max, grid): 215 | acc = pred_acc(y, scores[i_cat], t) 216 | if acc > best_acc: 217 | best_acc = acc 218 | best_threshold = t 219 | logging.debug("cat={}, dev_acc={}, t={}".format(i_cat, acc, t)) 220 | 221 | logging.info("cat={}, dev_acc={}, best_threshold={}".format(i_cat, best_acc, best_threshold)) 222 | best_thresholds[i_cat] = best_threshold 223 | return best_thresholds 224 | 225 | 226 | def main(): 227 | config = json.load(open(args.training_config, 'r')) 228 | indices.set_relation_classes(args.relation_config) 229 | 230 | if args.use_elmo: 231 | logging.info("using ELMo") 232 | elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0).to(args.device) 233 | else: 234 | pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) 235 | argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) 236 | n_preds = len(pred2idx) 237 | argw_vocabs = argw2idx.keys() 238 | argw_encoder = create_argw_encoder(config, args.device) 239 | if args.encoder_file: 240 | argw_encoder.load(args.encoder_file) 241 | 242 | logging.info("model class: " + config['model_type']) 243 | ModelClass = eval(config['model_type']) 244 | dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) 245 | dnee_model.load_state_dict(torch.load(args.model_file, 246 | map_location=lambda storage, location: storage)) 247 | model = elmo if args.use_elmo else dnee_model 248 | results = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) 249 | precisions = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) 250 | recalls = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) 251 | f1s = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) 252 | for i_round in range(args.n_rounds): 253 | logging.info("ROUND {}".format(i_round)) 254 | 255 | # dev 256 | dev_questions = sample_questions(args.dev_question_file, n_cat_questions=500) 257 | if args.use_elmo: 258 | dev_e2idx, dev_ev_embeddings = build_elmo(dev_questions, elmo) 259 | else: 260 | dev_e2idx, dev_ev_embeddings = build_ev_embeddings(dev_questions, config, pred2idx, argw2idx, dnee_model) 261 | thresholds = dev_thresholds(dev_questions, model, dev_e2idx, dev_ev_embeddings, args.step_size) 262 | 263 | # test results 264 | test_questions = sample_questions(args.test_question_file, n_cat_questions=500) 265 | if args.use_elmo: 266 | test_e2idx, test_ev_embeddings = build_elmo(test_questions, elmo) 267 | else: 268 | test_e2idx, test_ev_embeddings = build_ev_embeddings(test_questions, config, pred2idx, argw2idx, dnee_model) 269 | test_scores = score_questions(test_questions, model, test_e2idx, test_ev_embeddings) 270 | for i_cat in test_questions.keys(): 271 | y = torch.LongTensor([label for label, q in test_questions[i_cat]]).to(args.device) 272 | acc = pred_acc(y, test_scores[i_cat], thresholds[i_cat]) 273 | 274 | if args.use_elmo: 275 | y_preds = (test_scores[i_cat] > thresholds[i_cat]).type(torch.int64) 276 | else: 277 | y_preds = (test_scores[i_cat] < thresholds[i_cat]).type(torch.int64) 278 | 279 | y = y.detach().cpu().numpy() 280 | y_preds = y_preds.detach().cpu().numpy() 281 | 282 | prec = metrics.precision_score(y, y_preds) 283 | rec = metrics.recall_score(y, y_preds) 284 | f1 = metrics.f1_score(y, y_preds) 285 | logging.info("i_cat={} ({}), test_acc={}".format(i_cat, indices.IDX2REL[i_cat], acc)) 286 | logging.info("i_cat={} ({}), test_prec={}".format(i_cat, indices.IDX2REL[i_cat], prec)) 287 | logging.info("i_cat={} ({}), test_rec={}".format(i_cat, indices.IDX2REL[i_cat], rec)) 288 | results[i_round][i_cat] = acc 289 | precisions[i_round][i_cat] = prec 290 | recalls[i_round][i_cat] = rec 291 | f1s[i_round][i_cat] = f1 292 | 293 | avg = torch.mean(results, dim=0) 294 | avg_precisions = torch.mean(precisions, dim=0) 295 | avg_recalls = torch.mean(recalls, dim=0) 296 | avg_f1s = torch.mean(f1s, dim=0) 297 | for i_cat in test_questions.keys(): 298 | logging.info("i_cat={} ({}), avg_test_acc={} over {} rounds".format(i_cat, indices.IDX2REL[i_cat], avg[i_cat], args.n_rounds)) 299 | logging.info("i_cat={} ({}), avg_test_prec={} over {} rounds".format(i_cat, indices.IDX2REL[i_cat], avg_precisions[i_cat], args.n_rounds)) 300 | logging.info("i_cat={} ({}), avg_test_rec={} over {} rounds".format(i_cat, indices.IDX2REL[i_cat], avg_recalls[i_cat], args.n_rounds)) 301 | logging.info("i_cat={} ({}), avg_test_f1={} over {} rounds".format(i_cat, indices.IDX2REL[i_cat], avg_f1s[i_cat], args.n_rounds)) 302 | 303 | 304 | if __name__ == "__main__": 305 | args = utils.bin_config(get_arguments) 306 | if args.gpu_id is not None: 307 | args.device = torch.device('cuda:{}'.format(args.gpu_id)) 308 | else: 309 | args.device = torch.device('cpu') 310 | main() 311 | -------------------------------------------------------------------------------- /dnee/models/skipthoughts.py: -------------------------------------------------------------------------------- 1 | """ 2 | a porting from https://github.com/Cadene/skip-thoughts.torch 3 | thanks to their effort. 4 | """ 5 | 6 | import os 7 | import sys 8 | import numpy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from torch.autograd import Variable 15 | from collections import OrderedDict 16 | 17 | 18 | ############################################################### 19 | # UniSkip 20 | ############################################################### 21 | class AbstractSkipThoughts(nn.Module): 22 | 23 | def __init__(self, dir_st, vocab, save=True, dropout=0, fixed_emb=True, 24 | fixed_net=True, emb_fpath=None, device=torch.device('cpu')): 25 | super(AbstractSkipThoughts, self).__init__() 26 | self.device = device 27 | self.dir_st = dir_st 28 | self.vocab = vocab 29 | self.save = save 30 | self.dropout = dropout 31 | self.fixed_emb = fixed_emb 32 | # Module 33 | self.embedding = self._load_embedding(emb_fpath) 34 | if fixed_emb: 35 | self.embedding.weight.requires_grad = False 36 | self.rnn = self._load_rnn() 37 | if fixed_net: 38 | for p in self.rnn.parameters(): 39 | p.requires_grad = False 40 | 41 | def _get_table_name(self): 42 | raise NotImplementedError 43 | 44 | def _get_skip_name(self): 45 | raise NotImplementedError 46 | 47 | def _load_dictionary(self): 48 | path_dico = os.path.join(self.dir_st, 'dictionary.txt') 49 | with open(path_dico, 'r') as handle: 50 | dico_list = handle.readlines() 51 | dico = {word.strip():idx for idx,word in enumerate(dico_list)} 52 | return dico 53 | 54 | def _load_emb_params(self): 55 | table_name = self._get_table_name() 56 | path_params = os.path.join(self.dir_st, table_name+'.npy') 57 | params = numpy.load(path_params, encoding='latin1') # to load from python2 58 | return params 59 | 60 | def _load_rnn_params(self): 61 | skip_name = self._get_skip_name() 62 | path_params = os.path.join(self.dir_st, skip_name+'.npz') 63 | params = numpy.load(path_params, encoding='latin1') # to load from python2 64 | return params 65 | 66 | def _load_embedding(self, emb_fpath): 67 | if self.save: 68 | import hashlib 69 | import pickle 70 | # http://stackoverflow.com/questions/20416468/fastest-way-to-get-a-hash-from-a-list-in-python 71 | hash_id = hashlib.sha256(pickle.dumps(self.vocab, -1)).hexdigest() 72 | path = emb_fpath if emb_fpath else 'st_embedding_'+str(hash_id)+'.pth' 73 | if self.save and os.path.exists(path): 74 | self.embedding = torch.load(path) 75 | else: 76 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab) + 1, 77 | embedding_dim=620, 78 | padding_idx=0, # -> first_dim = zeros 79 | sparse=False) 80 | dictionary = self._load_dictionary() 81 | parameters = self._load_emb_params() 82 | state_dict = self._make_emb_state_dict(dictionary, parameters) 83 | self.embedding.load_state_dict(state_dict) 84 | if self.save: 85 | torch.save(self.embedding, path) 86 | return self.embedding 87 | 88 | def _make_emb_state_dict(self, dictionary, parameters): 89 | weight = torch.zeros(len(self.vocab)+1, 620) # first dim = zeros -> +1 90 | unknown_params = parameters[dictionary['UNK']] 91 | nb_unknown = 0 92 | for id_weight, word in enumerate(self.vocab): 93 | if word in dictionary: 94 | id_params = dictionary[word] 95 | params = parameters[id_params] 96 | else: 97 | #print('Warning: word `{}` not in dictionary'.format(word)) 98 | params = unknown_params 99 | nb_unknown += 1 100 | weight[id_weight+1] = torch.from_numpy(params) 101 | state_dict = OrderedDict({'weight':weight}) 102 | if nb_unknown > 0: 103 | print('Warning: {}/{} words are not in dictionary, thus set UNK' 104 | .format(nb_unknown, len(dictionary))) 105 | return state_dict 106 | 107 | def _select_last(self, x, lengths): 108 | batch_size = x.size(0) 109 | seq_length = x.size(1) 110 | mask = x.data.new().resize_as_(x.data).fill_(0) 111 | for i in range(batch_size): 112 | mask[i][lengths[i]-1].fill_(1) 113 | mask = Variable(mask) 114 | x = x.mul(mask) 115 | x = x.sum(1).view(batch_size, -1) 116 | return x 117 | 118 | def _select_last_old(self, input, lengths): 119 | batch_size = input.size(0) 120 | x = [] 121 | for i in range(batch_size): 122 | x.append(input[i,lengths[i]-1].view(1, -1)) 123 | output = torch.cat(x, 0) 124 | return output 125 | 126 | def _process_lengths(self, input): 127 | max_length = input.size(1) 128 | if input.shape[0] == 1: 129 | lengths = (max_length - input.data.eq(0).sum(1)).tolist() 130 | else: 131 | lengths = (max_length - input.data.eq(0).sum(1).squeeze()).tolist() 132 | return lengths 133 | 134 | def _load_rnn(self): 135 | raise NotImplementedError 136 | 137 | def _make_rnn_state_dict(self, p): 138 | raise NotImplementedError 139 | 140 | def forward(self, input, lengths=None): 141 | raise NotImplementedError 142 | 143 | 144 | ################################################################################### 145 | # UniSkip 146 | ################################################################################### 147 | 148 | class UniSkip(AbstractSkipThoughts): 149 | 150 | def __init__(self, dir_st, vocab, save=True, dropout=0.0, fixed_emb=True, 151 | fixed_net=True, emb_fpath=None, device=torch.device('cpu')): 152 | super(UniSkip, self).__init__(dir_st, vocab, save, dropout, fixed_emb, fixed_net, emb_fpath, device) 153 | # Remove bias_ih_l0 (== zero all the time) 154 | # del self.gru._parameters['bias_hh_l0'] 155 | # del self.gru._all_weights[0][3] 156 | 157 | def _get_table_name(self): 158 | return 'utable' 159 | 160 | def _get_skip_name(self): 161 | return 'uni_skip' 162 | 163 | def _load_rnn(self): 164 | self.rnn = nn.GRU(input_size=620, 165 | hidden_size=2400, 166 | batch_first=True, 167 | dropout=self.dropout) 168 | parameters = self._load_rnn_params() 169 | state_dict = self._make_rnn_state_dict(parameters) 170 | self.rnn.load_state_dict(state_dict) 171 | return self.rnn 172 | 173 | def _make_rnn_state_dict(self, p): 174 | s = OrderedDict() 175 | s['bias_ih_l0'] = torch.zeros(7200) 176 | s['bias_hh_l0'] = torch.zeros(7200) # must stay equal to 0 177 | s['weight_ih_l0'] = torch.zeros(7200, 620) 178 | s['weight_hh_l0'] = torch.zeros(7200, 2400) 179 | s['weight_ih_l0'][:4800] = torch.from_numpy(p['encoder_W']).t() 180 | s['weight_ih_l0'][4800:] = torch.from_numpy(p['encoder_Wx']).t() 181 | s['bias_ih_l0'][:4800] = torch.from_numpy(p['encoder_b']) 182 | s['bias_ih_l0'][4800:] = torch.from_numpy(p['encoder_bx']) 183 | s['weight_hh_l0'][:4800] = torch.from_numpy(p['encoder_U']).t() 184 | s['weight_hh_l0'][4800:] = torch.from_numpy(p['encoder_Ux']).t() 185 | return s 186 | 187 | def forward(self, input, lengths=None): 188 | if lengths is None: 189 | lengths = self._process_lengths(input) 190 | x = self.embedding(input) 191 | x, hn = self.rnn(x) # seq2seq 192 | if lengths: 193 | x = self._select_last(x, lengths) 194 | return x 195 | 196 | 197 | 198 | ############################################################### 199 | # BiSkip 200 | ############################################################### 201 | 202 | class BiSkip(AbstractSkipThoughts): 203 | 204 | def __init__(self, dir_st, vocab, save=True, dropout=0.0, fixed_emb=True, 205 | fixed_net=True, emb_fpath=None, device=torch.device('cpu')): 206 | super(BiSkip, self).__init__(dir_st, vocab, save, dropout, fixed_emb, fixed_net, emb_fpath, device) 207 | # Remove bias_ih_l0 (== zero all the time) 208 | # del self.gru._parameters['bias_hh_l0'] 209 | # del self.gru._all_weights[0][3] 210 | 211 | def _get_table_name(self): 212 | return 'btable' 213 | 214 | def _get_skip_name(self): 215 | return 'bi_skip' 216 | 217 | def _load_rnn(self): 218 | self.rnn = nn.GRU(input_size=620, 219 | hidden_size=1200, 220 | batch_first=True, 221 | dropout=self.dropout, 222 | bidirectional=True) 223 | parameters = self._load_rnn_params() 224 | state_dict = self._make_rnn_state_dict(parameters) 225 | self.rnn.load_state_dict(state_dict) 226 | return self.rnn 227 | 228 | def _make_rnn_state_dict(self, p): 229 | s = OrderedDict() 230 | s['bias_ih_l0'] = torch.zeros(3600) 231 | s['bias_hh_l0'] = torch.zeros(3600) # must stay equal to 0 232 | s['weight_ih_l0'] = torch.zeros(3600, 620) 233 | s['weight_hh_l0'] = torch.zeros(3600, 1200) 234 | 235 | s['bias_ih_l0_reverse'] = torch.zeros(3600) 236 | s['bias_hh_l0_reverse'] = torch.zeros(3600) # must stay equal to 0 237 | s['weight_ih_l0_reverse'] = torch.zeros(3600, 620) 238 | s['weight_hh_l0_reverse'] = torch.zeros(3600, 1200) 239 | 240 | s['weight_ih_l0'][:2400] = torch.from_numpy(p['encoder_W']).t() 241 | s['weight_ih_l0'][2400:] = torch.from_numpy(p['encoder_Wx']).t() 242 | s['bias_ih_l0'][:2400] = torch.from_numpy(p['encoder_b']) 243 | s['bias_ih_l0'][2400:] = torch.from_numpy(p['encoder_bx']) 244 | s['weight_hh_l0'][:2400] = torch.from_numpy(p['encoder_U']).t() 245 | s['weight_hh_l0'][2400:] = torch.from_numpy(p['encoder_Ux']).t() 246 | 247 | s['weight_ih_l0_reverse'][:2400] = torch.from_numpy(p['encoder_r_W']).t() 248 | s['weight_ih_l0_reverse'][2400:] = torch.from_numpy(p['encoder_r_Wx']).t() 249 | s['bias_ih_l0_reverse'][:2400] = torch.from_numpy(p['encoder_r_b']) 250 | s['bias_ih_l0_reverse'][2400:] = torch.from_numpy(p['encoder_r_bx']) 251 | s['weight_hh_l0_reverse'][:2400] = torch.from_numpy(p['encoder_r_U']).t() 252 | s['weight_hh_l0_reverse'][2400:] = torch.from_numpy(p['encoder_r_Ux']).t() 253 | return s 254 | 255 | def _argsort(self, seq): 256 | return sorted(range(len(seq)), key=seq.__getitem__) 257 | 258 | def forward(self, input, lengths=None): 259 | batch_size = input.size(0) 260 | if lengths is None: 261 | lengths = self._process_lengths(input) 262 | sorted_lengths = sorted(lengths) 263 | sorted_lengths = sorted_lengths[::-1] 264 | idx = self._argsort(lengths) 265 | idx = idx[::-1] # decreasing order 266 | inverse_idx = self._argsort(idx) 267 | idx = Variable(torch.LongTensor(idx)) 268 | inverse_idx = Variable(torch.LongTensor(inverse_idx)) 269 | if input.data.is_cuda: 270 | idx = idx.to(self.device) 271 | inverse_idx = inverse_idx.to(self.device) 272 | x = torch.index_select(input, 0, idx) 273 | 274 | x = self.embedding(x) 275 | x = nn.utils.rnn.pack_padded_sequence(x, sorted_lengths, batch_first=True) 276 | x, hn = self.rnn(x) # seq2seq 277 | hn = hn.transpose(0, 1) 278 | hn = hn.contiguous() 279 | hn = hn.view(batch_size, 2 * hn.size(2)) 280 | 281 | hn = torch.index_select(hn, 0, inverse_idx) 282 | return hn 283 | 284 | 285 | ############################################################### 286 | # RawBiSkip 287 | ############################################################### 288 | 289 | class CustomizedBiSkip(AbstractSkipThoughts): 290 | 291 | def __init__(self, dir_st, vocab, save=True, dropout=0.0, fixed_emb=True, 292 | fixed_net=False, emb_fpath=None, device=torch.device('cpu'), 293 | hidden_size=250): 294 | self.hidden_size = hidden_size 295 | super(CustomizedBiSkip, self).__init__(dir_st, vocab, save, dropout, fixed_emb, fixed_net, emb_fpath, device) 296 | # Remove bias_ih_l0 (== zero all the time) 297 | # del self.gru._parameters['bias_hh_l0'] 298 | # del self.gru._all_weights[0][3] 299 | 300 | def _get_table_name(self): 301 | return 'btable' 302 | 303 | def _get_skip_name(self): 304 | return 'bi_skip' 305 | 306 | def _load_rnn(self): 307 | self.rnn = nn.GRU(input_size=620, 308 | hidden_size=int(self.hidden_size), 309 | batch_first=True, 310 | dropout=self.dropout, 311 | bidirectional=True) 312 | return self.rnn 313 | 314 | def _argsort(self, seq): 315 | return sorted(range(len(seq)), key=seq.__getitem__) 316 | 317 | def forward(self, input, lengths=None): 318 | batch_size = input.size(0) 319 | if lengths is None: 320 | lengths = self._process_lengths(input) 321 | sorted_lengths = sorted(lengths) 322 | sorted_lengths = sorted_lengths[::-1] 323 | idx = self._argsort(lengths) 324 | idx = idx[::-1] # decreasing order 325 | inverse_idx = self._argsort(idx) 326 | idx = Variable(torch.LongTensor(idx)) 327 | inverse_idx = Variable(torch.LongTensor(inverse_idx)) 328 | if input.data.is_cuda: 329 | idx = idx.to(self.device) 330 | inverse_idx = inverse_idx.to(self.device) 331 | x = torch.index_select(input, 0, idx) 332 | 333 | x = self.embedding(x) 334 | x = nn.utils.rnn.pack_padded_sequence(x, sorted_lengths, batch_first=True) 335 | x, hn = self.rnn(x) # seq2seq 336 | hn = hn.transpose(0, 1) 337 | hn = hn.contiguous() 338 | hn = hn.view(batch_size, 2 * hn.size(2)) 339 | 340 | hn = torch.index_select(hn, 0, inverse_idx) 341 | return hn 342 | 343 | 344 | if __name__ == '__main__': 345 | dir_st = 'data/skipthought_models' 346 | vocab = ['robots', 'are', 'very', 'cool', '', 'BiDiBu'] 347 | 348 | us_model = UniSkip(dir_st, vocab) 349 | bs_model = BiSkip(dir_st, vocab) 350 | 351 | # batch_size x seq_len 352 | input = Variable(torch.LongTensor([ 353 | [6,1,2,3,3,4,0], 354 | [6,1,2,3,3,4,5], 355 | [1,2,3,4,0,0,0] 356 | ])) 357 | print(input.size()) 358 | 359 | # for skipping dropout layers 360 | us_model.eval() 361 | bs_model.eval() 362 | 363 | # batch_size x 2400 364 | us_seq2vec = us_model(input) 365 | print(us_seq2vec) 366 | 367 | bs_seq2vec = bs_model(input) 368 | print(bs_seq2vec) 369 | -------------------------------------------------------------------------------- /dnee/models/event_trans.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | from .skipthoughts import UniSkip, BiSkip, CustomizedBiSkip 12 | from ..events import indices 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | # These models are portings from OpenKE 17 | # https://github.com/thunlp/OpenKE 18 | # Thanks to their effort 19 | 20 | 21 | class ArgWordEncoder(object): 22 | def __init__(self, config, argw_vocabs=None, use_uniskip=True, use_biskip=True, use_customized_biskip=False, device=torch.device('cpu')): 23 | self.device = device 24 | self.argw_vocabs = self.load_argw_vocabs(config['argw_indices']) if argw_vocabs is None else argw_vocabs 25 | self.output_dim = 0 26 | if use_uniskip: 27 | # fixed, no parameters 28 | self.uniskip = UniSkip(config['skipthought_dir'], self.argw_vocabs, 29 | save=True, dropout=0, fixed_emb=True, 30 | fixed_net=True, device=self.device).to(self.device) 31 | self.output_dim += 2400 32 | else: 33 | self.uniskip = None 34 | 35 | if use_biskip: 36 | # fixed, no parameters 37 | self.biskip = BiSkip(config['skipthought_dir'], self.argw_vocabs, 38 | save=True, dropout=0, fixed_emb=True, 39 | fixed_net=True, device=self.device).to(self.device) 40 | self.output_dim += 2400 41 | else: 42 | self.biskip = None 43 | 44 | # network not fixed 45 | if use_customized_biskip: 46 | hidden_size = config['pred_dim'] / 2 47 | logger.info("cskip runs on {}, hidden size = {}".format(self.device, hidden_size)) 48 | self.c_biskip = CustomizedBiSkip(config['skipthought_dir'], self.argw_vocabs, 49 | save=True, dropout=0.0, fixed_emb=True, 50 | fixed_net=False, device=self.device, 51 | hidden_size=hidden_size).to(self.device) 52 | self.output_dim += config['pred_dim'] 53 | else: 54 | self.c_biskip = None 55 | logger.info("arg_dim: {}".format(self.output_dim)) 56 | 57 | def encode(self, x): 58 | uvec = self.uniskip(x) if self.uniskip else None 59 | bvec = self.biskip(x) if self.biskip else None 60 | cvec = self.c_biskip(x) if self.c_biskip else None 61 | comb_vec = tuple(v for v in (uvec, bvec, cvec) if v is not None) 62 | comb_vec = torch.cat(comb_vec, 1) 63 | return comb_vec 64 | 65 | def save(self, fpath): 66 | if self.c_biskip: 67 | torch.save(self.c_biskip.state_dict(), fpath) 68 | # not saving skipthough, since won't update it 69 | 70 | def load(self, fpath): 71 | if self.c_biskip: 72 | self.c_biskip.load_state_dict(torch.load(fpath, map_location=lambda storage, loc: storage)) 73 | 74 | @staticmethod 75 | def load_argw_vocabs(fpath): 76 | key2idx, _, _ = indices.load_argw(fpath) 77 | return list(key2idx.keys()) 78 | 79 | 80 | class AbstractEventTrans(nn.Module): 81 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 82 | super(AbstractEventTrans, self).__init__() 83 | self.norm = config['norm'] 84 | self.device = device 85 | self.config = config 86 | self.n_preds = self.get_n_preds(config['predicate_indices']) if n_preds is None else n_preds 87 | self.argw_encoder = argw_encoder 88 | logging.info('pred_dim: {}'.format(config['pred_dim'])) 89 | self.rel_zero_shot = rel_zero_shot 90 | if rel_zero_shot: 91 | logging.info('n_old_rel_types: {}'.format(config['n_old_rel_types'])) 92 | self.rel_embeddings = nn.Embedding(config['n_old_rel_types'], config['rel_dim']).to(self.device) 93 | else: 94 | logging.info('n_rel_types: {}'.format(config['n_rel_types'])) 95 | self.rel_embeddings = nn.Embedding(config['n_rel_types'], config['rel_dim']).to(self.device) 96 | 97 | logging.info('rel_dim: {}'.format(config['rel_dim'])) 98 | self.pred_embeddings = nn.Embedding(self.n_preds, config['pred_dim']).to(self.device) 99 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 100 | nn.init.xavier_uniform_(self.pred_embeddings.weight.data) 101 | 102 | def transfer_rel(self, sense_map, new_rel2idx, old_rel2idx): 103 | assert self.rel_zero_shot 104 | new_rel_embeddings = nn.Embedding(self.config['n_rel_types'], self.config['rel_dim']).to(self.device) 105 | nn.init.xavier_uniform_(new_rel_embeddings.weight.data) 106 | for rel, new_idx in new_rel2idx.iteritems(): 107 | if rel in sense_map: 108 | old_idx = old_rel2idx[sense_map[rel]] 109 | new_rel_embeddings.weight.data[new_idx] = self.rel_embeddings.weight.data[old_idx] 110 | self.rel_embeddings = new_rel_embeddings 111 | 112 | def loss_func(self, p_score, n_score): 113 | criterion = nn.MarginRankingLoss(self.config['margin'], False).to(self.device) 114 | y = Variable(torch.Tensor([-1]).to(self.device), requires_grad=False) 115 | loss = criterion(p_score, n_score, y) 116 | return loss 117 | 118 | @staticmethod 119 | def get_n_preds(fpath): 120 | pred2idx, idx2pred, _ = indices.load_predicates(fpath) 121 | return len(pred2idx) 122 | 123 | def forward(self, x): 124 | raise NotImplementedError 125 | 126 | def _get_event_raw_features(self, p, arg0, arg1): 127 | p_emb = self.pred_embeddings(p) 128 | arg0_emb = self.argw_encoder.encode(arg0) 129 | arg1_emb = self.argw_encoder.encode(arg1) 130 | ev_raw = torch.cat((p_emb, arg0_emb, arg1_emb), 1) 131 | return ev_raw 132 | 133 | def _event_indices(self, x): 134 | assert x.shape[1] == 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 135 | p = x[:, 0] 136 | arg0 = x[:, 1:1+self.config['arg0_max_len']] 137 | arg1 = x[:, 1+self.config['arg0_max_len']: 1+self.config['arg0_max_len']+self.config['arg1_max_len']] 138 | return p, arg0, arg1 139 | 140 | def _calc(self, h, t, r): 141 | return torch.abs(h + r - t) if self.norm == 1 else torch.pow(torch.abs(h + r - t), self.norm) 142 | 143 | 144 | class AbstractCompEventTrans(AbstractEventTrans): 145 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 146 | super(AbstractCompEventTrans, self).__init__(config, argw_encoder, n_preds, device, rel_zero_shot) 147 | event_raw_dim = config['pred_dim'] + self.argw_encoder.output_dim * 2 148 | logging.info('event_raw_dim: {}'.format(event_raw_dim)) 149 | logging.info('event_hidden_dim: {}'.format(config['event_hidden_dim'])) 150 | self.event_l1 = nn.Linear(event_raw_dim, config['event_hidden_dim']) 151 | self.relu = nn.ReLU() 152 | self.event_l2 = nn.Linear(config['event_hidden_dim'], config['event_dim']) 153 | logging.info('event_dim: {}'.format(config['event_dim'])) 154 | 155 | def embed_event(self, x): 156 | ev_len = 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 157 | assert ev_len == x.shape[1] 158 | ev_p, ev_arg0, ev_arg1 = self._event_indices(x) 159 | ev_raw = self._get_event_raw_features(ev_p, ev_arg0, ev_arg1) 160 | ev_h1 = self.relu(self.event_l1(ev_raw)) 161 | ev_emb = self.event_l2(ev_h1) 162 | return ev_emb 163 | 164 | def forward(self, x): 165 | raise NotImplementedError 166 | 167 | 168 | class AbstractAttnEventTrans(AbstractEventTrans): 169 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 170 | super(AbstractAttnEventTrans, self).__init__(config, argw_encoder, n_preds, device, rel_zero_shot) 171 | assert config['pred_dim'] == self.argw_encoder.output_dim 172 | assert config['pred_dim'] == config['event_dim'] 173 | self.w_attn = nn.Parameter(torch.FloatTensor(config['pred_dim'], 1)) 174 | nn.init.xavier_normal_(self.w_attn) 175 | 176 | def embed_event(self, x): 177 | ev_len = 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 178 | assert ev_len == x.shape[1] 179 | ev_p, ev_arg0, ev_arg1 = self._event_indices(x) 180 | ev_raw = self._get_event_raw_features(ev_p, ev_arg0, ev_arg1) 181 | batch_size = x.shape[0] 182 | d = self.config['pred_dim'] 183 | # attention: a soft combination of v_pred, v_a0, v_a1 184 | v_pred, v_arg0, v_arg1 = ev_raw[:, :d], \ 185 | ev_raw[:, d:2*d], \ 186 | ev_raw[:, 2*d:] 187 | vs = torch.stack((v_pred, v_arg0, v_arg1), dim=1) 188 | attn_scores = torch.matmul(vs, self.w_attn).squeeze() 189 | attn_scores = F.softmax(attn_scores).view(batch_size, 3, 1) 190 | scored_vs = vs * attn_scores 191 | condensed_vs = torch.sum(scored_vs, dim=1) 192 | return condensed_vs 193 | 194 | def forward(self, x): 195 | raise NotImplementedError 196 | 197 | 198 | class EventTransE(AbstractCompEventTrans): 199 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 200 | super(EventTransE, self).__init__(config, argw_encoder, n_preds, device, rel_zero_shot) 201 | 202 | def forward(self, x): 203 | rtype = x[:, 0] 204 | rel_emb = self.rel_embeddings(rtype) 205 | 206 | elen = 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 207 | e1_emb = self.embed_event(x[:, 1: 1+elen]) 208 | e2_emb = self.embed_event(x[:, 1+elen: 1+elen+elen]) 209 | 210 | _score = self._calc(e1_emb, e2_emb, rel_emb) 211 | score = torch.sum(_score, 1) 212 | if self.norm > 1: 213 | score = torch.pow(score, 1.0 / self.norm) 214 | return score 215 | 216 | def _transfer(self, ev_emb, rtype): 217 | return ev_emb 218 | 219 | 220 | class EventTransR(AbstractCompEventTrans): 221 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 222 | super(EventTransR, self).__init__(config, argw_encoder, n_preds, device, rel_zero_shot) 223 | self.transfer_matrix = nn.Embedding(config['n_rel_types'], config['rel_dim']*config['event_dim']).to(self.device) 224 | nn.init.xavier_uniform_(self.transfer_matrix.weight.data) 225 | 226 | def forward(self, x): 227 | rtype = x[:, 0] 228 | rel_emb = self.rel_embeddings(rtype).view(-1, self.config['rel_dim']) 229 | 230 | rel_m = self.transfer_matrix(rtype).view(-1, self.config['rel_dim'], self.config['event_dim']) 231 | 232 | elen = 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 233 | e1_emb = self.embed_event(x[:, 1: 1+elen]).view(-1, self.config['event_dim'], 1) 234 | e2_emb = self.embed_event(x[:, 1+elen: 1+elen+elen]).view(-1, self.config['event_dim'], 1) 235 | 236 | e1_emb_t = torch.matmul(rel_m, e1_emb).view(-1, self.config['rel_dim']) 237 | e2_emb_t = torch.matmul(rel_m, e2_emb).view(-1, self.config['rel_dim']) 238 | 239 | _score = self._calc(e1_emb_t, e2_emb_t, rel_emb) 240 | score = torch.sum(_score, 1) 241 | if self.norm > 1: 242 | score = torch.pow(score, 1.0 / self.norm) 243 | return score 244 | 245 | def _transfer(self, ev_emb, rtype): 246 | if rtype.shape == torch.Size([1]): 247 | rtype = rtype.expand(ev_emb.shape[0], 1) 248 | m = self.transfer_matrix(rtype).view(-1, self.config['rel_dim'], self.config['event_dim']) 249 | ev_emb = ev_emb.view(-1, self.config['event_dim'], 1) 250 | return torch.matmul(m, ev_emb).view(-1, self.config['rel_dim']) 251 | 252 | 253 | class AttnEventTransE(AbstractAttnEventTrans): 254 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 255 | super(AttnEventTransE, self).__init__(config, argw_encoder, n_preds, device, rel_zero_shot) 256 | 257 | def forward(self, x): 258 | rtype = x[:, 0] 259 | rel_emb = self.rel_embeddings(rtype) 260 | 261 | elen = 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 262 | e1_emb = self.embed_event(x[:, 1: 1+elen]) 263 | e2_emb = self.embed_event(x[:, 1+elen: 1+elen+elen]) 264 | 265 | _score = self._calc(e1_emb, e2_emb, rel_emb) 266 | score = torch.sum(_score, 1) 267 | if self.norm > 1: 268 | score = torch.pow(score, 1.0 / self.norm) 269 | return score 270 | 271 | def _transfer(self, ev_emb, rtype): 272 | return ev_emb 273 | 274 | 275 | class AttnEventTransR(AbstractAttnEventTrans): 276 | def __init__(self, config, argw_encoder, n_preds=None, device=torch.device('cpu'), rel_zero_shot=False): 277 | super(AttnEventTransR, self).__init__(config, argw_encoder, n_preds, device, rel_zero_shot) 278 | self.transfer_matrix = nn.Embedding(config['n_rel_types'], config['rel_dim']*config['event_dim']).to(self.device) 279 | nn.init.xavier_uniform_(self.transfer_matrix.weight.data) 280 | 281 | def forward(self, x): 282 | rtype = x[:, 0] 283 | rel_emb = self.rel_embeddings(rtype).view(-1, self.config['rel_dim']) 284 | 285 | rel_m = self.transfer_matrix(rtype).view(-1, self.config['rel_dim'], self.config['event_dim']) 286 | 287 | elen = 1 + self.config['arg0_max_len'] + self.config['arg1_max_len'] 288 | e1_emb = self.embed_event(x[:, 1: 1+elen]).view(-1, self.config['event_dim'], 1) 289 | e2_emb = self.embed_event(x[:, 1+elen: 1+elen+elen]).view(-1, self.config['event_dim'], 1) 290 | 291 | e1_emb_t = torch.matmul(rel_m, e1_emb).view(-1, self.config['rel_dim']) 292 | e2_emb_t = torch.matmul(rel_m, e2_emb).view(-1, self.config['rel_dim']) 293 | 294 | _score = self._calc(e1_emb_t, e2_emb_t, rel_emb) 295 | score = torch.sum(_score, 1) 296 | if self.norm > 1: 297 | score = torch.pow(score, 1.0 / self.norm) 298 | return score 299 | 300 | def _transfer(self, ev_emb, rtype): 301 | if rtype.shape == torch.Size([1]): 302 | rtype = rtype.expand(ev_emb.shape[0], 1) 303 | m = self.transfer_matrix(rtype).view(-1, self.config['rel_dim'], self.config['event_dim']) 304 | ev_emb = ev_emb.view(-1, self.config['event_dim'], 1) 305 | return torch.matmul(m, ev_emb).view(-1, self.config['rel_dim']) 306 | 307 | 308 | def create_argw_encoder(config, device): 309 | logging.info('argw_encoder: {}'.format(config['argw_encoder_opt'])) 310 | argw_vocabs = ArgWordEncoder.load_argw_vocabs(config['argw_indices']) 311 | if config['argw_encoder_opt'] == 'customized_biskip': 312 | argw_encoder = ArgWordEncoder(config, argw_vocabs, 313 | use_uniskip=False, use_biskip=False, use_customized_biskip=True, device=device) 314 | elif config['argw_encoder_opt'] == 'skipthoughts': 315 | argw_encoder = ArgWordEncoder(config, argw_vocabs, 316 | use_uniskip=True, use_biskip=True, use_customized_biskip=False, device=device) 317 | elif config['argw_encoder_opt'] == 'uniskip': 318 | argw_encoder = ArgWordEncoder(config, argw_vocabs, 319 | use_uniskip=True, use_biskip=False, use_customized_biskip=False, device=device) 320 | elif config['argw_encoder_opt'] == 'biskip': 321 | argw_encoder = ArgWordEncoder(config, argw_vocabs, 322 | use_uniskip=False, use_biskip=True, use_customized_biskip=False, device=device) 323 | else: 324 | raise ValueError('unsupported encoder type.') 325 | return argw_encoder 326 | -------------------------------------------------------------------------------- /dnee/evals/intrinsic.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | import logging 5 | from copy import deepcopy 6 | 7 | import six 8 | import numpy as np 9 | import torch 10 | from sklearn.metrics import pairwise 11 | from nltk import word_tokenize 12 | 13 | from ..events import Event, indices 14 | 15 | 16 | def embed_event_word_embeddings(e, we): 17 | # words = [w.lower() for w in word_tokenize(e.sentence)] 18 | tokens = [e.pred] 19 | arg0s = [w.lower() for w in word_tokenize(e.arg0)] 20 | arg1s = [w.lower() for w in word_tokenize(e.arg1)] 21 | arg2s = [w.lower() for w in word_tokenize(e.arg2)] 22 | tokens += arg0s + arg1s + arg2s 23 | dim = next(six.itervalues(we)).shape[0] 24 | avg = torch.zeros(dim, dtype=torch.float32) 25 | cnt = 0 26 | for t in tokens: 27 | if t not in we: 28 | continue 29 | avg += we[t] 30 | cnt += 1 31 | if cnt > 0: 32 | avg /= cnt 33 | return avg 34 | 35 | 36 | class CorefQuestion: 37 | def __init__(self, query_event, ans_idx, choices): 38 | self.query_event = query_event 39 | self.choices = choices 40 | self.ans_idx = ans_idx 41 | 42 | @classmethod 43 | def from_doc(cls, echain, event_list, n_choices=5): 44 | excludes = {} 45 | for e in echain: 46 | excludes[e.__repr__()] = e 47 | 48 | # pick first event and one coreferenced event 49 | ridx = random.randint(1, len(echain)-1) 50 | query_event = echain[0] 51 | ans_event = echain[ridx] 52 | 53 | # pick 5 un-coreferenced events 54 | choices = [] 55 | while len(choices) < n_choices: 56 | c = event_list[random.randint(0, len(event_list)-1)] 57 | if c.__repr__() not in excludes: 58 | choices.append(c) 59 | excludes[c.__repr__()] = c 60 | ans_idx = random.randint(0, n_choices-1) 61 | choices[ans_idx] = ans_event 62 | return cls(query_event, ans_idx, choices) 63 | 64 | 65 | class CorefBinaryQuestion: 66 | def __init__(self, e1, e2, label): 67 | self.e1 = e1 68 | self.e2 = e2 69 | self.label = label 70 | 71 | def __repr__(self): 72 | return "{}, {}, {}".format(self.e1, self.e2, self.label) 73 | 74 | 75 | class DiscourseQuestion: 76 | def __init__(self, rel, ans_idx, choices): 77 | self.rel = rel 78 | self.choices = choices 79 | self.ans_idx = ans_idx 80 | 81 | @classmethod 82 | def from_relation(cls, drel, epool, n_choices=5): 83 | ans_idx = random.randint(0, n_choices-1) 84 | choices = [] 85 | while len(choices) < n_choices: 86 | c = epool[random.randint(0, len(epool)-1)] 87 | if c not in choices and c != drel.e1 and c != drel.e2: 88 | choices.append(c) 89 | choices[ans_idx] = drel.e2 90 | return cls(drel, ans_idx, choices) 91 | 92 | 93 | class MCNCEvent(Event): 94 | def __init__(self, pred, dep, arg0, arg0_head, arg1, arg1_head, arg2, arg2_head, 95 | sentiment, ani0, ani1, ani2, sent): 96 | super(MCNCEvent, self).__init__(pred, arg0, arg0_head, arg1, arg1_head, 97 | arg2, arg2_head, sentiment, ani0, ani1, ani2) 98 | self.dep = dep 99 | rep = {'\n': ' ', '::': ' '} 100 | rep = dict((re.escape(k), v) for k, v in rep.iteritems()) 101 | pat = re.compile("|".join(rep.keys())) 102 | self.sentence = pat.sub(lambda m: rep[re.escape(m.group(0))], sent) 103 | if sys.version_info < (3, 0): 104 | self.sentence = self.sentence.encode('ascii', 'ignore') 105 | 106 | def __repr__(self): 107 | return "({}::{}::{}::{}::{}::{}::{}::{}::{}::{}::{}::{}::{})".format( 108 | self.pred, self.dep, 109 | self.arg0, self.arg0_head, 110 | self.arg1, self.arg1_head, 111 | self.arg2, self.arg2_head, 112 | self.sentiment, self.ani0, 113 | self.ani1, self.ani2, self.sentence) 114 | 115 | @classmethod 116 | def from_string(cls, line): 117 | raise NotImplementedError 118 | # line = line.rstrip("\n")[1:-1] 119 | # sp = line.split('::') 120 | # obj = cls(sp[0], sp[1], sp[2], sp[3], sp[4], sp[5], sp[6], 121 | # sp[7], sp[8], sp[9], sp[10], sp[11], sp[12]) 122 | # return obj 123 | 124 | @classmethod 125 | def from_json(cls, e): 126 | pred = e['predicate'] 127 | # only use the first sub-argument for now 128 | arg0_head = e['arg0'][0] if 'arg0' in e else indices.NO_ARG 129 | arg0 = e['arg0_text'][0] if 'arg0_text' in e else indices.NO_ARG 130 | 131 | arg1_head = e['arg1'][0] if 'arg1' in e else indices.NO_ARG 132 | arg1 = e['arg1_text'][0] if 'arg1_text' in e else indices.NO_ARG 133 | 134 | arg2_head = e['arg2'][0] if 'arg2' in e else indices.NO_ARG 135 | arg2 = e['arg2_text'][0] if 'arg2_text' in e else indices.NO_ARG 136 | sentiment = e['sentiment'] if 'sentiment' in e else None 137 | ani0 = e['ani0'][0] if 'ani0' in e else indices.UNKNOWN_ANIMACY 138 | ani1 = e['ani1'][0] if 'ani1' in e else indices.UNKNOWN_ANIMACY 139 | ani2 = e['ani2'][0] if 'ani2' in e else indices.UNKNOWN_ANIMACY 140 | dep = e['dep'] 141 | sent = e['sentence'] 142 | obj = cls(pred, dep, arg0, arg0_head, arg1, arg1_head, arg2, arg2_head, 143 | sentiment, ani0, ani1, ani2, sent) 144 | return obj 145 | 146 | 147 | # MCNC Question 148 | class Question: 149 | def __init__(self, q_idx, ans_idx, choices, echain): 150 | self.echain = echain 151 | self.q_idx = q_idx 152 | self.ans_idx = ans_idx 153 | self.choices = choices 154 | 155 | def get_contexts(self): 156 | contexts = self.echain[:self.q_idx] 157 | if len(self.echain) > self.q_idx+1: 158 | contexts += self.echain[self.q_idx+1:] 159 | return contexts 160 | 161 | def __repr__(self): 162 | return '{}, {}, {}, {}'.format( 163 | self.q_idx, 164 | self.ans_idx, 165 | self.choices, 166 | self.echain) 167 | 168 | @classmethod 169 | def from_event_chain(cls, echain, epool, n_choices=5, fixed_len=9): 170 | echain = echain[:fixed_len] 171 | ehash = {} 172 | for e in echain: 173 | ehash[e.__repr__()] = 1 174 | 175 | all_args = [] 176 | # randomly pick shared arguments from echain 177 | for e in echain: 178 | if e.arg0_head != indices.NO_ARG: 179 | all_args.append((e.arg0_head, e.arg0)) 180 | if e.arg1_head != indices.NO_ARG: 181 | all_args.append((e.arg1_head, e.arg1)) 182 | if e.arg2_head != indices.NO_ARG: 183 | all_args.append((e.arg2_head, e.arg2)) 184 | 185 | q_idx = fixed_len-1 186 | ans_idx = random.randint(0, n_choices-1) 187 | choices = [] 188 | while len(choices) < n_choices: 189 | c = epool[random.randint(0, len(epool)-1)] 190 | new_c = deepcopy(c) 191 | 192 | # replace with protagonist 193 | rpos = random.randint(0, 2) 194 | rarg = all_args[random.randint(0, len(all_args)-1)] 195 | if rpos == 0: 196 | new_c.arg0_head, new_c.arg0 = rarg 197 | elif rpos == 1: 198 | new_c.arg1_head, new_c.arg1 = rarg 199 | else: 200 | new_c.arg2_head, new_c.arg2 = rarg 201 | 202 | if new_c.__repr__() not in ehash and new_c not in choices: 203 | choices.append(new_c) 204 | choices[ans_idx] = echain[q_idx] 205 | return cls(q_idx, ans_idx, choices, echain) 206 | 207 | 208 | class MCNSQuestion: 209 | def __init__(self, echain, choice_lists, ans_idxs): 210 | self.echain = echain 211 | self.choice_lists = choice_lists 212 | self.ans_idxs = ans_idxs 213 | assert len(ans_idxs) == len(echain)-1 214 | assert len(ans_idxs) == len(choice_lists) 215 | 216 | def __repr__(self): 217 | return '{}, {}, {}'.format( 218 | self.ans_idxs, 219 | self.choice_lists, 220 | self.echain) 221 | 222 | @classmethod 223 | def from_event_chain(cls, echain, epool, n_choices=5, chain_len=5): 224 | ehash = {} 225 | for e in echain: 226 | ehash[e.__repr__()] = 1 227 | 228 | assert len(echain) >= chain_len 229 | subechain = echain[:chain_len] 230 | 231 | all_args = [] 232 | # randomly pick shared arguments from echain 233 | for e in subechain: 234 | if e.arg0_head != indices.NO_ARG: 235 | all_args.append((e.arg0_head, e.arg0)) 236 | if e.arg1_head != indices.NO_ARG: 237 | all_args.append((e.arg1_head, e.arg1)) 238 | if e.arg2_head != indices.NO_ARG: 239 | all_args.append((e.arg2_head, e.arg2)) 240 | 241 | choice_lists = [] 242 | ans_idxs = [] 243 | for i in range(1, chain_len): 244 | cs = [] 245 | choice_hash = {} 246 | while len(cs) < n_choices: 247 | c = epool[random.randint(0, len(epool)-1)] 248 | new_c = deepcopy(c) 249 | 250 | # replace with protagonist 251 | rpos = random.randint(0, 2) 252 | rarg = all_args[random.randint(0, len(all_args)-1)] 253 | if rpos == 0: 254 | new_c.arg0_head, new_c.arg0 = rarg 255 | elif rpos == 1: 256 | new_c.arg1_head, new_c.arg1 = rarg 257 | else: 258 | new_c.arg2_head, new_c.arg2 = rarg 259 | if new_c.__repr__() not in ehash and new_c.__repr__() not in choice_hash: 260 | cs.append(new_c) 261 | choice_hash[new_c.__repr__()] = 1 262 | ans_idx = random.randint(0, n_choices-1) 263 | cs[ans_idx] = subechain[i] 264 | choice_lists.append(cs) 265 | ans_idxs.append(ans_idx) 266 | return cls(subechain, choice_lists, ans_idxs) 267 | 268 | 269 | def get_event_emb(model, e, e2idx, embeddings, device): 270 | idx = e2idx[e.__repr__()] 271 | emb = embeddings[idx] 272 | return emb 273 | 274 | 275 | def predict_mcnc(model, q, e2idx, embeddings, rtype, device): 276 | contexts = q.get_contexts() 277 | ctx_embs = [get_event_emb(model, e, e2idx, embeddings, device) 278 | for e in contexts] 279 | choice_embs = [get_event_emb(model, e, e2idx, embeddings, device) 280 | for e in q.choices] 281 | rel_emb = model.rel_embeddings(rtype) 282 | # take avg of scores between event pairs 283 | energies = [] 284 | for ch in choice_embs: 285 | sub_es = [] 286 | for ctx in ctx_embs: 287 | e = model._calc(ch, ctx, rel_emb) 288 | e = torch.sum(e, 1) # L1-norm 289 | if model.norm > 1: 290 | e = torch.pow(e, 1.0 / model.norm) 291 | sub_es.append(e[0]) 292 | avg = sum(sub_es) / len(sub_es) 293 | energies.append(avg) 294 | energies = torch.Tensor(energies).to(device) 295 | v, idx = energies.min(0) 296 | return idx 297 | 298 | 299 | def scoring_cosine_similarity(emb1s, emb2s): 300 | cs = pairwise.cosine_similarity(emb1s, emb2s) 301 | # shift from [-1, 1] to [0, 1] 302 | ret = (cs + 1) / 2.0 303 | return ret 304 | 305 | 306 | def _we_transition_prob(ts1, ts2): 307 | ts1_embs = torch.stack(ts1, dim=0) 308 | ts2_embs = torch.stack(ts2, dim=0) 309 | 310 | scores = pairwise.cosine_similarity(ts1_embs.cpu(), ts2_embs.cpu()) 311 | # shift from [-1, 1] to [0, 1] to avoid negative values 312 | scores = (scores + 1) / 2.0 313 | 314 | probs = scores / scores.sum(axis=1, keepdims=True) 315 | assert (probs < 0).sum() == 0 316 | return torch.from_numpy(probs) 317 | 318 | 319 | def _ev_transition_prob(ts1, ts2, rel_emb, model): 320 | tm = torch.zeros((len(ts1), len(ts2)), dtype=torch.float32) 321 | for i in range(len(ts1)): 322 | for j in range(len(ts2)): 323 | e = model._calc(ts1[i], ts2[j], rel_emb) 324 | e = torch.sum(e, 1) # L1-norm 325 | tm[i][j] = torch.pow(e, 1.0 / model.norm) if model.norm > 1 else e 326 | tm = 1.0 / tm 327 | probs = tm / tm.sum(1) 328 | assert (probs < 0).sum().tolist() == 0 329 | return probs 330 | 331 | 332 | def transition_prob_matrix(model, q, e2idx, ev_embeddings, w_embeddings, rtype): 333 | timestamp_we_embs = [[w_embeddings[e2idx[q.echain[0].__repr__()]]]] 334 | timestamp_ev_embs = [[ev_embeddings[e2idx[q.echain[0].__repr__()]]]] 335 | for clist in q.choice_lists: 336 | choice_we_embs = [w_embeddings[e2idx[c.__repr__()]] for c in clist] 337 | choice_ev_embs = [ev_embeddings[e2idx[c.__repr__()]] for c in clist] 338 | timestamp_we_embs.append(choice_we_embs) 339 | timestamp_ev_embs.append(choice_ev_embs) 340 | 341 | rel_emb = model.rel_embeddings(rtype) 342 | # first transition 343 | all_we_probs, all_ev_probs = [], [] 344 | we_probs = _we_transition_prob(timestamp_we_embs[0], timestamp_we_embs[1]) 345 | ev_probs = _ev_transition_prob(timestamp_ev_embs[0], timestamp_ev_embs[1], rel_emb, model) 346 | all_we_probs.append(we_probs) 347 | all_ev_probs.append(ev_probs) 348 | 349 | # the rest 350 | for i in range(1, len(timestamp_ev_embs)-1): 351 | we_probs = _we_transition_prob(timestamp_we_embs[i], timestamp_we_embs[i+1]) 352 | ev_probs = _ev_transition_prob(timestamp_ev_embs[i], timestamp_ev_embs[i+1], rel_emb, model) 353 | all_we_probs.append(we_probs) 354 | all_ev_probs.append(ev_probs) 355 | return all_we_probs, all_ev_probs 356 | 357 | 358 | def transition_prob_matrix_all_orders(model, q, e2idx, ev_embeddings, w_embeddings, rtype): 359 | timestamp_we_embs = [[w_embeddings[e2idx[q.echain[0].__repr__()]]]] 360 | timestamp_ev_embs = [[ev_embeddings[e2idx[q.echain[0].__repr__()]]]] 361 | for clist in q.choice_lists: 362 | choice_we_embs = [w_embeddings[e2idx[c.__repr__()]] for c in clist] 363 | choice_ev_embs = [ev_embeddings[e2idx[c.__repr__()]] for c in clist] 364 | timestamp_we_embs.append(choice_we_embs) 365 | timestamp_ev_embs.append(choice_ev_embs) 366 | 367 | rel_emb = model.rel_embeddings(rtype) 368 | # first transition 369 | all_we_probs, all_ev_probs = [], [] 370 | for i in range(len(timestamp_we_embs)): 371 | tmp_we, tmp_ev = [], [] 372 | for j in range(len(timestamp_we_embs)): 373 | if i >= j: 374 | tmp_we.append(None) 375 | tmp_ev.append(None) 376 | else: 377 | _we = _we_transition_prob(timestamp_we_embs[i], timestamp_we_embs[j]) 378 | _ev = _ev_transition_prob(timestamp_ev_embs[i], timestamp_ev_embs[j], rel_emb, model) 379 | tmp_we.append(_we) 380 | tmp_ev.append(_ev) 381 | all_we_probs.append(tmp_we) 382 | all_ev_probs.append(tmp_ev) 383 | return all_we_probs, all_ev_probs 384 | 385 | 386 | def viterbi(tpms): 387 | n_choices = tpms[1].shape[0] 388 | trellis = torch.zeros((n_choices, len(tpms)), dtype=torch.float32) 389 | backtrace = torch.ones((n_choices, len(tpms)), dtype=torch.int64) * -1 390 | 391 | t1 = time.time() 392 | trellis[:, 0] = tpms[0] 393 | 394 | for t in range(1, len(tpms)): 395 | for j in range(n_choices): 396 | tmp_probs = torch.zeros(n_choices, dtype=torch.float32) 397 | for k in range(n_choices): 398 | p = trellis[k, t-1] * tpms[t][k][j] 399 | tmp_probs[k] = p 400 | trellis[j, t] = torch.max(tmp_probs) 401 | backtrace[j, t] = torch.argmax(tmp_probs) 402 | 403 | # backtrace 404 | tokens = [trellis[:, -1].argmax()] 405 | for i in xrange(len(tpms)-1, 0, -1): 406 | tokens.append(backtrace[tokens[-1], i]) 407 | preds = tokens[::-1] 408 | logging.debug('viterbi: {} s'.format(time.time()-t1)) 409 | return preds 410 | 411 | 412 | def markov_baseline(tpms): 413 | # base on the previous prediction 414 | preds = [] 415 | for t in range(len(tpms)): 416 | if t == 0: 417 | pred = torch.argmax(tpms[t]) 418 | else: 419 | previous_state = preds[t-1] 420 | pred = torch.argmax(tpms[t][previous_state]) 421 | preds.append(pred) 422 | return preds 423 | 424 | 425 | def markov_skyline(tpms, ans_idxs, fix_end, n_choices=5): 426 | # base on the correct previous state 427 | # preds = [] 428 | # for t in range(len(tpms)): 429 | # if t == 0: 430 | # pred = torch.argmax(tpms[t]) 431 | # else: 432 | # previous_state = ans_idxs[t-1] 433 | # pred = torch.argmax(tpms[t][previous_state]) 434 | # preds.append(pred) 435 | 436 | # consider all previous and future gold states 437 | # when predicting each step 438 | preds = [] 439 | for target in range(1, len(tpms)): 440 | scores = torch.zeros(n_choices, dtype=torch.float32) 441 | for src in range(len(tpms)): 442 | if src == target: 443 | continue 444 | elif src < target: 445 | m = tpms[src][target] 446 | if src == 0: 447 | scores += m.squeeze() 448 | else: 449 | ans_idx = ans_idxs[src-1] 450 | scores += m[ans_idx, :] 451 | else: 452 | m = tpms[target][src] 453 | ans_idx = ans_idxs[src-1] 454 | scores += m[:, ans_idx] 455 | _, pred = torch.max(scores, 0) 456 | preds.append(pred) 457 | return preds 458 | 459 | 460 | def _predict_sequence(model, q, e2idx, ev_embeddings, w_embeddings, rtype, device, fix_end=False, inference_model='Viterbi'): 461 | if inference_model != 'Skyline': 462 | we_tpms, ev_tpms = transition_prob_matrix(model, q, e2idx, ev_embeddings, w_embeddings, rtype) 463 | # if fix end, we simply make the last tranition probs to zero, except the correct one 464 | if fix_end: 465 | last_we_tpm = we_tpms[-1] 466 | last_ev_tpm = ev_tpms[-1] 467 | for i in range(last_ev_tpm.shape[0]): 468 | for j in range(last_ev_tpm.shape[1]): 469 | if j != q.ans_idxs[-1]: 470 | last_we_tpm[i, j] = 0.0 471 | last_ev_tpm[i, j] = 0.0 472 | # mix_tpms = [(we_tpms[i] + ev_tpms[i]) / 2.0 for i in range(len(we_tpms))] 473 | else: 474 | we_tpms, ev_tpms = transition_prob_matrix_all_orders(model, q, e2idx, ev_embeddings, w_embeddings, rtype) 475 | 476 | if inference_model == 'Viterbi': 477 | we_preds = viterbi(we_tpms) 478 | ev_preds = viterbi(ev_tpms) 479 | # mix_preds = viterbi(mix_tpms) 480 | elif inference_model == 'Baseline': 481 | we_preds = markov_baseline(we_tpms) 482 | ev_preds = markov_baseline(ev_tpms) 483 | # mix_preds = markov_baseline(mix_tpms) 484 | elif inference_model == 'Skyline': 485 | we_preds = markov_skyline(we_tpms, q.ans_idxs, fix_end) 486 | ev_preds = markov_skyline(ev_tpms, q.ans_idxs, fix_end) 487 | # mix_preds = markov_skyline(mix_tpms, q.ans_idxs, fix_end) 488 | else: 489 | raise ValueError('unsupported inference model {}'.format(args.inference_model)) 490 | 491 | return we_preds, ev_preds 492 | 493 | 494 | def predict_mcns(model, q, e2idx, ev_embeddings, w_embeddings, rtype, device, inference_model): 495 | return _predict_sequence(model, q, e2idx, ev_embeddings, w_embeddings, rtype, device, fix_end=False, inference_model=inference_model) 496 | 497 | 498 | def predict_mcne(model, q, e2idx, ev_embeddings, w_embeddings, rtype, device, inference_model): 499 | return _predict_sequence(model, q, e2idx, ev_embeddings, w_embeddings, rtype, device, fix_end=True, inference_model=inference_model) 500 | -------------------------------------------------------------------------------- /dnee/models/predicate_gr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from collections import OrderedDict 4 | from ast import literal_eval 5 | 6 | import six 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from scipy import sparse 12 | import progressbar 13 | from scipy import spatial 14 | import parse 15 | 16 | from ..events import indices 17 | from ..events.predicate_gr import Event 18 | from .. import utils 19 | 20 | 21 | def index_predicategr_from_event(e, e2idx): 22 | idxs = [] 23 | key = '{}:{}'.format(e.pred, e.dep) 24 | idx = e2idx[key] if key in e2idx else 0 25 | idxs.append(idx) 26 | key = 'arg:{}'.format(e.arg0_head) 27 | idx = e2idx[key] if key in e2idx else 0 28 | idxs.append(idx) 29 | key = 'arg:{}'.format(e.arg1_head) 30 | idx = e2idx[key] if key in e2idx else 0 31 | idxs.append(idx) 32 | key = 'arg:{}'.format(e.arg2_head) 33 | idx = e2idx[key] if key in e2idx else 0 34 | idxs.append(idx) 35 | return idxs 36 | 37 | 38 | event_parser = None 39 | def index_predicategr_from_estr(estr, dep, e2idx): 40 | global event_parser 41 | if event_parser is None: 42 | event_parser = parse.compile("{pred}({arg0},{arg1},{arg2})") 43 | 44 | e = event_parser.parse(estr) 45 | key = '{}:{}'.format(e['pred'], dep) 46 | idx = e2idx[key] if key in e2idx else 0 47 | idxs.append(idx) 48 | key = 'arg:{}'.format(e['arg0']) 49 | idx = e2idx[key] if key in e2idx else 0 50 | idxs.append(idx) 51 | key = 'arg:{}'.format(e['arg1']) 52 | idx = e2idx[key] if key in e2idx else 0 53 | idxs.append(idx) 54 | key = 'arg:{}'.format(e['arg2']) 55 | idx = e2idx[key] if key in e2idx else 0 56 | idxs.append(idx) 57 | return idxs 58 | 59 | 60 | def index_predicategr_example(e1, e2, nestr, e2idx): 61 | global event_parser 62 | idxs = [] 63 | idxs += index_predicategr_from_event(e1, e2idx) 64 | idxs += index_predicategr_from_event(e2, e2idx) 65 | idxs += index_predicategr_from_estr(nestr, e2.dep, e2idx) 66 | return idxs 67 | 68 | 69 | def load_word_embeddings(fpath, use_torch=False, skip_first_line=True): 70 | we = {} 71 | with open(fpath, 'r') as fr: 72 | for line in fr: 73 | if skip_first_line: 74 | skip_first_line = False 75 | continue 76 | line = line.rstrip() 77 | sp = line.split(" ") 78 | emb = np.squeeze(np.array([sp[1:]], dtype=np.float32)) 79 | we[sp[0]] = torch.from_numpy(emb) if use_torch else emb 80 | return we 81 | 82 | 83 | class PredicateGrBase(object): 84 | def __init__(self, **kwargs): 85 | pass 86 | 87 | def score(self, v1, v2): 88 | raise NotImplementedError 89 | 90 | def cosine_similarity(self, v1, v2): 91 | if v1.sum() == 0 and v2.sum() == 0: 92 | return 1.0 93 | return 1.0 - spatial.distance.cosine(v1, v2) 94 | 95 | def embed_event(self, e): 96 | raise NotImplementedError 97 | 98 | def predict_mcnc(self, ctx_events, choices): 99 | raise NotImplementedError 100 | 101 | 102 | class EventComp(nn.Module, PredicateGrBase): 103 | def __init__(self, config, verbose=True, dropout=0.3): 104 | super(EventComp, self).__init__() 105 | self.event_dim = config['event_dim'] 106 | self.embedding_file = config['embedding_file'] 107 | self.e2idx = self.idx2e = None 108 | self.load_word2vec_event_emb() 109 | 110 | # dimensions are form the original paper 111 | self.arg_comp1_a1 = nn.Linear(self.event_dim*4, 600) 112 | self.tanh1_a1 = nn.Tanh() 113 | self.d1_a1 = nn.Dropout(p=dropout) 114 | self.arg_comp1_a2 = nn.Linear(self.event_dim*4, 600) 115 | self.tanh1_a2 = nn.Tanh() 116 | self.d1_a2 = nn.Dropout(p=dropout) 117 | 118 | self.arg_comp2_a1 = nn.Linear(600, 300) 119 | self.tanh2_a1 = nn.Tanh() 120 | self.d2_a1 = nn.Dropout(p=dropout) 121 | self.arg_comp2_a2 = nn.Linear(600, 300) 122 | self.tanh2_a2 = nn.Tanh() 123 | self.d2_a2 = nn.Dropout(p=dropout) 124 | 125 | self.event_comp1 = nn.Linear(600, 400) 126 | self.tanh_e1 = nn.Tanh() 127 | self.d_e1 = nn.Dropout(p=dropout) 128 | self.event_comp2 = nn.Linear(400, 200) 129 | self.tanh_e2 = nn.Tanh() 130 | self.d_e2 = nn.Dropout(p=dropout) 131 | self.event_comp3 = nn.Linear(200, 1) 132 | self.sigmoid = nn.Sigmoid() 133 | 134 | self.init_weights() 135 | 136 | def init_weights(self): 137 | nn.init.xavier_uniform_(self.arg_comp1_a1.weight.data) 138 | nn.init.xavier_uniform_(self.arg_comp1_a2.weight.data) 139 | nn.init.xavier_uniform_(self.arg_comp2_a1.weight.data) 140 | nn.init.xavier_uniform_(self.arg_comp2_a2.weight.data) 141 | nn.init.xavier_uniform_(self.event_comp1.weight.data) 142 | nn.init.xavier_uniform_(self.event_comp2.weight.data) 143 | nn.init.xavier_uniform_(self.event_comp3.weight.data) 144 | 145 | def predict_mcnc(self, ctx_events, choices): 146 | # the e2idx is not used 147 | all_idxs = [] 148 | for i_ch, ch in enumerate(choices): 149 | ch_idx = index_predicategr_from_event(ch, self.e2idx) 150 | epairs = [] 151 | for i_ctx, ctx in enumerate(ctx_events): 152 | ctx_idx = index_predicategr_from_event(ctx, self.e2idx) 153 | epairs.append(ch_idx + ctx_idx) 154 | all_idxs.append(epairs) 155 | all_idxs = torch.LongTensor(all_idxs) 156 | n_pair_per_ch = all_idxs.shape[1] 157 | input_dim = 8 158 | assert all_idxs.shape[2] == input_dim 159 | all_idxs = all_idxs.view(-1, input_dim) 160 | 161 | scores = self.forward(all_idxs).view(len(choices), n_pair_per_ch) 162 | scores = torch.sum(scores, dim=1) / n_pair_per_ch 163 | _, pred = torch.max(scores, 0) 164 | return pred.item() 165 | 166 | def save(self, fpath): 167 | torch.save(self.state_dict(), fpath) 168 | 169 | def load(self, fld): 170 | fpath = os.path.join(fld, 'model.pt') 171 | self.load_state_dict(torch.load(fpath, map_location=lambda storage, location: storage)) 172 | 173 | def load_word2vec_event_emb(self): 174 | self.e2idx, self.idx2e = {}, {} 175 | n_embs = len([line for line in open(self.embedding_file)]) 176 | embs = np.zeros((n_embs+1, self.event_dim)) 177 | # leave index 0 as zero vector 178 | with open(self.embedding_file, 'r') as fr: 179 | for i, line in enumerate(fr): 180 | line = line.rstrip('\n') 181 | sp = line.split(" ") 182 | if i == 0: 183 | self.n_embs = int(sp[0]) 184 | assert self.event_dim == int(sp[1]) 185 | continue 186 | ns = filter(None, sp[1:]) 187 | ns = [float(n) for n in ns] 188 | emb = np.array(ns, dtype=np.float32) 189 | embs[i] = emb 190 | self.e2idx[sp[0]] = i 191 | self.idx2e[i] = sp[0] 192 | self.embs = nn.Embedding.from_pretrained(torch.FloatTensor(embs)) 193 | 194 | def forward(self, x): 195 | batch_size = x.shape[0] 196 | x_embs = self.embs(x) 197 | e1 = x_embs[:, :4].view(batch_size, -1) 198 | e2 = x_embs[:, 4:8].view(batch_size, -1) 199 | 200 | out1_e1 = self.d1_a1(self.tanh1_a1(self.arg_comp1_a1(e1))) 201 | out1_e1 = self.d2_a1(self.tanh2_a1(self.arg_comp2_a1(out1_e1))) 202 | out1_e2 = self.d1_a2(self.tanh1_a2(self.arg_comp1_a2(e2))) 203 | out1_e2 = self.d2_a2(self.tanh2_a2(self.arg_comp2_a2(out1_e2))) 204 | 205 | epair = torch.cat((out1_e1, out1_e2), dim=1) 206 | 207 | out2 = self.d_e1(self.tanh_e1(self.event_comp1(epair))) 208 | out2 = self.d_e2(self.tanh_e2(self.event_comp2(out2))) 209 | out2 = self.sigmoid(self.event_comp3(out2)) 210 | return out2.squeeze() 211 | 212 | def loss_func(self, cohs, eps=1e-12, lambda_l2=1e-3): 213 | pos_coh, neg_coh = cohs 214 | m = pos_coh.shape[0] 215 | loss = torch.sum(-torch.log(pos_coh + eps) - torch.log(1 - neg_coh + eps)) / m 216 | # l2 doesn't work well here, let's do dropout 217 | # l2_reg = None 218 | # for w in self.parameters(): 219 | # if l2_reg is None: 220 | # l2_reg = torch.sum(w ** 2) 221 | # else: 222 | # l2_reg += torch.sum(w ** 2) 223 | # loss += (lambda_l2 * l2_reg) 224 | return loss 225 | 226 | 227 | class Word2Vec(PredicateGrBase): 228 | def __init__(self, config, verbose=True): 229 | super(Word2Vec, self).__init__() 230 | fpath = config['embedding_file'] 231 | self.embs = load_word_embeddings(fpath, skip_first_line=config['emb_file_skip_first_line']) 232 | self.dim = self.embs[self.embs.keys()[0]].shape[0] 233 | 234 | def score(self, v1, v2): 235 | return self.cosine_similarity(v1, v2) 236 | 237 | def aggr_emb(self, events): 238 | emb = np.zeros(self.dim, dtype=np.float32) 239 | cnt = 0 240 | for e in events: 241 | if e.pred in self.embs: 242 | emb += self.embs[e.pred] 243 | cnt += 1 244 | if e.arg0_head in self.embs: 245 | emb += self.embs[e.arg0_head] 246 | cnt += 1 247 | if e.arg1_head in self.embs: 248 | emb += self.embs[e.arg1_head] 249 | cnt += 1 250 | if e.arg2_head in self.embs: 251 | emb += self.embs[e.arg2_head] 252 | cnt += 1 253 | return emb if cnt > 0 else np.random.uniform( 254 | low=-1.0/self.dim, high=1.0/self.dim, size=self.dim) 255 | 256 | def predict_mcnc(self, ctx_events, choices): 257 | ctx_emb = self.aggr_emb(ctx_events) 258 | 259 | ch_embs = [] 260 | for ch in choices: 261 | ch_emb = self.aggr_emb([ch]) 262 | ch_embs.append(ch_emb) 263 | 264 | max_score = -1 265 | pred = -1 266 | scores = [0.0] * len(choices) 267 | for i, ch in enumerate(choices): 268 | scores[i] = self.score(ctx_emb, ch_embs[i]) 269 | if scores[i] > max_score: 270 | max_score = scores[i] 271 | pred = i 272 | return pred 273 | 274 | 275 | class Word2VecEvent(PredicateGrBase): 276 | def __init__(self, config, verbose=True): 277 | super(Word2VecEvent, self).__init__() 278 | fpath = config['embedding_file'] 279 | self.embs = load_word_embeddings(fpath) 280 | self.dim = self.embs[self.embs.keys()[0]].shape[0] 281 | 282 | def _make_predicate_key(self, pred, dep): 283 | return "{}:{}".format(pred, dep) 284 | 285 | def _make_arg_key(self, arg): 286 | return "arg:{}".format(arg) 287 | 288 | def get_event_emb(self, e): 289 | """summation 290 | """ 291 | emb = np.zeros(self.dim, dtype=np.float32) 292 | es = [] 293 | key = self._make_predicate_key(e.pred, e.dep) 294 | if key in self.embs: 295 | emb += self.embs[key] 296 | es.append(key) 297 | key = self._make_arg_key(e.arg0_head) 298 | if key in self.embs: 299 | emb += self.embs[key] 300 | es.append(key) 301 | key = self._make_arg_key(e.arg1_head) 302 | if key in self.embs: 303 | emb += self.embs[key] 304 | es.append(key) 305 | key = self._make_arg_key(e.arg2_head) 306 | if key in self.embs: 307 | emb += self.embs[key] 308 | es.append(key) 309 | return es, emb 310 | 311 | def score(self, v1, v2): 312 | return self.cosine_similarity(v1, v2) 313 | 314 | def predict_mcnc(self, ctx_events, choices): 315 | ctx_es = [] 316 | for ctx in ctx_events: 317 | es, ctx_emb = self.get_event_emb(ctx) 318 | ctx_es += es 319 | logging.debug(ctx_es) 320 | 321 | ch_embs = [] 322 | for i, ch in enumerate(choices): 323 | ch_es, ch_emb = self.get_event_emb(ch) 324 | logging.debug("ch {}: {}".format(i, ch_es)) 325 | ch_embs.append(ch_emb) 326 | 327 | max_score = -1 328 | pred = -1 329 | scores = [0.0] * len(choices) 330 | for i, ch in enumerate(choices): 331 | scores[i] = self.score(ctx_emb, ch_embs[i]) 332 | if scores[i] > max_score: 333 | max_score = scores[i] 334 | pred = i 335 | return pred 336 | 337 | 338 | class BiGram(PredicateGrBase): 339 | def __init__(self, params, verbose=True): 340 | super(BiGram, self).__init__() 341 | # do it later 342 | raise NotImplementedError 343 | 344 | 345 | class CJ08(PredicateGrBase): 346 | def __init__(self, config, verbose=True): 347 | super(CJ08, self).__init__() 348 | self.verbose = verbose 349 | 350 | self.adj_m_fld = config['adj_matrix_folder'] 351 | # self.adj_m = self.load_adj_m(os.path.join(adj_fld, 'adj_m.txt')) 352 | self.e2idx, self.idx2e = utils.load_indices(os.path.join(self.adj_m_fld, 'index_file.txt')) 353 | self.ppmi_m = None 354 | 355 | @staticmethod 356 | def load_adj_m(fpath, use_float=False): 357 | m = {} 358 | with open(fpath) as fr: 359 | for line in fr: 360 | line = line.rstrip('\n') 361 | sp = line.split('\t') 362 | if use_float: 363 | m[literal_eval(sp[0])] = float(sp[1]) 364 | else: 365 | m[literal_eval(sp[0])] = int(sp[1]) 366 | return m 367 | 368 | def train_ppmi(self, efreqs, adj_m ,e2idx, idx2e): 369 | self.ppmi_m = self._ppmi_matrix(efreqs, adj_m, e2idx, idx2e) 370 | 371 | def _adj_total_freq(self, m): 372 | 373 | dsum = m.diagonal().sum() 374 | upper = (m.sum() - dsum) / 2.0 375 | return upper + dsum 376 | 377 | def _ppmi_matrix(self, efreqs, adj_m ,e2idx, idx2e): 378 | ppmi_m = {} 379 | n_combs = sum(adj_m.values()) 380 | freq_sum = sum([f for e, f in six.iteritems(efreqs)]) 381 | 382 | if self.verbose: 383 | logging.info("learning PPMI...") 384 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 385 | bar = progressbar.ProgressBar(widgets=widgets, maxval=len(adj_m)).start() 386 | cnt = 0 387 | for k, v in six.iteritems(adj_m): 388 | idx1, idx2 = k 389 | if (idx2, idx1) in ppmi_m: 390 | continue 391 | v2 = adj_m[(idx2, idx1)] if (idx2, idx1) in adj_m else 0 392 | p_joint = float(v + v2) / n_combs 393 | e1, e2 = idx2e[idx1], idx2e[idx2] 394 | f1, f2 = efreqs[e1], efreqs[e2] 395 | p1 = float(f1) / freq_sum 396 | p2 = float(f2) / freq_sum 397 | ppmi_m[k] = ppmi_m[(idx2, idx1)] = np.log(p_joint / (p1 * p2)) 398 | if self.verbose: 399 | cnt += 1 400 | bar.update(cnt) 401 | if self.verbose: 402 | bar.finish() 403 | return ppmi_m 404 | 405 | def _score_epair(self, estr0, estr1): 406 | assert estr0 in self.e2idx 407 | assert estr1 in self.e2idx 408 | idx0, idx1 = e2idx[estr0], e2idx[estr1] 409 | return self.ppmi_m[idx0, idx1] 410 | 411 | def save(self, fld): 412 | fpath = os.path.join(fld, 'ppmi.txt') 413 | with open(fpath, 'w') as fw: 414 | for k, v in six.iteritems(self.ppmi_m): 415 | fw.write('{}\t{}\n'.format(k, v)) 416 | 417 | def load(self, fld): 418 | fpath = os.path.join(fld, 'ppmi.txt') 419 | logging.info('loading {}...'.format(fpath)) 420 | self.ppmi_m = self.load_adj_m(fpath, use_float=True) 421 | fpath = os.path.join(self.adj_m_fld, 'index_file.txt') 422 | logging.info('loading {}...'.format(fpath)) 423 | self.e2idx, self.idx2e = utils.load_indices(os.path.join(self.adj_m_fld, 'index_file.txt')) 424 | 425 | def predict_mcnc(self, ctx_events, choices): 426 | ctx_idxs = [] 427 | for ctx in choices: 428 | estr = Event.cj08_format(ctx.pred, ctx.dep) 429 | if estr in self.e2idx: 430 | ctx_idxs.append(self.e2idx[estr]) 431 | 432 | ch_estrs = [] 433 | for ch in choices: 434 | estr = Event.cj08_format(ch.pred, ch.dep) 435 | ch_estrs.append(estr) 436 | 437 | max_score = float('-inf') 438 | pred = 0 439 | scores = [0.0] * len(ch_estrs) 440 | for i, ch in enumerate(ch_estrs): 441 | if ch in self.e2idx: 442 | cidx = self.e2idx[ch] 443 | scores[i] = sum([self.ppmi_m[(cidx, idx)] for idx in ctx_idxs if (cidx, idx) in self.ppmi_m]) 444 | if scores[i] > max_score: 445 | max_score = scores[i] 446 | pred = i 447 | return pred 448 | 449 | 450 | class SGNN(nn.Module): 451 | def __init__(self, config): 452 | super(SGNN, self).__init__() 453 | n_preds = self.get_n_preds(config['predicate_indices']) 454 | self.vocab = self.load_argw_vocabs(config['argw_indices']) 455 | self.save = True 456 | self.dir_st = config['skipthought_dir'] 457 | self.embeddings = self._load_embeddings() 458 | import pdb; pdb.set_trace() 459 | 460 | @staticmethod 461 | def get_n_preds(fpath): 462 | pred2idx, idx2pred, _ = indices.load_predicates(fpath) 463 | return len(pred2idx) 464 | 465 | def _load_dictionary(self): 466 | path_dico = os.path.join(self.dir_st, 'dictionary.txt') 467 | with open(path_dico, 'r') as handle: 468 | dico_list = handle.readlines() 469 | dico = {word.strip():idx for idx,word in enumerate(dico_list)} 470 | return dico 471 | 472 | def _get_table_name(self): 473 | return 'btable' 474 | 475 | def _load_emb_params(self): 476 | table_name = self._get_table_name() 477 | path_params = os.path.join(self.dir_st, table_name+'.npy') 478 | params = numpy.load(path_params, encoding='latin1') # to load from python2 479 | return params 480 | 481 | def _make_emb_state_dict(self, dictionary, parameters): 482 | weight = torch.zeros(len(self.vocab)+1, 620) # first dim = zeros -> +1 483 | unknown_params = parameters[dictionary['UNK']] 484 | nb_unknown = 0 485 | for id_weight, word in enumerate(self.vocab): 486 | if word in dictionary: 487 | id_params = dictionary[word] 488 | params = parameters[id_params] 489 | else: 490 | #print('Warning: word `{}` not in dictionary'.format(word)) 491 | params = unknown_params 492 | nb_unknown += 1 493 | weight[id_weight+1] = torch.from_numpy(params) 494 | state_dict = OrderedDict({'weight':weight}) 495 | if nb_unknown > 0: 496 | print('Warning: {}/{} words are not in dictionary, thus set UNK' 497 | .format(nb_unknown, len(dictionary))) 498 | return state_dict 499 | 500 | def _load_embeddings(self, emb_fpath=None): 501 | if self.save: 502 | import hashlib 503 | import pickle 504 | # http://stackoverflow.com/questions/20416468/fastest-way-to-get-a-hash-from-a-list-in-python 505 | hash_id = hashlib.sha256(pickle.dumps(self.vocab, -1)).hexdigest() 506 | path = emb_fpath if emb_fpath else 'st_embedding_'+str(hash_id)+'.pth' 507 | if self.save and os.path.exists(path): 508 | self.embedding = torch.load(path) 509 | else: 510 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab) + 1, 511 | embedding_dim=620, 512 | padding_idx=0, # -> first_dim = zeros 513 | sparse=False) 514 | dictionary = self._load_dictionary() 515 | parameters = self._load_emb_params() 516 | state_dict = self._make_emb_state_dict(dictionary, parameters) 517 | self.embedding.load_state_dict(state_dict) 518 | if self.save: 519 | torch.save(self.embedding, path) 520 | return self.embedding 521 | 522 | def load_argw_vocabs(self, fpath): 523 | key2idx, _, _ = indices.load_argw(fpath) 524 | return key2idx.keys() 525 | --------------------------------------------------------------------------------