├── README.md ├── code ├── context_aggregator.py ├── featureFuncs.py ├── featurize_data.py ├── gurobi_inference.py ├── joint_model.py ├── joint_model_global.py └── utils.py ├── data ├── matres │ ├── test.pickle │ ├── test_docs.txt │ ├── train.pickle │ └── train_docs.txt └── tbd │ ├── dev.pickle │ ├── dev_docs.txt │ ├── test.pickle │ ├── test_docs.txt │ ├── train.pickle │ └── train_docs.txt └── other └── pos_tags.txt /README.md: -------------------------------------------------------------------------------- 1 | **Author**: Rujun Han 2 | 3 | **Date**: Nov 22nd, 2019 4 | 5 | **Title**: Codebase for EMNLP 2019 Paper: [Joint Event and Temporal Relation Extraction with Shared Representations and Structured Prediction](https://www.aclweb.org/anthology/D19-1041.pdf) 6 | 7 | 1. Data processinng. We have preprocessed TB-Dense and MATRES raw data using internal NLP tools at the Information Sciences Institute. These .pickle files are saved in data fold. Download glove.6B.50d.txt into other/ folder. 8 | 2. Featurize data. Run featurize_data.py and context_aggregator.py sequentially. Two folders are created: all_joint/ and all_context/. all_context contains the final files used in the model. 9 | 3. Local Model: run joint_model.py 10 | 4. Global Model: save a pipeline_joint model object from step 3 and then run joint_model_global.py. 11 | 12 | 13 | ### Code Structure (joint_model.py) 14 | 15 | Main() --> [NNClassifier].train_epoch() 16 | 17 | [NNClassifier].train_epoch() --> [NNClassifier]._train() 18 | 19 | -------------------------------> [NNClassifier].predict() 20 | 21 | 22 | 1. Singletask Model. Set args.relation_weights = 0 to train event module; then set args.entity_weights = 0 to train a relation module; use both saved modules to train a pipeline end-to-end model. 23 | > python code/joint_model.py --relation_weights 0 --relation_weights 1.0 --data_type "tbd" --batch 4 --model "singletask/pipeline" --epoch 10 24 | 25 | > python code/joint_model.py --relation_weights 1.0 --entity_weights 0 --data_type "tbd" --batch 4 --model "singletask/pipeline" --epoch 10 26 | 2. Multitask Model. Set args.pipe_epoch = 1000, args.eval_gold = True to train with gold relations only; set args.eval_gold = False to train with candidate relations generated by event module. 27 | > python code/joint_model.py --relation_weights 1.0 --entity_weights 1.0 --data_type "tbd" --batch 4 --model "multitask/pipeline" --eval_gold True --pipe_epoch 1000 --epoch 10 28 | 29 | > python code/joint_model.py --relation_weights 1.0 --entity_weights 1.0 --data_type "tbd" --batch 4 --model "multitask/pipeline" --eval_gold False --pipe_epoch 1000 --epoch 10 30 | 3. Pipeline Joint Model. Set args.pipe_epoch < args.epochs and set args.eval_gold = False to train with candidate relations generated by event module. Our paper used the output model in this step as local model. 31 | > python code/joint_model.py --relation_weights 1.0 --entity_weights 1.0 --data_type "tbd" --batch 4 --model "multitask/pipeline" --eval_gold False --pipe_epoch 5 --epoch 10 32 | 4. Global model. Install [Gurobi](https://www.gurobi.com/documentation/) package and run joint_model_global.py 33 | > python code/joint_model_global.py --relation_weights 1.0 --entity_weights 1.0 --data_type "tbd" --batch 4 --model "multitask/pipeline" --momentum 0.1 --decay 0.1 --entity_weight 0.1 --lr 0.0005 --ent_thresh 0.49 --epoch 5 -------------------------------------------------------------------------------- /code/context_aggregator.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import os 4 | from collections import OrderedDict 5 | def main(args): 6 | 7 | for split in ["train", "dev", "test"]: 8 | 9 | with open('%s/all_joint/%s.pickle' % (args.data_dir, split), 'rb') as handle: 10 | data = pickle.load(handle) 11 | handle.close() 12 | 13 | context_map = OrderedDict([]) 14 | count = 0 15 | for ex in data: 16 | 17 | start = ex[4][1][1][0] 18 | end = ex[4][1][-2][0] 19 | 20 | # use doc_id + start token and end token spans as unique context id 21 | context_id = (ex[0], start, end) 22 | 23 | # sample id, (left id, right id), label_idx, distance, reverse_ind, 24 | # (left_start, left_end, right_start, right_end), pred_ind 25 | rel = (ex[1], ex[2], ex[3], ex[4][3], ex[4][4], ex[4][5:9], ex[4][9]) 26 | if context_id in context_map: 27 | context_map[context_id]['rels'].append(rel) 28 | else: 29 | context_map[context_id] = {'context_id': count, 30 | 'doc_id': ex[0], 31 | 'context': ex[4][0:3], #bert, event_label, pos 32 | 'rels': [rel]} 33 | count += 1 34 | 35 | save_dir = args.data_dir + '/all_context/' 36 | 37 | if not os.path.isdir(save_dir): 38 | os.mkdir(save_dir) 39 | 40 | with open('%s/%s.pickle' % (save_dir, split), 'wb') as handle: 41 | pickle.dump(context_map, handle, protocol=pickle.HIGHEST_PROTOCOL) 42 | handle.close() 43 | 44 | return 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | p = argparse.ArgumentParser() 51 | p.add_argument('-data_dir', type=str) 52 | p.add_argument('-data_type', type=str, default="matres") 53 | 54 | args = p.parse_args() 55 | 56 | args.data_dir += args.data_type 57 | main(args) 58 | -------------------------------------------------------------------------------- /code/featureFuncs.py: -------------------------------------------------------------------------------- 1 | from nltk.corpus import wordnet as wn 2 | import numpy as np 3 | from collections import OrderedDict 4 | import copy 5 | import codecs 6 | 7 | def read_glove(input_dir): 8 | glove_emb = open(input_dir, 'r+', encoding="utf-8") 9 | emb_dict = OrderedDict([(x.strip().split(' ')[0], [float(xx) for xx in x.strip().split(' ')[1:]]) for x in glove_emb]) 10 | 11 | return emb_dict 12 | 13 | def create_pos_dict(nlp_ann): 14 | pos_dict = OrderedDict() 15 | for key, pos in nlp_ann.items(): 16 | # to find the last '[' 17 | key = str(key) 18 | splt = key.rfind('[') 19 | tok = key[:splt] 20 | span = key[splt:] 21 | ### span has to be the key because duplicate tokens can occur in a text 22 | pos_dict[span] = (tok, pos.label) 23 | return pos_dict 24 | 25 | 26 | def ner_features(nlp_ann, left, right): 27 | 28 | all_ner = {str(en.span): (str(en.text()), str(en.entity_type)) for en in nlp_ann.mentions()} 29 | all_ner = {range(int(k.split(':')[0][1:]), int(k.split(':')[1][:-1])):v for k,v in all_ner.items()} 30 | left_tlen = len(left.text.split(' ')) 31 | right_tlen = len(right.text.split(' ')) 32 | 33 | ner_fts = [0, 0] 34 | for k,v in all_ner.items(): 35 | 36 | if left_tlen > 1 and left.span[0] in k and (left.span[1] - 1) in k and v[1] in ['DATE', 'TIME']: 37 | ner_fts[0] = 1 38 | if right_tlen > 1 and right.span[0] in k and (right.span[1] - 1) in k and v[1] in ['DATE', 'TIME']: 39 | ner_fts[1] = 1 40 | 41 | return ner_fts 42 | 43 | def token_idx(left, right, pos_dict): 44 | 45 | all_keys = list(pos_dict.keys()) 46 | 47 | ### to handle case with multiple tokens 48 | lkey_start = str(left[0]) 49 | lkey_end = str(left[1]) 50 | 51 | ### to handle start is not an exact match -- "tomtake", which should be "to take" 52 | lidx_start = 0 53 | while int(all_keys[lidx_start].split(':')[1][:-1]) <= left[0]: 54 | lidx_start += 1 55 | 56 | ### to handle case such as "ACCOUNCED--" or multiple token ends with not match 57 | lidx_end = lidx_start 58 | try: 59 | while left[1] > int(all_keys[lidx_end].split(':')[1][:-1]): 60 | lidx_end += 1 61 | except: 62 | lidx_end -= 1 63 | 64 | rkey_start = str(right[0]) 65 | rkey_end = str(right[1]) 66 | 67 | ridx_start = 0 68 | while int(all_keys[ridx_start].split(':')[1][:-1]) <= right[0]: 69 | ridx_start += 1 70 | 71 | ridx_end = ridx_start 72 | try: 73 | while right[1] > int(all_keys[ridx_end].split(':')[1][:-1]): 74 | ridx_end += 1 75 | except: 76 | ridx_end -= 1 77 | return all_keys, lidx_start, lidx_end, ridx_start, ridx_end 78 | 79 | def compute_ngbrs(all_keys, lidx_start, lidx_end, ridx_start, ridx_end, pos_dict, pos_ngbrs, pos_fts=True): 80 | 81 | idx = int(pos_fts) 82 | if lidx_start < pos_ngbrs: 83 | lngbrs = ['' for k in range(pos_ngbrs - lidx_start)] + [pos_dict[all_keys[k]][idx] for k in list(range(lidx_start)) + list(range(lidx_end + 1, lidx_end+pos_ngbrs+1))] 84 | elif lidx_end > (len(all_keys) - 1 - pos_ngbrs): 85 | lngbrs = [pos_dict[all_keys[k]][idx] for k in list(range(lidx_start - pos_ngbrs, lidx_start)) + list(range(lidx_end + 1, len(all_keys)))] + ['' for k in range(pos_ngbrs - (len(all_keys) - 1 - lidx_end))] 86 | else: 87 | lngbrs = [pos_dict[all_keys[k]][idx] for k in list(range(lidx_start-pos_ngbrs, lidx_start)) + list(range(lidx_end + 1, lidx_end+1+pos_ngbrs))] 88 | 89 | assert len(lngbrs) == 2 * pos_ngbrs 90 | 91 | if ridx_start < pos_ngbrs: 92 | rngbrs = ['' for k in range(pos_ngbrs - ridx_start)] + [pos_dict[all_keys[k]][idx] for k in list(range(ridx_start)) + list(range(ridx_end + 1, ridx_end+pos_ngbrs+1))] 93 | 94 | elif ridx_end > len(all_keys) - pos_ngbrs - 1: 95 | rngbrs = [pos_dict[all_keys[k]][idx] for k in list(range(ridx_start - pos_ngbrs, ridx_start)) + list(range(ridx_end + 1, len\ 96 | (all_keys)))] + ['' for k in range(pos_ngbrs - (len(all_keys) - 1 - ridx_end))] 97 | 98 | else: 99 | rngbrs = [pos_dict[all_keys[k]][idx] for k in list(range(ridx_start-pos_ngbrs, ridx_start)) + list(range(ridx_end + 1, ridx_end+1+pos_ngbrs))] 100 | 101 | assert len(rngbrs) == 2 * pos_ngbrs 102 | 103 | return lngbrs, rngbrs 104 | 105 | 106 | def pos_features(all_keys, lidx_start, lidx_end, ridx_start, ridx_end, pos_dict, pos_ngbrs, pos2idx): 107 | 108 | lngbrs, rngbrs = compute_ngbrs(all_keys, lidx_start, lidx_end, ridx_start, ridx_end, pos_dict, pos_ngbrs) 109 | return [pos2idx[x] if x in pos2idx.keys() else len(pos2idx) for x in [pos_dict[all_keys[lidx_start]][1], pos_dict[all_keys[ridx_start]][1]] + lngbrs + rngbrs] 110 | 111 | 112 | def distance_features(lidx_start, lidx_end, ridx_start, ridx_end): 113 | 114 | ### if multiple tokens, take the mid-point 115 | return (float(lidx_start) + float(lidx_end)) / 2.0 - (float(ridx_start) + float(ridx_end) ) / 2.0 116 | 117 | 118 | def modal_features(lidx_start, lidx_end, ridx_start, ridx_end, pos_dict): 119 | 120 | modals = ['will', 'would', 'can', 'could', 'may', 'might'] 121 | all_tokens = [tok.lower() for tok, span in pos_dict.values()] 122 | 123 | return [1 if md in all_tokens[lidx_end + 1 : ridx_start] else 0 for md in modals] 124 | 125 | def temporal_features(lidx_start, lidx_end, ridx_start, ridx_end, pos_dict): 126 | 127 | temporal = ['before', 'after', 'since', 'afterwards', 'first', 'lastly', 'meanwhile', 'next', 'while', 'then'] 128 | all_tokens = [tok.lower() for tok, span in pos_dict.values()] 129 | 130 | return [1 if tp in all_tokens[lidx_end + 1 : ridx_start] else 0 for tp in temporal] 131 | 132 | 133 | def wordNet_features(lidx_start, lidx_end, ridx_start, ridx_end, pos_dict): 134 | 135 | all_tokens = [tok.lower() for tok, span in pos_dict.values()] 136 | 137 | features = [] 138 | try: 139 | sims = set(wn.synsets(all_tokens[lidx_start])).intersection(wn.synsets(all_tokens[ridx_start])) 140 | if len(sims) > 0: 141 | features.append(1) 142 | else: 143 | features.append(0) 144 | except: 145 | features.append(0) 146 | 147 | try: 148 | lderiv = set(itertools.chain.from_iterable([lemma.derivationally_related_forms() for lemma in wn.lemmas(all_tokens[lidx_start])])) 149 | rderiv = set(itertools.chain.from_iterable([lemma.derivationally_related_forms() for lemma in wn.lemmas(all_tokens[ridx_start])])) 150 | if len(lderiv.intersection(rderiv))> 0: 151 | features.append(1) 152 | else: 153 | features.append(0) 154 | except: 155 | features.append(0) 156 | 157 | return features 158 | 159 | def polarity_features(left, right): 160 | 161 | lp = 1.0 if left.polarity == "POS" else 0.0 162 | rp = 1.0 if right.polarity == "POS" else 0.0 163 | return [lp, rp] 164 | 165 | 166 | def tense_features(left, right): 167 | 168 | tense_dict = {'PAST': 0, 169 | 'PRESENT': 1, 170 | 'INFINITIVE': 2, 171 | 'FUTURE': 3, 172 | 'PRESPART': 4, 173 | 'PASTPART': 5} 174 | 175 | li = np.zeros(len(tense_dict)) 176 | ri = np.zeros(len(tense_dict)) 177 | 178 | if left.tense in tense_dict: 179 | li[tense_dict[left.tense]] = 1.0 180 | 181 | if right.tense in tense_dict: 182 | ri[tense_dict[right.tense]] = 1.0 183 | 184 | return list(li) + list(ri) 185 | -------------------------------------------------------------------------------- /code/featurize_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass 3 | from typing import Tuple 4 | import argparse 5 | from collections import defaultdict, Counter, OrderedDict 6 | import random 7 | import logging as log 8 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig 9 | from pytorch_pretrained_bert.tokenization import BertTokenizer 10 | import os 11 | import torch 12 | from torch.utils import data 13 | import time 14 | from featureFuncs import * 15 | from sklearn.model_selection import train_test_split 16 | 17 | tbd_label_map = OrderedDict([('VAGUE', 'VAGUE'), 18 | ('BEFORE', 'BEFORE'), 19 | ('AFTER', 'AFTER'), 20 | ('SIMULTANEOUS', 'SIMULTANEOUS'), 21 | ('INCLUDES', 'INCLUDES'), 22 | ('IS_INCLUDED', 'IS_INCLUDED'), 23 | ]) 24 | 25 | matres_label_map = OrderedDict([('VAGUE', 'VAGUE'), 26 | ('BEFORE', 'BEFORE'), 27 | ('AFTER', 'AFTER'), 28 | ('SIMULTANEOUS', 'SIMULTANEOUS') 29 | ]) 30 | 31 | @dataclass 32 | class Event(): 33 | id: str 34 | type: str 35 | text: str 36 | tense: str 37 | polarity: str 38 | span: (int, int) 39 | 40 | def create_features(ex, pos2idx, w2i, tokenizer, bert_model): 41 | bert_model.eval() 42 | pos_dict = ex['doc_dictionary'] 43 | ent_labels = ex['event_labels'] 44 | 45 | all_keys, lidx_start, lidx_end, ridx_start, ridx_end = \ 46 | token_idx(ex['left_event'].span, ex['right_event'].span, pos_dict) 47 | 48 | # truncate dictionary into three pieces 49 | left_seq = [pos_dict[x][0] for x in all_keys[:lidx_start]] 50 | right_seq = [pos_dict[x][0] for x in all_keys[ridx_end + 1:]] 51 | in_seq = [pos_dict[x][0] for x in all_keys[lidx_start:ridx_end+1]] 52 | 53 | # find context sentence(s) start and end indices 54 | try: 55 | sent_start = max(loc for loc, val in enumerate(left_seq) if val == '.') + 1 56 | except: 57 | sent_start = 0 58 | 59 | try: 60 | sent_end = ridx_end + 1 + min(loc for loc, val in enumerate(right_seq) if val == '.') 61 | except: 62 | sent_end = len(pos_dict) 63 | 64 | assert sent_start < sent_end 65 | assert sent_start <= lidx_start 66 | assert ridx_end <= sent_end 67 | 68 | # if > 2 sentences, not predicting 69 | pred_ind = True 70 | if len([x for x in in_seq if x == '.']) > 1: 71 | pred_ind = False 72 | 73 | sent_key = all_keys[sent_start:sent_end] 74 | orig_sent = [pos_dict[x][0].lower() for x in sent_key] 75 | sent = [args.w2i[t] if t in args.w2i.keys() else 1 for t in orig_sent] 76 | 77 | pos = [pos2idx[k] if k in pos2idx.keys() else len(pos2idx) for k in [pos_dict[x][1] for x in sent_key]] 78 | ent = [(x, ent_labels[x]) for x in sent_key] 79 | 80 | # calculate events' index in context sentences 81 | lidx_start_s = lidx_start - sent_start 82 | lidx_end_s = lidx_end - sent_start 83 | ridx_start_s = ridx_start - sent_start 84 | ridx_end_s = ridx_end - sent_start 85 | 86 | # bert sentence segment ids 87 | segments_ids = [] # [0, ..., 0, 0, 1, 1, ...., 1] 88 | seg = 0 89 | bert_pos = [] 90 | bert_ent = [] 91 | 92 | # append sentence start 93 | bert_tokens = ["[CLS]"] 94 | # original token to bert word-piece token mapping 95 | orig_to_tok_map = [] 96 | 97 | segments_ids.append(seg) 98 | bert_pos.append("[CLS]") 99 | 100 | # sent_start is non-event by default 101 | bert_ent.append(("[CLS]", 0)) 102 | 103 | for i, token in enumerate(orig_sent): 104 | orig_to_tok_map.append(len(bert_tokens)) 105 | if token == '.': 106 | segments_ids.append(seg) 107 | bert_pos.append("[SEP]") 108 | if seg == 0: 109 | seg = 1 110 | bert_tokens.append("[SEP]") 111 | else: 112 | bert_tokens.append(".") 113 | # sentence sep is non-event by default 114 | bert_ent.append(('[SEP]', 0)) 115 | else: 116 | temp_tokens = tokenizer.tokenize(token) 117 | bert_tokens.extend(temp_tokens) 118 | for t in temp_tokens: 119 | segments_ids.append(seg) 120 | bert_pos.append(pos[i]) 121 | bert_ent.append(ent[i]) 122 | 123 | orig_to_tok_map.append(len(bert_tokens)) 124 | 125 | bert_tokens.append("[SEP]") 126 | bert_pos.append("[SEP]") 127 | bert_ent.append(('[SEP]', 0)) 128 | 129 | segments_ids.append(seg) 130 | assert len(segments_ids) == len(bert_tokens) 131 | assert len(bert_pos) == len(bert_tokens) 132 | 133 | # map original token index into bert (word_piece) index 134 | lidx_start_s = orig_to_tok_map[lidx_start_s] 135 | lidx_end_s = orig_to_tok_map[lidx_end_s + 1] - 1 136 | 137 | ridx_start_s = orig_to_tok_map[ridx_start_s] 138 | ridx_end_s = orig_to_tok_map[ridx_end_s + 1] - 1 139 | 140 | bert_sent = tokenizer.convert_tokens_to_ids(bert_tokens) 141 | 142 | bert_sent = torch.tensor([bert_sent]) 143 | segs_sent = torch.tensor([segments_ids]) 144 | 145 | # use the last layer computed by BERT as token vectors 146 | try: 147 | out, _ = bert_model(bert_sent, segs_sent) 148 | sent = out[-1].squeeze(0).data.numpy() 149 | # rare long sentences may fail > max_sent_len in BERT 150 | except: 151 | sent_len = len(bert_tokens) 152 | print(sent_len, pred_ind) 153 | sent = [] 154 | bert_pos = [] 155 | 156 | # create lexical features for the model 157 | new_fts = [] 158 | new_fts.append(-distance_features(lidx_start, lidx_end, ridx_start, ridx_end)) 159 | #new_fts.extend(polarity_features(ex.left, ex.right)) 160 | #new_fts.extend(tense_features(ex.left, ex.right)) 161 | 162 | return (sent, bert_ent, bert_pos, new_fts, ex['rev'], lidx_start_s, lidx_end_s, ridx_start_s, ridx_end_s, pred_ind) 163 | 164 | def parallel(ex, ex_id, args, tokenizer, bert_model): 165 | label_id = args._label_to_id[ex['rel_type']] 166 | return ex['doc_id'], ex_id, (ex['left_event'].id, ex['right_event'].id), label_id, \ 167 | create_features(ex, args.pos2idx, args.w2i, tokenizer, bert_model) 168 | 169 | 170 | def data_split(train_docs, eval_docs, data, neg_r = 0.0, seed = 7): 171 | train_set = [] 172 | eval_set = [] 173 | train_set_neg = [] 174 | 175 | for s in data: 176 | # dev-set doesn't require unlabled data 177 | if s[0] in eval_docs: 178 | # 0:doc_id, 1:ex.id, 2:(ex.left.id, ex.right.id), 3:label_id, 4:features 179 | eval_set.append(s) 180 | elif s[1][0] in ['L', 'C']: 181 | train_set.append(s) 182 | elif s[1][0] in ['N']: 183 | train_set_neg.append(s) 184 | 185 | random.Random(seed).shuffle(train_set_neg) 186 | n_neg = int(neg_r * len(train_set)) 187 | if n_neg > 0: 188 | train_set.extend(train_set_neg[:n_neg]) 189 | random.Random(seed).shuffle(train_set) 190 | 191 | return train_set, eval_set 192 | 193 | 194 | def split_and_save(train_docs, dev_docs, data, seed, save_dir, nr=0.0): 195 | # first split labeled into train and dev 196 | train_data, dev_data = data_split(train_docs, dev_docs, data, neg_r = nr) 197 | print(len(train_data), len(dev_data)) 198 | 199 | # shuffle 200 | #random.Random(seed).shuffle(train_data) 201 | if not os.path.isdir(save_dir): 202 | os.mkdir(save_dir) 203 | 204 | with open(save_dir + '/train.pickle', 'wb') as handle: 205 | pickle.dump(train_data, handle, protocol=pickle.HIGHEST_PROTOCOL) 206 | handle.close() 207 | 208 | with open(save_dir + '/dev.pickle', 'wb') as handle: 209 | pickle.dump(dev_data, handle, protocol=pickle.HIGHEST_PROTOCOL) 210 | handle.close() 211 | 212 | return 213 | 214 | 215 | def reduce_vocab(data, save_dir, w2i, glove): 216 | # sent in data is index by original GloVe emb 217 | # 1. need to output a mappting from GloVe index to reduce index: glove2vocab 218 | # 2. a reduced emb saved in npy 219 | 220 | glove2vocab = {0:0, 1:1} 221 | count = 2 222 | emb = [] 223 | i2w = {v:k for k,v in w2i.items()} 224 | 225 | for x in data: 226 | for t in x[4][0]: 227 | if t not in glove2vocab: 228 | glove2vocab[t] = count 229 | count += 1 230 | emb.append(glove[i2w[t]]) 231 | 232 | emb = np.array(emb) 233 | print(emb.shape) 234 | assert emb.shape[1] == len(glove['the']) 235 | assert emb.shape[0] + 2 == len(glove2vocab) 236 | 237 | np.save(save_dir + '/emb_reduced.npy', emb) 238 | np.save(save_dir + '/glove2vocab.npy', glove2vocab) 239 | 240 | return 241 | 242 | def main(args): 243 | 244 | # to pick up here. 245 | if args.data_type == "matres": 246 | label_map = matres_label_map 247 | elif args.data_type == "tbd": 248 | label_map = tbd_label_map 249 | 250 | all_labels = list(OrderedDict.fromkeys(label_map.values())) 251 | 252 | args._label_to_id = OrderedDict([(all_labels[l],l) for l in range(len(all_labels))]) 253 | args._id_to_label = OrderedDict([(l,all_labels[l]) for l in range(len(all_labels))]) 254 | print(args._label_to_id) 255 | 256 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 257 | 258 | if args.load_model_dir: 259 | output_model_file = os.path.join(args.load_model_dir, "pytorch_model.bin") 260 | model_state_dict = torch.load(output_model_file) 261 | bert_model = BertModel.from_pretrained('bert-base-uncased', state_dict=model_state_dict) 262 | else: 263 | bert_model = BertModel.from_pretrained('bert-base-uncased') 264 | 265 | train_data = pickle.load(open( args.data_dir + "/train.pickle", "rb" )) 266 | print("process train...") 267 | data = [parallel(v, k, args, tokenizer, bert_model) for k,v in train_data.items()] 268 | 269 | if args.data_type in ['tbd']: 270 | print("process dev...") 271 | dev_data = pickle.load(open( args.data_dir + "/dev.pickle", "rb" )) 272 | dev_data = [parallel(v, k, args, tokenizer, bert_model) for k,v in dev_data.items()] 273 | data += dev_data 274 | 275 | # doc splits 276 | if args.data_type in ['matres']: 277 | train_docs, dev_docs = train_test_split(args.train_docs, test_size=0.2, random_state=args.seed) 278 | # TBDense data has given splits on train/dev/test 279 | else: 280 | train_docs = args.train_docs 281 | dev_docs = args.dev_docs 282 | 283 | if not os.path.isdir(args.save_data_dir): 284 | os.mkdir(args.save_data_dir) 285 | 286 | if 'all' in args.split: 287 | print("process test...") 288 | test_data = pickle.load(open( args.data_dir + "/test.pickle", "rb" )) 289 | test_data = [parallel(v, k, args, tokenizer, bert_model) for k,v in test_data.items()] 290 | print(len(test_data)) 291 | print(args.save_data_dir) 292 | 293 | with open(args.save_data_dir + '/test.pickle', 'wb') as handle: 294 | pickle.dump(test_data, handle, protocol=pickle.HIGHEST_PROTOCOL) 295 | handle.close() 296 | 297 | split_and_save(train_docs, dev_docs, data, args.seed, args.save_data_dir) 298 | 299 | # quick trick to reduce number of tokens in GloVe 300 | # reduce_vocab(data + test_data, args.save_data_dir, args.w2i, args.glove) 301 | return 302 | 303 | 304 | if __name__ == '__main__': 305 | p = argparse.ArgumentParser() 306 | p.add_argument('-data_dir', type=str, default = '../data') 307 | p.add_argument('-other_dir', type=str, default = '../other') 308 | p.add_argument('-load_model_dir', type=str, default = '') 309 | p.add_argument('-train_docs', type=list, default = []) 310 | p.add_argument('-dev_docs', type=list, default = []) 311 | p.add_argument('-split', type=str, default='bert_all_joint_cosmos') 312 | p.add_argument('-data_type', type=str, default='matres') 313 | p.add_argument('-seed', type=int, default=7) 314 | args = p.parse_args() 315 | 316 | args.data_dir += args.data_type 317 | if args.data_type == "tbd": 318 | args.train_docs = [x.strip() for x in open("%s/train_docs.txt" % args.data_dir, 'r')] 319 | args.dev_docs = [x.strip() for x in open("%s/dev_docs.txt" % args.data_dir, 'r')] 320 | elif args.data_type == "matres": 321 | args.train_docs = [x.strip() for x in open("%s/train_docs.txt" % args.data_dir, 'r')] 322 | print(args.train_docs[:10]) 323 | args.save_data_dir = args.data_dir + '/' + args.split 324 | 325 | glove = read_glove(args.other_dir + "/glove.6B.50d.txt") 326 | vocab = np.array(['', ''] + list(glove.keys())) 327 | args.w2i = OrderedDict((vocab[i], i) for i in range(len(vocab))) 328 | 329 | tags = open(args.other_dir + "/pos_tags.txt") 330 | pos2idx = {} 331 | idx = 0 332 | for tag in tags: 333 | tag = tag.strip() 334 | pos2idx[tag] = idx 335 | idx += 1 336 | args.pos2idx = pos2idx 337 | 338 | main(args) 339 | 340 | -------------------------------------------------------------------------------- /code/gurobi_inference.py: -------------------------------------------------------------------------------- 1 | from gurobipy import * 2 | from pathlib import Path 3 | from collections import defaultdict, Counter, OrderedDict 4 | from typing import Iterator, List, Mapping, Union, Optional, Set 5 | from datetime import datetime 6 | from utils import ClassificationReport 7 | import numpy as np 8 | import pickle 9 | 10 | class Global_Inference(): 11 | 12 | def __init__(self, prob_ents, prob_rels, cand_rels, label2idx, pairs, ew): 13 | 14 | self.model = Model("joint_inference") 15 | 16 | self.prob_ents = prob_ents 17 | self.prob_rels = prob_rels 18 | 19 | self.N, self.Nc = prob_ents.shape 20 | self.M, self.Mc = prob_rels.shape 21 | 22 | self.pred_ent_labels = list(np.argmax(prob_ents, axis=1)) 23 | self.pred_rel_labels = list(np.argmax(prob_rels, axis=1)) 24 | 25 | self.cand_rels = cand_rels 26 | self.ew = ew 27 | 28 | self.label2idx = label2idx 29 | self.idx2label = OrderedDict([(v,k) for k,v in label2idx.items()]) 30 | 31 | self.pairs = pairs 32 | self.idx2pair = {n: self.pairs[n] for n in range(len(pairs))} 33 | self.pair2idx = {v:k for k,v in self.idx2pair.items()} 34 | 35 | def define_vars(self): 36 | var_table_e, var_table_r = [], [] 37 | 38 | # entity variables 39 | for n in range(self.N): 40 | sample = [] 41 | for p in range(self.Nc): 42 | sample.append(self.model.addVar(vtype=GRB.BINARY, name="e_%s_%s"%(n,p))) 43 | var_table_e.append(sample) 44 | 45 | # relation variables 46 | for m in range(self.M): 47 | sample = [] 48 | for p in range(self.Mc): 49 | sample.append(self.model.addVar(vtype=GRB.BINARY, name="r_%s_%s"%(m,p))) 50 | var_table_r.append(sample) 51 | 52 | return var_table_e, var_table_r 53 | 54 | def objective(self, samples_e, samples_r, p_table_e, p_table_r): 55 | 56 | obj = 0.0 57 | 58 | assert len(samples_e) == self.N 59 | assert len(samples_r) == self.M 60 | assert len(samples_e[0]) == self.Nc 61 | assert len(samples_r[0]) == self.Mc 62 | 63 | # entity 64 | for n in range(self.N): 65 | for p in range(self.Nc): 66 | obj += self.ew * samples_e[n][p] * p_table_e[n][p] 67 | 68 | # relation 69 | for m in range(self.M): 70 | for p in range(self.Mc): 71 | obj += samples_r[m][p] * p_table_r[m][p] 72 | 73 | return obj 74 | 75 | def single_label(self, sample): 76 | return sum(sample) == 1 77 | 78 | def rel_ent_sum(self, samples_e, samples_r, e, r, c): 79 | # negative rel constraint 80 | return samples_e[e[0]][0] + samples_e[e[1]][0] - samples_r[r][c] 81 | 82 | def rel_left_ent(self, samples_e, samples_r, e, r, c): 83 | # positive rel left constraint 84 | return samples_e[e[0]][1] - samples_r[r][c] 85 | 86 | def rel_right_ent(self, samples_e, samples_r, e, r, c): 87 | # positive rel right constraint 88 | return samples_e[e[1]][1] - samples_r[r][c] 89 | 90 | def transitivity_list(self): 91 | 92 | transitivity_samples = [] 93 | pair2idx = self.pair2idx 94 | 95 | for k, (e1, e2) in self.idx2pair.items(): 96 | for (re1, re2), i in pair2idx.items(): 97 | if e2 == re1 and (e1, re2) in pair2idx.keys(): 98 | transitivity_samples.append((pair2idx[(e1, e2)], pair2idx[(re1, re2)], pair2idx[(e1, re2)])) 99 | return transitivity_samples 100 | 101 | def transitivity_criteria(self, samples, triplet): 102 | # r1 r2 Trans(r1, r2) 103 | # _____________________ 104 | # r r r 105 | # r s r 106 | # b v b, v 107 | # a v a, v 108 | # v b b, v 109 | # v a a, v 110 | r1, r2, r3 = triplet 111 | label_dict = self.label2idx 112 | 113 | return [ 114 | samples[r1][label_dict['BEFORE']] + samples[r2][label_dict['BEFORE']] - samples[r3][label_dict['BEFORE']], 115 | samples[r1][label_dict['AFTER']] + samples[r2][label_dict['AFTER']] - samples[r3][label_dict['AFTER']], 116 | samples[r1][label_dict['SIMULTANEOUS']] + samples[r2][label_dict['SIMULTANEOUS']] - samples[r3][label_dict['SIMULTANEOUS']], 117 | #samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['VAGUE']], 118 | #samples[r1][label_dict['NONE']] + samples[r2][label_dict['NONE']] - samples[r3][label_dict['NONE']], 119 | samples[r1][label_dict['BEFORE']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']], 120 | samples[r1][label_dict['AFTER']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']], 121 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['BEFORE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']], 122 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['AFTER']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] 123 | ] 124 | 125 | ''' TBD 126 | return [ 127 | samples[r1][label_dict['BEFORE']] + samples[r2][label_dict['BEFORE']] - samples[r3][label_dict['BEFORE']], 128 | samples[r1][label_dict['AFTER']] + samples[r2][label_dict['AFTER']] - samples[r3][label_dict['AFTER']], 129 | samples[r1][label_dict['SIMULTANEOUS']] + samples[r2][label_dict['SIMULTANEOUS']] - samples[r3][label_dict['SIMULTANEOUS']], 130 | samples[r1][label_dict['INCLUDES']] + samples[r2][label_dict['INCLUDES']] - samples[r3][label_dict['INCLUDES']], 131 | samples[r1][label_dict['IS_INCLUDED']] + samples[r2][label_dict['IS_INCLUDED']] - samples[r3][label_dict['IS_INCLUDED']], 132 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['VAGUE']], 133 | samples[r1][label_dict['BEFORE']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['IS_INCLUDED']], 134 | samples[r1][label_dict['BEFORE']] + samples[r2][label_dict['INCLUDES']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']], 135 | samples[r1][label_dict['BEFORE']] + samples[r2][label_dict['IS_INCLUDED']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['IS_INCLUDED']], 136 | samples[r1][label_dict['AFTER']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['IS_INCLUDED']], 137 | samples[r1][label_dict['AFTER']] + samples[r2][label_dict['INCLUDES']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']], 138 | samples[r1][label_dict['AFTER']] + samples[r2][label_dict['IS_INCLUDED']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['IS_INCLUDED']], 139 | samples[r1][label_dict['INCLUDES']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']], 140 | samples[r1][label_dict['INCLUDES']] + samples[r2][label_dict['BEFORE']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']], 141 | samples[r1][label_dict['INCLUDES']] + samples[r2][label_dict['AFTER']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['AFTER']], 142 | samples[r1][label_dict['IS_INCLUDED']] + samples[r2][label_dict['VAGUE']] - samples[r3][label_dict['IS_INCLUDED']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['AFTER']], 143 | samples[r1][label_dict['IS_INCLUDED']] + samples[r2][label_dict['BEFORE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['IS_INCLUDED']], 144 | samples[r1][label_dict['IS_INCLUDED']] + samples[r2][label_dict['AFTER']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['IS_INCLUDED']], 145 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['BEFORE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['IS_INCLUDED']], 146 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['AFTER']] - samples[r3][label_dict['AFTER']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['IS_INCLUDED']], 147 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['INCLUDES']] - samples[r3][label_dict['INCLUDES']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['AFTER']], 148 | samples[r1][label_dict['VAGUE']] + samples[r2][label_dict['IS_INCLUDED']] - samples[r3][label_dict['IS_INCLUDED']] - samples[r3][label_dict['VAGUE']] - samples[r3][label_dict['BEFORE']] - samples[r3][label_dict['AFTER']] 149 | ] 150 | ''' 151 | def define_constraints(self, var_table_e, var_table_r): 152 | # Constraint 1: single label assignment 153 | for n in range(self.N): 154 | self.model.addConstr(self.single_label(var_table_e[n]), "c1_%s" % n) 155 | for m in range(self.M): 156 | self.model.addConstr(self.single_label(var_table_r[m]), "c1_%s" % (self.N + m)) 157 | 158 | # Constraint 2: Positive relation requires positive event arguments 159 | for r, cr in enumerate(self.cand_rels): 160 | for c in range(self.Mc-1): 161 | self.model.addConstr(self.rel_left_ent(var_table_e, var_table_r, cr, r, c) >= 0, "c2_%s_%s" % (r, c)) 162 | self.model.addConstr(self.rel_right_ent(var_table_e, var_table_r, cr, r, c) >= 0, "c3_%s_%s" % (r, c)) 163 | if c == self.Mc-1: 164 | self.model.addConstr(self.rel_ent_sum(var_table_e, var_table_r, cr, r, c) >= 0, "c4_%s_%s" % (r, c)) 165 | 166 | 167 | # Constraint 3: transitivity 168 | trans_triples = self.transitivity_list() 169 | t = 0 170 | for triple in trans_triples: 171 | for ci in self.transitivity_criteria(var_table_r, triple): 172 | self.model.addConstr(ci <= 1, "c5_%s" % t) 173 | t += 1 174 | return 175 | 176 | def run(self): 177 | try: 178 | # Define variables 179 | var_table_e, var_table_r = self.define_vars() 180 | 181 | # Set objective 182 | self.model.setObjective(self.objective(var_table_e, var_table_r, self.prob_ents, 183 | self.prob_rels), GRB.MAXIMIZE) 184 | 185 | # Define constrains 186 | self.define_constraints(var_table_e, var_table_r) 187 | 188 | # run model 189 | self.model.setParam('OutputFlag', False) 190 | self.model.optimize() 191 | 192 | except GurobiError: 193 | print('Error reported') 194 | 195 | 196 | def predict(self): 197 | ent_count, rel_count = 0, 0 198 | 199 | for i, v in enumerate(self.model.getVars()): 200 | 201 | # rel_ent indicator 202 | is_ent = True if v.varName.split('_')[0] == 'e' else False 203 | # sample idx 204 | s_idx = int(v.varName.split('_')[1]) 205 | # sample class index 206 | c_idx = int(v.varName.split('_')[2]) 207 | 208 | if is_ent: 209 | if v.x == 1.0 and self.pred_ent_labels[s_idx] != c_idx: 210 | #print(v.varName, self.pred_ent_labels[s_idx]) 211 | self.pred_ent_labels[s_idx] = c_idx 212 | ent_count += 1 213 | else: 214 | if v.x == 1.0 and self.pred_rel_labels[s_idx] != c_idx: 215 | #print(v.varName, self.pred_rel_labels[s_idx]) 216 | self.pred_rel_labels[s_idx] = c_idx 217 | rel_count += 1 218 | 219 | print('# of global entity correction: %s' % ent_count) 220 | print('# of global relation correction: %s' % rel_count) 221 | print('Objective Function Value:', self.model.objVal) 222 | 223 | return 224 | -------------------------------------------------------------------------------- /code/joint_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pickle 3 | import sys 4 | import argparse 5 | from collections import defaultdict, Counter, OrderedDict 6 | from itertools import combinations 7 | from typing import Iterator, List, Mapping, Union, Optional, Set 8 | import logging as log 9 | import abc 10 | from dataclasses import dataclass 11 | from datetime import datetime 12 | import numpy as np 13 | import random 14 | import torch 15 | import torch.autograd as autograd 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | from torch.autograd import Variable 20 | from torch.nn import Parameter 21 | import math 22 | import time 23 | import copy 24 | from torch.utils import data 25 | from torch.nn.utils.rnn import pack_padded_sequence as pack, pad_packed_sequence as unpack 26 | from featurize_data import matres_label_map, tbd_label_map 27 | from functools import partial 28 | from sklearn.model_selection import KFold, ParameterGrid, train_test_split 29 | from utils import ClassificationReport 30 | 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = False 33 | 34 | torch.manual_seed(123) 35 | 36 | def pad_collate(batch): 37 | """Puts data, and lengths into a packed_padded_sequence then returns 38 | the packed_padded_sequence and the labels. Set use_lengths to True 39 | to use this collate function. 40 | Args: 41 | batch: (list of tuples) [(doc_id, sample_id, pair, label, sent, pos, fts, rev, lidx_start_s, lidx_end_s, ridx_start_s, ridx_end_s, pred_ind)]. 42 | 43 | Output: 44 | packed_batch: (PackedSequence for sent and pos), see torch.nn.utils.rnn.pack_padded_sequence 45 | labels: (Tensor) 46 | 47 | other arguments remain the same. 48 | """ 49 | if len(batch) >= 1: 50 | 51 | bs = list(zip(*[ex for ex in sorted(batch, key=lambda x: x[2].shape[0], reverse=True)])) 52 | 53 | max_len, n_fts = bs[2][0].shape 54 | lengths = [x.shape[0] for x in bs[2]] 55 | 56 | ### gather sents: idx = 2 in batch_sorted 57 | sents = [torch.cat((torch.FloatTensor(s), torch.zeros(max_len - s.shape[0], n_fts)), 0) 58 | if s.shape[0] != max_len else torch.FloatTensor(s) for s in bs[2]] 59 | sents = torch.stack(sents, 0) 60 | 61 | # gather entity labels: idx = 3 in batch_sorted 62 | # we need a unique doc_span key for aggregation later 63 | all_key_ent = [list(zip(*key_ent)) for key_ent in bs[3]] 64 | 65 | keys = [[(bs[0][i], k) for k in v[0]] for i, v in enumerate(all_key_ent)] 66 | 67 | ents = [v[1] for v in all_key_ent] 68 | ents = [torch.cat((torch.LongTensor(s).unsqueeze(1), torch.zeros(max_len - len(s), 1, dtype=torch.long)), 0) 69 | if len(s) != max_len else torch.LongTensor(s).unsqueeze(1) for s in ents] 70 | ents = torch.stack(ents, 0).squeeze(2) 71 | 72 | # gather pos tags: idx = 6 in batch_sorted; treat pad as 0 -- this needs to be fixed !!! 73 | #poss = [torch.cat((s.unsqueeze(1), torch.zeros(max_len - s.size(0), 1, dtype=torch.long)), 0) 74 | # if s.size(0) != max_len else s.unsqueeze(1) for s in bs[4]] 75 | #poss = torch.stack(poss, 0) 76 | 77 | return bs[0], bs[1], sents, keys, ents, bs[4], bs[5], lengths 78 | 79 | 80 | class EventDataset(data.Dataset): 81 | 'Characterizes a dataset for PyTorch' 82 | def __init__(self, data_dir, data_split): 83 | 'Initialization' 84 | # load data 85 | with open(data_dir + data_split + '.pickle', 'rb') as handle: 86 | self.data = pickle.load(handle) 87 | self.data = list(self.data.values()) 88 | handle.close() 89 | 90 | def __len__(self): 91 | 'Denotes the total number of samples' 92 | return len(self.data) 93 | 94 | def __getitem__(self, idx): 95 | 'Generates one sample of data' 96 | 97 | sample = self.data[idx] 98 | doc_id = sample['doc_id'] 99 | context_id = sample['context_id'] 100 | context = sample['context'] 101 | rels = sample['rels'] 102 | 103 | return doc_id, context_id, context[0], context[1], context[2], rels 104 | 105 | class BertClassifier(nn.Module): 106 | 'Neural Network Architecture' 107 | def __init__(self, args): 108 | 109 | super(BertClassifier, self).__init__() 110 | 111 | self.hid_size = args.hid 112 | self.batch_size = args.batch 113 | self.num_layers = args.num_layers 114 | self.num_classes = len(args.label_to_id) 115 | self.num_ent_classes = 2 116 | 117 | self.dropout = nn.Dropout(p=args.dropout) 118 | # lstm is shared for both relation and entity 119 | self.lstm = nn.LSTM(768, self.hid_size, self.num_layers, bias = False, bidirectional=True) 120 | 121 | # MLP classifier for relation 122 | self.linear1 = nn.Linear(self.hid_size*4+args.n_fts, self.hid_size) 123 | self.linear2 = nn.Linear(self.hid_size, self.num_classes) 124 | 125 | # MLP classifier for entity 126 | self.linear1_ent = nn.Linear(self.hid_size*2, int(self.hid_size / 2)) 127 | self.linear2_ent = nn.Linear(int(self.hid_size / 2), self.num_ent_classes) 128 | 129 | self.act = nn.Tanh() 130 | self.softmax = nn.Softmax(dim=1) 131 | self.softmax_ent = nn.Softmax(dim=2) 132 | 133 | def forward(self, sents, lengths, fts = [], rel_idxs=[], lidx_start=[], lidx_end=[], ridx_start=[], 134 | ridx_end=[], pred_ind=True, flip=False, causal=False, token_type_ids=None, task='relation'): 135 | 136 | batch_size = sents.size(0) 137 | # dropout 138 | out = self.dropout(sents) 139 | # pack and lstm layer 140 | out, _ = self.lstm(pack(out, lengths, batch_first=True)) 141 | # unpack 142 | out, _ = unpack(out, batch_first = True) 143 | 144 | ### entity prediction - predict each input token 145 | if task == 'entity': 146 | out_ent = self.linear1_ent(self.dropout(out)) 147 | out_ent = self.act(out_ent) 148 | out_ent = self.linear2_ent(out_ent) 149 | prob_ent = self.softmax_ent(out_ent) 150 | return out_ent, prob_ent 151 | 152 | ### relaiton prediction - flatten hidden vars into a long vector 153 | if task == 'relation': 154 | 155 | ltar_f = torch.cat([out[b, lidx_start[b][r], :self.hid_size].unsqueeze(0) for b,r in rel_idxs], dim=0) 156 | ltar_b = torch.cat([out[b, lidx_end[b][r], self.hid_size:].unsqueeze(0) for b,r in rel_idxs], dim=0) 157 | rtar_f = torch.cat([out[b, ridx_start[b][r], :self.hid_size].unsqueeze(0) for b,r in rel_idxs], dim=0) 158 | rtar_b = torch.cat([out[b, ridx_end[b][r], self.hid_size:].unsqueeze(0) for b,r in rel_idxs], dim=0) 159 | 160 | out = self.dropout(torch.cat((ltar_f, ltar_b, rtar_f, rtar_b), dim=1)) 161 | out = torch.cat((out, fts), dim=1) 162 | 163 | # linear prediction 164 | out = self.linear1(out) 165 | out = self.act(out) 166 | out = self.dropout(out) 167 | out = self.linear2(out) 168 | prob = self.softmax(out) 169 | return out, prob 170 | 171 | @dataclass() 172 | class NNClassifier(nn.Module): 173 | def __init__(self): 174 | super(NNClassifier, self).__init__() 175 | #self.label_probs = [] 176 | 177 | def predict(self, model, data, args, test=False, gold=True, model_r=None): 178 | 179 | model.eval() 180 | 181 | criterion = nn.CrossEntropyLoss() 182 | 183 | count = 1 184 | labels, probs, losses_t, losses_e = [], [], [], [] 185 | pred_inds, docs, pairs = [], [], [] 186 | 187 | # stoare non-predicted rels in list 188 | nopred_rels = [] 189 | 190 | ent_pred_map, ent_label_map = {}, {} 191 | rd_pred_map, rd_label_map = {}, {} 192 | 193 | for doc_id, context_id, sents, ent_keys, ents, poss, rels, lengths in data: 194 | 195 | if args.cuda: 196 | sents = sents.cuda() 197 | ents = ents.cuda() 198 | 199 | ## predict entity first 200 | out_e, prob_e = model(sents, lengths, task='entity') 201 | 202 | labels_r, fts, rel_idxs, doc, pair, lidx_start, lidx_end, ridx_start, ridx_end, nopred_rel = self.construct_relations(prob_e, lengths, rels, list(doc_id), poss, gold=gold) 203 | 204 | nopred_rels.extend(nopred_rel) 205 | 206 | ### predict relations 207 | if rel_idxs: # predicted relation could be empty --> skip 208 | docs.extend(doc) 209 | pairs.extend(pair) 210 | 211 | if args.cuda: 212 | labels_r = labels_r.cuda() 213 | fts = fts.cuda() 214 | 215 | if model_r: 216 | model_r.eval() 217 | out_r, prob_r = model_r(sents, lengths, fts=fts, rel_idxs=rel_idxs, lidx_start=lidx_start, 218 | lidx_end=lidx_end, ridx_start=ridx_start, ridx_end=ridx_end) 219 | else: 220 | out_r, prob_r = model(sents, lengths, fts=fts, rel_idxs=rel_idxs, lidx_start=lidx_start, 221 | lidx_end=lidx_end, ridx_start=ridx_start, ridx_end=ridx_end) 222 | loss_r = criterion(out_r, labels_r) 223 | predicted = (prob_r.data.max(1)[1]).long().view(-1) 224 | 225 | if args.cuda: 226 | loss_r = loss_r.cpu() 227 | prob_r = prob_r.cpu() 228 | labels_r = labels_r.cpu() 229 | 230 | losses_t.append(loss_r.data.numpy()) 231 | probs.append(prob_r) 232 | labels.append(labels_r) 233 | 234 | # retrieve and flatten entity prediction for loss calculation 235 | ent_pred, ent_label, ent_prob, ent_key, ent_pos = [], [], [], [], [] 236 | for i,l in enumerate(lengths): 237 | # flatten prediction 238 | ent_pred.append(out_e[i, :l]) 239 | # flatten entity prob 240 | ent_prob.append(prob_e[i, :l]) 241 | # flatten entity label 242 | ent_label.append(ents[i, :l]) 243 | # flatten entity key - a list of original (extend) 244 | assert len(ent_keys[i]) == l 245 | ent_key.extend(ent_keys[i]) 246 | # flatten pos tags 247 | ent_pos.extend([p for p in poss[i]]) 248 | 249 | ent_pred = torch.cat(ent_pred, 0) 250 | ent_label = torch.cat(ent_label, 0) 251 | ent_probs = torch.cat(ent_prob, 0) 252 | 253 | assert ent_pred.size(0) == ent_label.size(0) 254 | assert ent_pred.size(0) == len(ent_key) 255 | 256 | loss_e = criterion(ent_pred, ent_label) 257 | losses_e.append(loss_e.cpu().data.numpy()) 258 | 259 | ent_label = ent_label.tolist() 260 | 261 | for i, v in enumerate(ent_key): 262 | label_e = ent_label[i] 263 | prob_e = ent_probs[i] 264 | 265 | # exclude sent_start and sent_sep 266 | if v in ["[SEP]", "[CLS]"]: 267 | assert ent_pos[i] in ["[SEP]", "[CLS]"] 268 | 269 | if v not in ent_pred_map: 270 | # only store the probability of being 1 (is an event) 271 | ent_pred_map[v] = [prob_e.tolist()[1]] 272 | ent_label_map[v] = (label_e, ent_pos[i]) 273 | else: 274 | # if key stored already, append another prediction 275 | ent_pred_map[v].append(prob_e.tolist()[1]) 276 | # and ensure label is the same 277 | assert ent_label_map[v][0] == label_e 278 | assert ent_label_map[v][1] == ent_pos[i] 279 | 280 | count += 1 281 | if count % 10 == 0: 282 | print("finished evaluating %s samples" % (count * args.batch)) 283 | 284 | ## collect relation prediction results 285 | probs = torch.cat(probs,dim=0) 286 | labels = torch.cat(labels,dim=0) 287 | 288 | assert labels.size(0) == probs.size(0) 289 | 290 | # calculate entity F1 score here 291 | # update ent_pred_map with [mean > 0.5 --> 1] 292 | 293 | ent_pred_map_agg = {k:1 if np.mean(v) > 0.5 else 0 for k,v in ent_pred_map.items()} 294 | 295 | n_correct = 0 296 | n_pred = 0 297 | 298 | pos_keys = OrderedDict([(k, v) for k, v in ent_label_map.items() if v[0]==1]) 299 | n_true = len(pos_keys) 300 | 301 | for k,v in ent_label_map.items(): 302 | if ent_pred_map_agg[k] == 1: 303 | n_pred += 1 304 | if ent_pred_map_agg[k] == 1 and ent_label_map[k][0] == 1: 305 | n_correct += 1 306 | 307 | print(n_pred, n_true, n_correct) 308 | 309 | def safe_division(numr, denr, on_err=0.0): 310 | return on_err if denr == 0.0 else float(numr) / float(denr) 311 | 312 | precision = safe_division(n_correct, n_pred) 313 | recall = safe_division(n_correct, n_true) 314 | f1_score = safe_division(2.0 * precision * recall, precision + recall) 315 | 316 | print("Evaluation temporal relation loss: %.4f" % np.mean(losses_t)) 317 | print("Evaluation temporal entity loss: %.4f; F1: %.4f" % (np.mean(losses_e), f1_score)) 318 | 319 | if test: 320 | return probs.data, np.mean(losses_t), labels, docs, pairs, f1_score, nopred_rels 321 | else: 322 | return probs.data, np.mean(losses_t), labels, docs, pairs, n_pred, n_true, n_correct, nopred_rels 323 | 324 | def construct_relations(self, ent_probs, lengths, rels, doc, poss, gold=True, train=True): 325 | # many relation properties such rev and pred_ind are not used for now 326 | 327 | nopred_rels = [] 328 | 329 | ## Case 1: only use gold relation 330 | if gold: 331 | pred_rels = rels 332 | 333 | ## Case 2: use candidate relation predicted by entity model 334 | else: 335 | def _is_gold(pred_span, gold_rel_span): 336 | return ((gold_rel_span[0] <= pred_span <= gold_rel_span[1])) 337 | 338 | batch_size = ent_probs.size(0) 339 | ent_probs = ent_probs.cpu() 340 | 341 | # select event based on prob > 0.5, but eliminate ent_pred > context length 342 | ent_locs = [[x for x in (ent_probs[b,:, 1] > 0.5).nonzero().view(-1).tolist() 343 | if x < lengths[b]] for b in range(batch_size)] 344 | 345 | # all possible relation candiate based on pred_ent 346 | rel_locs = [list(combinations(el, 2)) for el in ent_locs] 347 | 348 | pred_rels = [] 349 | totl = 0 350 | # use the smallest postive sample id as start of neg id 351 | # this may not be perfect, but we really don't care about neg id 352 | neg_counter = min([int(x[0][1:]) for rel in rels for x in rel]) 353 | 354 | for i, rl in enumerate(rel_locs): 355 | temp_rels, temp_ids = [], [] 356 | for r in rl: 357 | sent_segs = len([x for x in poss[i] if x == '[SEP]']) 358 | in_seg = [x for x in poss[i][r[0] : r[1]] if x == '[SEP]'] 359 | ### exclude rel that are in the same sentence, but two segments exist. i.e. unique input context 360 | if (sent_segs > 1) and (len(in_seg) == 0): 361 | continue 362 | else: 363 | totl += 1 364 | gold_match = [x for x in rels[i] if _is_gold(r[0], x[5][:2]) and _is_gold(r[1], x[5][2:])] 365 | # multiple tokens could indicate the same events. 366 | # simple pick the one occurs first 367 | if len(gold_match) > 0 and gold_match[0][0] not in temp_ids: 368 | temp_rels.append(gold_match[0]) 369 | temp_ids.append(gold_match[0][0]) 370 | else: 371 | ## construct a negative relation pair -- 'NONE' 372 | neg_id = 'N%s' % neg_counter 373 | left_match = [x for x in rels[i] if _is_gold(r[0], x[5][:2])] 374 | right_match = [x for x in rels[i] if _is_gold(r[1], x[5][2:])] 375 | # provide a random but unique id for event predicted if not matched in gold 376 | left_id = left_match[0][1][0] if len(left_match) > 0 else ('e%s' % (neg_counter + 10000)) 377 | right_id = right_match[0][1][1] if len(right_match) > 0 else ('e%s' % (neg_counter + 20000)) 378 | a_rel = (neg_id, (left_id, right_id), self._label_to_id['NONE'], 379 | [float(r[1] - r[0])], False, (r[0], r[0], r[1], r[1]), True) 380 | temp_rels.append(a_rel) 381 | neg_counter += 1 382 | nopred_rels.extend([x[2] for x in rels[i] if x[0] not in [tr[0] for tr in temp_rels]]) 383 | pred_rels.append(temp_rels) 384 | 385 | # relations are (flatten) lists of features 386 | # rel_idxs indicates (batch_id, rel_in_batch_id) 387 | docs, pairs = [], [] 388 | rel_idxs, lidx_start, lidx_end, ridx_start, ridx_end = [],[],[],[],[] 389 | for i, rel in enumerate(pred_rels): 390 | rel_idxs.extend([(i, ii) for ii, _ in enumerate(rel)]) 391 | lidx_start.append([x[5][0] for x in rel]) 392 | lidx_end.append([x[5][1] for x in rel]) 393 | ridx_start.append([x[5][2] for x in rel]) 394 | ridx_end.append([x[5][3] for x in rel]) 395 | pairs.extend([x[1] for x in rel]) 396 | docs.extend([doc[i] for _ in rel]) 397 | assert len(docs) == len(pairs) 398 | 399 | rels = [x for rel in pred_rels for x in rel] 400 | if rels == []: 401 | labels = torch.FloatTensor([]) 402 | fts = torch.FloatTensor([]) 403 | else: 404 | labels = torch.LongTensor([x[2] for x in rels]) 405 | fts = torch.cat([torch.FloatTensor(x[3]) for x in rels]).unsqueeze(1) 406 | 407 | return labels, fts, rel_idxs, docs, pairs, lidx_start, lidx_end, ridx_start, ridx_end, nopred_rels 408 | 409 | def _train(self, train_data, eval_data, pos_emb, args): 410 | 411 | model = BertClassifier(args) 412 | 413 | if args.cuda: 414 | print("using cuda device: %s" % torch.cuda.current_device()) 415 | assert torch.cuda.is_available() 416 | model.cuda() 417 | 418 | 419 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) 420 | criterion_e = nn.CrossEntropyLoss() 421 | 422 | if args.data_type in ['tbd']: 423 | weights = torch.FloatTensor([1.0, 1.0, 1.0, args.uw, args.uw, args.uw, 1.0]) 424 | 425 | else: 426 | weights = torch.FloatTensor([1.0, 1.0, 1.0, args.uw, 1.0]) 427 | 428 | if args.cuda: 429 | weights = weights.cuda() 430 | 431 | criterion_r = nn.CrossEntropyLoss(weight=weights) 432 | losses = [] 433 | 434 | sents, poss, ftss, labels = [], [], [], [] 435 | if args.load_model == True: 436 | checkpoint = torch.load(args.ilp_dir + args.entity_model_file) 437 | model.load_state_dict(checkpoint['state_dict']) 438 | epoch = checkpoint['epoch'] 439 | best_eval_f1 = checkpoint['f1'] 440 | print("Local best eval f1 is: %s" % best_eval_f1) 441 | 442 | best_eval_f1 = 0.0 443 | best_epoch = 0 444 | 445 | for epoch in range(args.epochs): 446 | print("Training Epoch #%s..." % epoch) 447 | model.train() 448 | count = 1 449 | 450 | loss_hist_t, loss_hist_e = [], [] 451 | 452 | start_time = time.time() 453 | 454 | gold = False if epoch > args.pipe_epoch else True 455 | for doc_id, context_id, sents, keys, ents, poss, rels, lengths in train_data: 456 | 457 | if args.cuda: 458 | sents = sents.cuda() 459 | ents = ents.cuda() 460 | 461 | model.zero_grad() 462 | 463 | ## predict entity first 464 | out_e, prob_e = model(sents, lengths, task='entity') 465 | 466 | labels_r, fts, rel_idxs, _, _, lidx_start, lidx_end, ridx_start, ridx_end, _ = self.construct_relations(prob_e, lengths, rels, list(doc_id), poss, gold=gold) 467 | 468 | if args.cuda: 469 | labels_r = labels_r.cuda() 470 | fts = fts.cuda() 471 | 472 | # retrieve and flatten entity prediction for loss calculation 473 | ent_pred, ent_label = [], [] 474 | 475 | for i,l in enumerate(lengths): 476 | # flatten prediction 477 | ent_pred.append(out_e[i, :l]) 478 | # flatten entity label 479 | ent_label.append(ents[i, :l]) 480 | 481 | ent_pred = torch.cat(ent_pred, 0) 482 | ent_label = torch.cat(ent_label, 0) 483 | 484 | assert ent_pred.size(0) == ent_label.size(0) 485 | 486 | loss_e = criterion_e(ent_pred, ent_label) 487 | 488 | ## predict relations 489 | loss_r = 0 490 | if rel_idxs: 491 | out_r, prob_r = model(sents, lengths, fts=fts, rel_idxs=rel_idxs, lidx_start=lidx_start, 492 | lidx_end=lidx_end, ridx_start=ridx_start, ridx_end=ridx_end) 493 | loss_r = criterion_r(out_r, labels_r) 494 | 495 | loss = args.relation_weight * loss_r + args.entity_weight * loss_e 496 | loss.backward() 497 | 498 | optimizer.step() 499 | 500 | if args.cuda: 501 | if loss_r != 0: 502 | loss_hist_t.append(loss_r.data.cpu().numpy()) 503 | loss_hist_e.append(loss_e.data.cpu().numpy()) 504 | else: 505 | if loss_r != 0: 506 | loss_hist_t.append(loss_r.data.numpy()) 507 | loss_hist_e.append(loss_e.data.numpy()) 508 | 509 | if count % 100 == 0: 510 | print("trained %s samples" % (count * args.batch)) 511 | print("Temporal loss is %.4f" % np.mean(loss_hist_t)) 512 | print("Entity loss is %.4f" % np.mean(loss_hist_e)) 513 | print("%.4f seconds elapsed" % (time.time() - start_time)) 514 | count += 1 515 | 516 | # Evaluate at the end of each epoch 517 | print("*"*50) 518 | if len(eval_data) > 0: 519 | 520 | # need to have a warm-start otherwise there could be no event_pred 521 | # may need to manually pick poch < #, but 0 generally works when ew is large 522 | #eval_gold = True if epoch == 0 else args.eval_gold 523 | eval_gold = gold 524 | eval_preds, eval_loss, eval_labels, _, _, ent_pred, ent_true, ent_corr, nopred_rels = self.predict(model, eval_data, args, gold=eval_gold) 525 | pred_labels = eval_preds.max(1)[1].long().view(-1) 526 | assert eval_labels.size() == pred_labels.size() 527 | 528 | eval_correct = (pred_labels == eval_labels).sum() 529 | eval_acc = float(eval_correct) / float(len(eval_labels)) 530 | 531 | pred_labels = list(pred_labels.numpy()) 532 | eval_labels = list(eval_labels.numpy()) 533 | 534 | # Append non-predicted labels as label: Gold; Pred: None 535 | if not eval_gold: 536 | print(len(nopred_rels)) 537 | pred_labels.extend([self._label_to_id['NONE'] for _ in nopred_rels]) 538 | eval_labels.extend(nopred_rels) 539 | 540 | if args.data_type in ['red', 'caters']: 541 | pred_labels = [pred_labels[k] if v == 1 else self._label_to_id['NONE'] for k,v in enumerate(pred_inds)] 542 | 543 | # select model only based on entity + relation F1 score 544 | eval_f1 = self.weighted_f1(pred_labels, eval_labels, ent_corr, ent_pred, ent_true, 545 | args.relation_weight, args.entity_weight) 546 | 547 | # args.pipe_epoch <= args.epochs if pipeline (joint) training is used 548 | if eval_f1 > best_eval_f1 and (epoch > args.pipe_epoch or args.pipe_epoch >= 1000): 549 | best_eval_f1 = eval_f1 550 | self.model = copy.deepcopy(model) 551 | best_epoch = epoch 552 | 553 | print("Evaluation loss: %.4f; Evaluation F1: %.4f" % (eval_loss, eval_f1)) 554 | print("*"*50) 555 | 556 | print("Final Evaluation F1: %.4f at Epoch %s" % (best_eval_f1, best_epoch)) 557 | print("*"*50) 558 | 559 | if len(eval_data) == 0 or args.load_model: 560 | self.model = copy.deepcopy(model) 561 | best_epoch = epoch 562 | 563 | if args.save_model == True: 564 | torch.save({'epoch': epoch, 565 | 'args': args, 566 | 'state_dict': self.model.cpu().state_dict(), 567 | 'f1': best_eval_f1, 568 | 'optimizer' : optimizer.state_dict() 569 | }, "%s%s.pth.tar" % (args.ilp_dir, args.save_stamp)) 570 | 571 | return best_eval_f1, best_epoch 572 | 573 | def train_epoch(self, train_data, dev_data, args, test_data = None): 574 | 575 | if args.data_type == "matres": 576 | label_map = matres_label_map 577 | if args.data_type == "tbd": 578 | label_map = tbd_label_map 579 | assert len(label_map) > 0 580 | 581 | all_labels = list(OrderedDict.fromkeys(label_map.values())) 582 | ## append negative pair label 583 | all_labels.append('NONE') 584 | 585 | self._label_to_id = OrderedDict([(all_labels[l],l) for l in range(len(all_labels))]) 586 | self._id_to_label = OrderedDict([(l,all_labels[l]) for l in range(len(all_labels))]) 587 | 588 | print(self._label_to_id) 589 | print(self._id_to_label) 590 | 591 | args.label_to_id = self._label_to_id 592 | 593 | ### pos embdding is not used for now, but can be added later 594 | pos_emb= np.zeros((len(args.pos2idx) + 1, len(args.pos2idx) + 1)) 595 | for i in range(pos_emb.shape[0]): 596 | pos_emb[i, i] = 1.0 597 | 598 | best_f1, best_epoch = self._train(train_data, dev_data, pos_emb, args) 599 | print("Final Dev F1: %.4f" % best_f1) 600 | return best_f1, best_epoch 601 | 602 | def weighted_f1(self, pred_labels, true_labels, ent_corr, ent_pred, ent_true, rw=0.0, ew=0.0): 603 | def safe_division(numr, denr, on_err=0.0): 604 | return on_err if denr == 0.0 else numr / denr 605 | 606 | assert len(pred_labels) == len(true_labels) 607 | 608 | weighted_f1_scores = {} 609 | if 'NONE' in self._label_to_id.keys(): 610 | num_tests = len([x for x in true_labels if x != self._label_to_id['NONE']]) 611 | else: 612 | num_tests = len([x for x in true_labels]) 613 | 614 | print("Total positive samples to eval: %s" % num_tests) 615 | total_true = Counter(true_labels) 616 | total_pred = Counter(pred_labels) 617 | 618 | labels = list(self._id_to_label.keys()) 619 | 620 | n_correct = 0 621 | n_true = 0 622 | n_pred = 0 623 | 624 | if rw > 0: 625 | # f1 score is used for tcr and matres and hence exclude vague 626 | exclude_labels = ['VAGUE', 'NONE'] if len(self._label_to_id) == 5 else ['NONE'] 627 | 628 | for label in labels: 629 | if self._id_to_label[label] not in exclude_labels: 630 | 631 | true_count = total_true.get(label, 0) 632 | pred_count = total_pred.get(label, 0) 633 | 634 | n_true += true_count 635 | n_pred += pred_count 636 | 637 | correct_count = len([l for l in range(len(pred_labels)) 638 | if pred_labels[l] == true_labels[l] and pred_labels[l] == label]) 639 | n_correct += correct_count 640 | if ew > 0: 641 | # add entity prediction results before calculating precision, recall and f1 642 | n_correct += ent_corr 643 | n_pred += ent_pred 644 | n_true += ent_true 645 | 646 | precision = safe_division(n_correct, n_pred) 647 | recall = safe_division(n_correct, n_true) 648 | f1_score = safe_division(2.0 * precision * recall, precision + recall) 649 | print("Overall Precision: %.4f\tRecall: %.4f\tF1: %.4f" % (precision, recall, f1_score)) 650 | 651 | return(f1_score) 652 | 653 | class EventEvaluator: 654 | def __init__(self, model): 655 | self.model = model 656 | 657 | def evaluate(self, test_data, args): 658 | # load test data first since it needs to be executed twice in this function 659 | print("start testing...") 660 | if args.model == "singletask/pipeline": 661 | model_r = BertClassifier(args) 662 | if args.cuda: 663 | print("using cuda device: %s" % torch.cuda.current_device()) 664 | assert torch.cuda.is_available() 665 | model_r.cuda() 666 | checkpoint = torch.load(args.ilp_dir + args.relation_model_file) 667 | model_r.load_state_dict(checkpoint['state_dict']) 668 | preds, loss, true_labels, docs, pairs, ent_f1, nopred_rels = self.model.predict(self.model.model, 669 | test_data, 670 | args, 671 | test = True, 672 | gold = False, 673 | model_r = model_r) 674 | else: 675 | preds, loss, true_labels, docs, pairs, ent_f1, nopred_rels \ 676 | = self.model.predict(self.model.model, test_data, args, test = True, gold = args.eval_gold) 677 | 678 | preds = (preds.max(1)[1]).long().view(-1) 679 | 680 | pred_labels = preds.numpy().tolist() 681 | true_labels = true_labels.tolist() 682 | if not args.eval_gold: 683 | print(len(nopred_rels)) 684 | pred_labels.extend([self.model._label_to_id['NONE'] for _ in nopred_rels]) 685 | true_labels.extend(nopred_rels) 686 | 687 | rel_f1 = self.model.weighted_f1(pred_labels, true_labels, 0, 0, 0, rw=1.0) 688 | 689 | pred_labels = [self.model._id_to_label[x] for x in pred_labels] 690 | true_labels = [self.model._id_to_label[x] for x in true_labels] 691 | 692 | print(len(pred_labels), len(true_labels), len(pairs), len(docs)) 693 | out = ClassificationReport(args.model, true_labels, pred_labels) 694 | print(out) 695 | print("F1 Excluding Vague: %.4f" % rel_f1) 696 | return rel_f1, ent_f1 697 | 698 | def main(args): 699 | 700 | data_dir = args.data_dir 701 | opt_args = {} 702 | 703 | params = {'batch_size': args.batch, 704 | 'shuffle': False, 705 | 'collate_fn': pad_collate} 706 | 707 | type_dir = "/all_context/" 708 | test_data = EventDataset(args.data_dir + type_dir, "test") 709 | test_generator = data.DataLoader(test_data, **params) 710 | 711 | train_data = EventDataset(args.data_dir + type_dir, "train") 712 | train_generator = data.DataLoader(train_data, **params) 713 | 714 | dev_data = EventDataset(args.data_dir + type_dir, "dev") 715 | dev_generator = data.DataLoader(dev_data, **params) 716 | 717 | model = NNClassifier() 718 | print(f"======={args.model}=====\n") 719 | best_f1, best_epoch = model.train_epoch(train_generator, dev_generator, args) 720 | evaluator = EventEvaluator(model) 721 | rel_f1, ent_f1 = evaluator.evaluate(test_generator, args) 722 | print(rel_f1, ent_f1) 723 | 724 | if __name__ == '__main__': 725 | p = argparse.ArgumentParser() 726 | # arguments for data processing 727 | p.add_argument('-data_dir', type=str, default = '../data') 728 | p.add_argument('-other_dir', type=str, default = '../other') 729 | # select model 730 | p.add_argument('-model', type=str, default='multitask/pipeline')#, 'multitask/gold', 'multitask/pipeline' 731 | # arguments for RNN model 732 | p.add_argument('-emb', type=int, default=100) 733 | p.add_argument('-hid', type=int, default=100) 734 | p.add_argument('-num_layers', type=int, default=1) 735 | p.add_argument('-batch', type=int, default=2) 736 | p.add_argument('-data_type', type=str, default="matres") 737 | p.add_argument('-epochs', type=int, default=5) 738 | p.add_argument('-pipe_epoch', type=int, default=1000) # 1000: no pipeline training; otherwise <= epochs 739 | p.add_argument('-seed', type=int, default=123) 740 | p.add_argument('-lr', type=float, default=0.0005) 741 | p.add_argument('-num_classes', type=int, default=2) # get updated in main() 742 | p.add_argument('-dropout', type=float, default=0.1) 743 | p.add_argument('-ngbrs', type=int, default = 15) 744 | p.add_argument('-pos2idx', type=dict, default = {}) 745 | p.add_argument('-w2i', type=OrderedDict) 746 | p.add_argument('-glove', type=OrderedDict) 747 | p.add_argument('-cuda', action='store_true') 748 | p.add_argument('-refit_all', type=bool, default=False) 749 | p.add_argument('-uw', type=float, default=1.0) 750 | p.add_argument('-params', type=dict, default={}) 751 | p.add_argument('-n_splits', type=int, default=5) 752 | p.add_argument('-pred_win', type=int, default=200) 753 | p.add_argument('-n_fts', type=int, default=1) 754 | p.add_argument('-relation_weight', type=float, default=1.0) 755 | p.add_argument('-entity_weight', type=float, default=15.0) 756 | p.add_argument('-save_model', type=bool, default=False) 757 | p.add_argument('-save_stamp', type=str, default="matres_entity_best") 758 | p.add_argument('-entity_model_file', type=str, default="") 759 | p.add_argument('-relation_model_file', type=str, default="") 760 | p.add_argument('-load_model', type=bool, default=False) 761 | p.add_argument('-bert_config', type=dict, default={}) 762 | p.add_argument('-fine_tune', type=bool, default=False) 763 | p.add_argument('-eval_gold',type=bool, default=True) 764 | args = p.parse_args() 765 | args.save_stamp = "%s_hid%s_dropout%s_ew%s" % (args.save_stamp, args.hid, args.dropout, args.entity_weight) 766 | #args.eval_gold = True if args.pipe_epoch >= 1000 else False 767 | 768 | # if training with pipeline, ensure train / eval pipe epoch are the same 769 | #if args.pipe_epoch < 1000: 770 | # assert args.pipe_epoch == args.eval_pipe_epoch 771 | 772 | args.eval_list = [] 773 | args.data_dir += args.data_type 774 | 775 | # create pos_tag and vocabulary dictionaries 776 | # make sure raw data files are stored in the same directory as train/dev/test data 777 | tags = open(args.other_dir + "/pos_tags.txt") 778 | pos2idx = {} 779 | idx = 0 780 | for tag in tags: 781 | tag = tag.strip() 782 | pos2idx[tag] = idx 783 | idx += 1 784 | args.pos2idx = pos2idx 785 | 786 | args.idx2pos = {v+1:k for k,v in pos2idx.items()} 787 | 788 | print(args.hid, args.dropout, args.entity_weight, args.relation_weight) 789 | main(args) 790 | 791 | 792 | -------------------------------------------------------------------------------- /code/joint_model_global.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pickle 3 | import sys 4 | import argparse 5 | from collections import defaultdict, Counter, OrderedDict 6 | from itertools import combinations 7 | from typing import Iterator, List, Mapping, Union, Optional, Set 8 | import logging as log 9 | import abc 10 | from dataclasses import dataclass 11 | from datetime import datetime 12 | import numpy as np 13 | import random 14 | import torch 15 | import torch.autograd as autograd 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | from torch.nn import Parameter 20 | import math 21 | import time 22 | import copy 23 | import json 24 | from torch.autograd import Variable 25 | from torch.utils import data 26 | from torch.nn.utils.rnn import pack_padded_sequence as pack, pad_packed_sequence as unpack 27 | from featureFuncs import * 28 | from featurize_data import matres_label_map, tbd_label_map 29 | from functools import partial 30 | from sklearn.model_selection import KFold, ParameterGrid, train_test_split 31 | from joint_model import pad_collate, EventDataset, BertClassifier 32 | from gurobi_inference import Global_Inference 33 | from utils import ClassificationReport 34 | 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | torch.manual_seed(123) 39 | 40 | @dataclass() 41 | class NNClassifier(nn.Module): 42 | def __init__(self): 43 | super(NNClassifier, self).__init__() 44 | 45 | def predict(self, model, data, args, test=False, gold=False, model_r=None): 46 | model.eval() 47 | 48 | criterion = nn.CrossEntropyLoss() 49 | count = 1 50 | labels_r, probs_r, losses, losses_e = [], [], [], [] 51 | pred_inds = [] 52 | 53 | # stoare non-predicted rels in list 54 | nopred_rels = [] 55 | 56 | ent_pred_map, ent_label_map, ent_idx_map = {}, {}, {} 57 | all_ent_key, all_ent_pos, all_label_e, all_prob_e = [], [], [], [] 58 | 59 | all_pairs, all_label_r, all_prob_r = [], [], [] 60 | all_lidx_start, all_ridx_start = [], [] 61 | l_idx, r_idx, context_start = 0, 0, 0 62 | 63 | for doc_id, context_id, sents, ent_keys, ents, poss, rels, lengths in data: 64 | 65 | if args.cuda: 66 | sents = sents.cuda() 67 | ents = ents.cuda() 68 | 69 | ## predict entity first 70 | out_e, prob_e = model(sents, lengths, task='entity') 71 | ## construct candidate relations 72 | rel_label, fts, rel_idxs, doc_id, pairs, lidx_start, lidx_end, ridx_start, ridx_end, none_rel \ 73 | = self.construct_relations(prob_e, lengths, rels, list(doc_id), poss, gold=False, 74 | ent_thresh=args.ent_thresh, rel_thresh=args.rel_thresh, ent_keys=ent_keys, test=test) 75 | 76 | nopred_rels.extend(none_rel) 77 | 78 | if args.cuda: 79 | rel_label = rel_label.cuda() 80 | fts = fts.cuda() 81 | 82 | all_label_r.append(rel_label) 83 | all_pairs.extend([(doc + "_" + x, doc + "_" + y) for doc, (x, y) in zip(doc_id, pairs)]) 84 | 85 | ## predict relations 86 | out_r, prob_r = model(sents, lengths, fts=fts, rel_idxs=rel_idxs, lidx_start=lidx_start, 87 | lidx_end=lidx_end, ridx_start=ridx_start, ridx_end=ridx_end) 88 | 89 | loss_e = [] 90 | ## global inference for each unique context 91 | b_str = 0 92 | pred_ent_labels, pred_rel_labels = [], [] 93 | for b, l in enumerate(lengths): 94 | all_lidx_start.extend([x + context_start for x in lidx_start[b]]) 95 | all_ridx_start.extend([x + context_start for x in ridx_start[b]]) 96 | context_start += l 97 | 98 | all_prob_e.append(prob_e[b, :l, :]) 99 | all_label_e.append(ents[b, :l]) 100 | 101 | # flatten entity key - a list of original (extend) 102 | assert len(ent_keys[b]) == l 103 | all_ent_key.extend([p for p in ent_keys[b]]) 104 | # flatten pos tags 105 | all_ent_pos.extend([p for p in poss[b]]) 106 | all_prob_r.append(prob_r) 107 | 108 | all_prob_e = torch.cat(all_prob_e) 109 | all_label_e = torch.cat(all_label_e) 110 | assert all_prob_e.size(0) == all_label_e.size(0) 111 | assert all_label_e.size(0) == len(all_ent_key) 112 | 113 | all_label_r = torch.cat(all_label_r) 114 | all_prob_r = torch.cat(all_prob_r) 115 | assert len(all_pairs) == all_prob_r.size(0) 116 | assert len(all_pairs) == all_label_r.size(0) 117 | assert len(all_pairs) == len(all_lidx_start) 118 | assert len(all_pairs) == len(all_ridx_start) 119 | 120 | # global inference for relation with transitivity 121 | best_pred_idx_e, best_pred_idx_r, pred_ent_labels, pred_rel_labels \ 122 | = self.global_prediction(all_prob_e, all_prob_r, all_lidx_start, 123 | all_ridx_start, all_pairs, args.entity_weight, evaluate=True) 124 | # Compute Loss 125 | loss_r = self.loss_func_rel(best_pred_idx_r, all_label_r, all_prob_r, args.margin) 126 | loss_e = self.loss_func_ent(best_pred_idx_e, all_label_e, all_prob_e, args.margin) 127 | loss = args.relation_weight * loss_r + args.entity_weight * loss_e 128 | 129 | all_label_e = all_label_e.tolist() 130 | for i, v in enumerate(all_ent_key): 131 | label_e = all_label_e[i] 132 | 133 | # exclude sent_start and sent_sep 134 | if v in ["[SEP]", "[CLS]"]: 135 | assert all_ent_pos[i] in ["[SEP]", "[CLS]"] 136 | 137 | if v not in ent_pred_map: 138 | # store global assignments 139 | ent_pred_map[v] = [pred_ent_labels[i]] 140 | ent_label_map[v] = (label_e, all_ent_pos[i]) 141 | else: 142 | # if key stored already, append another prediction 143 | ent_pred_map[v].append(pred_ent_labels[i]) 144 | # and ensure label is the same 145 | assert ent_label_map[v][0] == label_e 146 | assert ent_label_map[v][1] == all_ent_pos[i] 147 | 148 | assert all_label_r.size(0) == len(pred_rel_labels) 149 | 150 | # calculate entity F1 score here 151 | # update ent_pred_map with [mean > 0.5 --> 1] 152 | 153 | ent_pred_map_agg = {k:1 if np.mean(v) >= 0.5 else 0 for k,v in ent_pred_map.items()} 154 | #ent_pred_map_agg = {k:max(v) for k,v in ent_pred_map.items()} 155 | n_correct = 0 156 | n_pred = 0 157 | 158 | pos_keys = OrderedDict([(k, v) for k, v in ent_label_map.items() if v[0] == 1]) 159 | n_true = len(pos_keys) 160 | 161 | for k,v in ent_label_map.items(): 162 | if ent_pred_map_agg[k] == 1: 163 | n_pred += 1 164 | if ent_pred_map_agg[k] == 1 and ent_label_map[k][0] == 1: 165 | n_correct += 1 166 | 167 | print(n_pred, n_true, n_correct) 168 | 169 | def safe_division(numr, denr, on_err=0.0): 170 | return on_err if denr == 0.0 else float(numr) / float(denr) 171 | 172 | precision = safe_division(n_correct, n_pred) 173 | recall = safe_division(n_correct, n_true) 174 | f1_score = safe_division(2.0 * precision * recall, precision + recall) 175 | 176 | print("Evaluation temporal relation loss: %.4f" % loss_r.data) 177 | 178 | print("Evaluation temporal entity loss: %.4f; F1: %.4f" % (loss_e.data, f1_score)) 179 | 180 | if test: 181 | return pred_rel_labels, all_label_r, f1_score, nopred_rels 182 | else: 183 | return pred_rel_labels, all_label_r, n_pred, n_true, n_correct, nopred_rels 184 | 185 | def construct_relations(self, prob_e, lengths, rels, doc, poss, gold=True, ent_thresh = 0.1, rel_thresh = 0.5, ent_keys=[], test=False): 186 | # many relation properties such rev and pred_ind are not used for now 187 | 188 | ## Case 1: only use gold relation 189 | if gold: 190 | pred_rels = rels 191 | 192 | ## Case 2: use candidate relation predicted by entity model 193 | else: 194 | def _is_gold(pred_span, gold_rel_span): 195 | return ((gold_rel_span[0] <= pred_span <= gold_rel_span[1])) 196 | 197 | # gold labels not predicted, should be none, sanity check 198 | nopred_rels = [] 199 | 200 | batch_size = len(lengths) 201 | 202 | # eliminate ent_pred > context length 203 | # add filter for events with certain pos tags (based on train set) 204 | #include_pos = [6, 11, 12, 26, 27, 28, 29, 30, 31] # tbd 205 | include_pos = [26, 27, 28, 29, 30, 31] # matres 206 | ent_locs = [[x for x in range(l) if poss[b][x] in include_pos and prob_e[b, x, 1] > ent_thresh] 207 | for b,l in enumerate(lengths)] 208 | 209 | # all possible relation candiate based on pred_ent 210 | rel_locs = [list(combinations(el, 2)) for el in ent_locs] 211 | 212 | pred_rels = [] 213 | totl = 0 214 | # use the smallest postive sample id as start of neg id 215 | # this may not be perfect, but we really don't care about neg id 216 | neg_counter = min([int(x[0][1:]) for rel in rels for x in rel]) 217 | 218 | for i, rl in enumerate(rel_locs): 219 | temp_rels, temp_ids = [], [] 220 | for r in rl: 221 | # filtered with both events has local prob < 0.5 222 | if prob_e[i, r[0], 1] < rel_thresh and prob_e[i, r[1], 1] < rel_thresh: 223 | #print(i, r) 224 | continue 225 | sent_segs = len([x for x in poss[i] if x == '[SEP]']) 226 | in_seg = [x for x in poss[i][r[0] : r[1]] if x == '[SEP]'] 227 | ### exclude rel that are in the same sentence, but two segments exist. i.e. unique input context 228 | if (sent_segs > 1) and (len(in_seg) == 0): 229 | continue 230 | else: 231 | totl += 1 232 | gold_match = [x for x in rels[i] if _is_gold(r[0], x[5][:2]) and _is_gold(r[1], x[5][2:])] 233 | # multiple tokens could indicate the same events. 234 | # simple pick the one occurs first 235 | if len(gold_match) > 0 and gold_match[0][0] not in temp_ids: 236 | temp_rels.append(gold_match[0]) 237 | temp_ids.append(gold_match[0][0]) 238 | else: 239 | ## construct a negative relation pair -- 'NONE' 240 | neg_id = 'N%s' % neg_counter 241 | left_match = [x for x in rels[i] if _is_gold(r[0], x[5][:2])] 242 | right_match = [x for x in rels[i] if _is_gold(r[1], x[5][2:])] 243 | # provide a random but unique id for event predicted if not matched in gold 244 | left_id = left_match[0][1][0] if len(left_match) > 0 else ('n%s' % (neg_counter + 10000)) 245 | right_id = right_match[0][1][1] if len(right_match) > 0 else ('n%s' % (neg_counter + 20000)) 246 | a_rel = (neg_id, (left_id, right_id), self._label_to_id['NONE'], 247 | [float(r[1] - r[0])], False, (r[0], r[0], r[1], r[1]), True) 248 | temp_rels.append(a_rel) 249 | neg_counter += 1 250 | 251 | nopred_rels.extend([x[2] for x in rels[i] if x[0] not in [tr[0] for tr in temp_rels]]) 252 | 253 | pred_rels.append(temp_rels) 254 | 255 | # relations are (flatten) lists of features 256 | # rel_idxs indicates (batch_id, rel_in_batch_id) 257 | docs, pairs = [], [] 258 | rel_idxs, lidx_start, lidx_end, ridx_start, ridx_end = [],[],[],[],[] 259 | for i, rel in enumerate(pred_rels): 260 | rel_idxs.extend([(i, ii) for ii, _ in enumerate(rel)]) 261 | lidx_start.append([x[5][0] for x in rel]) 262 | lidx_end.append([x[5][1] for x in rel]) 263 | ridx_start.append([x[5][2] for x in rel]) 264 | ridx_end.append([x[5][3] for x in rel]) 265 | pairs.extend([x[1] for x in rel]) 266 | docs.extend([doc[i] for _ in rel]) 267 | assert len(docs) == len(pairs) 268 | 269 | rels = [x for rel in pred_rels for x in rel] 270 | if rels == []: 271 | labels = torch.FloatTensor([]) 272 | fts = torch.FloatTensor([]) 273 | else: 274 | labels = torch.LongTensor([x[2] for x in rels]) 275 | fts = torch.cat([torch.FloatTensor(x[3]) for x in rels]).unsqueeze(1) 276 | 277 | return labels, fts, rel_idxs, docs, pairs, lidx_start, lidx_end, ridx_start, ridx_end, nopred_rels 278 | 279 | 280 | def global_prediction(self, prob_table_ents, prob_table_rels, lidx, ridx, pairs, ew, evaluate=False, true_labels=[]): 281 | # input (for each context): 282 | # 1. prob_table_ents: (context) local event predictions: N * 2, N: number of entities 283 | # 2. prob_table_rels: (context) local rel candidate predictions: M (# of can rels) * R (# of rel class) 284 | # 3. left idx: a vector of length M - left entity index 285 | # 4. right idx: a vector of length M - right entity index 286 | # 5. pairs: relation pairs for transivity rule 287 | # output: 288 | # 1. global_ent_idx: best global assignment of entity in matrix form 289 | # 2. global_rel_idx: best global assignment of relation in matrix form 290 | 291 | # initialize entity table 292 | N, Nc = prob_table_ents.shape 293 | global_ent_idx = np.zeros((N, Nc), dtype=int) 294 | 295 | # initialize relation table 296 | M, Mc = prob_table_rels.shape 297 | global_rel_idx = np.zeros((M, Mc), dtype=int) 298 | 299 | cand_rels = list(zip(lidx, ridx)) 300 | assert M == len(cand_rels) 301 | 302 | global_model = Global_Inference(prob_table_ents.detach().numpy(), 303 | prob_table_rels.detach().numpy(), 304 | cand_rels, self._label_to_id, pairs, ew) 305 | global_model.run() 306 | global_model.predict() 307 | 308 | # entity global assignment 309 | for n in range(N): 310 | global_ent_idx[n, global_model.pred_ent_labels[n]] = 1 311 | 312 | # relation global assignment 313 | for m in range(M): 314 | global_rel_idx[m, global_model.pred_rel_labels[m]] = 1 315 | 316 | if evaluate: 317 | #assert len(true_labels) == N + M 318 | #global_model.evaluate(true_labels) 319 | return global_ent_idx, global_rel_idx, global_model.pred_ent_labels, global_model.pred_rel_labels 320 | else: 321 | return global_ent_idx, global_rel_idx 322 | 323 | def _train(self, train_data, eval_data, pos_emb, args): 324 | 325 | model = BertClassifier(args) 326 | 327 | if args.cuda: 328 | print("using cuda device: %s" % torch.cuda.current_device()) 329 | assert torch.cuda.is_available() 330 | model.cuda() 331 | 332 | #optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) 333 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 334 | lr = args.lr, momentum=args.momentum, weight_decay=args.decay) 335 | #criterion = nn.CrossEntropyLoss() 336 | losses = [] 337 | 338 | sents, poss, ftss, labels = [], [], [], [] 339 | if args.load_model == True: 340 | checkpoint = torch.load(args.load_model_file, map_location='cpu') 341 | model.load_state_dict(checkpoint['state_dict']) 342 | epoch = checkpoint['epoch'] 343 | best_eval_f1 = checkpoint['f1'] 344 | print("Local best eval f1 is: %s" % best_eval_f1) 345 | 346 | best_eval_f1 = 0.0 347 | best_epoch = 0 348 | 349 | for epoch in range(args.epochs): 350 | print("Training Epoch #%s..." % epoch) 351 | model.train() 352 | count = 1 353 | 354 | loss_hist_t, loss_hist_e = [], [] 355 | 356 | start_time = time.time() 357 | 358 | gold = False #if epoch > args.pipe_epoch else True 359 | #event_pos_counter = Counter() 360 | 361 | all_pairs, all_label_r, all_prob_r = [], [], [] 362 | all_label_e, all_prob_e = [], [] 363 | 364 | # record unique context start idx for both rel and ent 365 | #context_ent, context_rel = [], [] 366 | all_lidx_start, all_ridx_start = [], [] 367 | l_idx, r_idx, context_start = 0, 0, 0 368 | for doc_id, context_id, sents, keys, ents, poss, rels, lengths in train_data: 369 | 370 | if args.cuda: 371 | sents = sents.cuda() 372 | ents = ents.cuda() 373 | 374 | model.zero_grad() 375 | 376 | ## entity detection 377 | out_e, prob_e = model(sents, lengths, task='entity') 378 | 379 | ## construct candidate relations 380 | rel_label, fts, rel_idxs, doc_id, pairs, lidx_start, lidx_end, ridx_start, ridx_end, non_preds \ 381 | = self.construct_relations(prob_e, lengths, rels, list(doc_id), poss, gold=False) 382 | 383 | if args.cuda: 384 | rel_label = rel_label.cuda() 385 | fts = fts.cuda() 386 | 387 | all_label_r.append(rel_label) 388 | all_pairs.extend([(doc + "_" + x, doc + "_" + y) for doc, (x, y) in zip(doc_id, pairs)]) 389 | 390 | ## predict relations 391 | out_r, prob_r = model(sents, lengths, fts=fts, rel_idxs=rel_idxs, lidx_start=lidx_start, 392 | lidx_end=lidx_end, ridx_start=ridx_start, ridx_end=ridx_end) 393 | 394 | ## collect all unique contexts for joint inference 395 | for b, l in enumerate(lengths): 396 | all_lidx_start.extend([x + context_start for x in lidx_start[b]]) 397 | all_ridx_start.extend([x + context_start for x in ridx_start[b]]) 398 | context_start += l 399 | 400 | all_prob_e.append(prob_e[b, :l, :]) 401 | all_label_e.append(ents[b, :l]) 402 | 403 | all_prob_r.append(prob_r) 404 | count += 1 405 | 406 | all_prob_e = torch.cat(all_prob_e) 407 | all_label_e = torch.cat(all_label_e) 408 | assert all_prob_e.size(0) == all_label_e.size(0) 409 | 410 | all_label_r = torch.cat(all_label_r) 411 | all_prob_r = torch.cat(all_prob_r) 412 | assert len(all_pairs) == all_prob_r.size(0) 413 | assert len(all_pairs) == all_label_r.size(0) 414 | assert len(all_pairs) == len(all_lidx_start) 415 | assert len(all_pairs) == len(all_ridx_start) 416 | 417 | # global inference for relation with transitivity 418 | best_pred_idx_e, best_pred_idx_r = self.global_prediction(all_prob_e, all_prob_r, all_lidx_start, 419 | all_ridx_start, all_pairs, args.entity_weight) 420 | loss_r = self.loss_func_rel(best_pred_idx_r, all_label_r, all_prob_r, args.margin) 421 | loss_e = self.loss_func_ent(best_pred_idx_e, all_label_e, all_prob_e, args.margin) 422 | 423 | # combine 424 | #loss = args.relation_weight * loss_r + args.entity_weight * loss_e 425 | loss = loss_r + loss_e 426 | loss.backward() 427 | 428 | optimizer.step() 429 | 430 | if args.cuda: 431 | if loss_r != 0: 432 | loss_hist_t.append(loss_r.data.cpu().numpy()) 433 | loss_hist_e.append(loss_e.data.cpu().numpy()) 434 | else: 435 | if loss_r != 0: 436 | loss_hist_t.append(loss_r.data.numpy()) 437 | loss_hist_e.append(loss_e.data.numpy()) 438 | print("Temporal loss is %.4f" % np.mean(loss_hist_t)) 439 | print("Entity loss is %.4f" % np.mean(loss_hist_e)) 440 | print("%.4f seconds elapsed" % (time.time() - start_time)) 441 | # Evaluate at the end of each epoch 442 | print("*"*50) 443 | 444 | if len(eval_data) > 0: 445 | 446 | # need to have a warm-start otherwise there could be no event_pred 447 | # may need to manually pick poch < #, but 0 generally works when ew is large 448 | #eval_gold = True if epoch == 0 else args.eval_gold 449 | eval_gold = gold 450 | pred_labels, eval_labels, ent_pred, ent_true, ent_corr, nopred_rels = self.predict(model, eval_data, args, gold=eval_gold) 451 | 452 | eval_labels = list(eval_labels.numpy()) 453 | assert len(eval_labels) == len(pred_labels) 454 | 455 | pred_labels.extend([self._label_to_id['NONE'] for _ in nopred_rels]) 456 | eval_labels.extend(nopred_rels) 457 | # select model only based on entity + relation F1 score 458 | eval_f1 = self.weighted_f1(pred_labels, eval_labels, ent_corr, ent_pred, ent_true, 459 | args.relation_weight, args.entity_weight) 460 | 461 | # args.pipe_epoch <= args.epochs if pipeline (joint) training is used 462 | if eval_f1 > best_eval_f1 and (epoch > args.pipe_epoch or args.pipe_epoch >= 1000): 463 | best_eval_f1 = eval_f1 464 | self.model = copy.deepcopy(model) 465 | best_epoch = epoch 466 | 467 | print("Evaluation F1: %.4f" % (eval_f1)) 468 | print("*"*50) 469 | 470 | print("Final Evaluation F1: %.4f at Epoch %s" % (best_eval_f1, best_epoch)) 471 | print("*"*50) 472 | 473 | if args.epochs == 0: 474 | pred_labels, eval_labels, ent_pred, ent_true, ent_corr, nopred_rels = self.predict(model, eval_data, args, gold=False) 475 | 476 | eval_labels = list(eval_labels.numpy()) 477 | assert len(eval_labels) == len(pred_labels) 478 | 479 | pred_labels.extend([self._label_to_id['NONE'] for _ in nopred_rels]) 480 | eval_labels.extend(nopred_rels) 481 | # select model only based on entity + relation F1 score 482 | eval_f1 = self.weighted_f1(pred_labels, eval_labels, ent_corr, ent_pred, ent_true, 483 | args.relation_weight, args.entity_weight) 484 | 485 | # args.pipe_epoch <= args.epochs if pipeline (joint) training is used 486 | if eval_f1 > best_eval_f1 and (epoch > args.pipe_epoch or args.pipe_epoch >= 1000): 487 | best_eval_f1 = eval_f1 488 | self.model = copy.deepcopy(model) 489 | best_epoch = epoch 490 | 491 | if args.save_model == True: 492 | torch.save({'epoch': epoch, 493 | 'args': args, 494 | 'state_dict': self.model.state_dict(), 495 | 'f1': best_eval_f1, 496 | 'optimizer' : optimizer.state_dict() 497 | }, "%s%s.pth.tar" % (args.ilp_dir, args.save_stamp)) 498 | 499 | return best_eval_f1, best_epoch 500 | 501 | def loss_func_ent(self, best_pred_idx_e, all_label_e, prob_e, margin): 502 | 503 | ## max prediction scores 504 | mask_e = torch.ByteTensor(best_pred_idx_e) 505 | 506 | assert mask_e.size() == prob_e.size() 507 | 508 | max_score_e = torch.masked_select(prob_e, mask_e) 509 | 510 | #globalNlocal = (probs.data.max(1)[0].view(-1) != max_scores.data.view(-1)).numpy() 511 | 512 | ## Entity true label scores 513 | N, Nc = prob_e.size() 514 | idx_mat_e = np.zeros((N, Nc), dtype=int) 515 | 516 | for n in range(N): 517 | idx_mat_e[n][all_label_e[n]] = 1 518 | mask_e = torch.ByteTensor(idx_mat_e) 519 | assert mask_e.size() == prob_e.size() 520 | label_score_e = torch.masked_select(prob_e, mask_e) 521 | 522 | ## Entity SSVM loss 523 | # distance measure: try Hamming Distance later 524 | #delta = torch.FloatTensor([margin for _ in range(N)]) 525 | delta = Variable(torch.FloatTensor([0.00000001 if label_score_e[n].data == max_score_e[n].data else margin for n in range(N)]), requires_grad=True) 526 | diff = delta + max_score_e - label_score_e 527 | 528 | # loss should be non-negative 529 | losses_e = [] 530 | for n in range(N): 531 | if diff[n].data.numpy() <= 0.0: 532 | losses_e.append(Variable(torch.FloatTensor([0.0]))) 533 | else: 534 | losses_e.append(diff[n].reshape(1,)) 535 | 536 | return torch.mean(torch.cat(losses_e)) 537 | 538 | 539 | def loss_func_rel(self, best_pred_idx_r, all_label_r, prob_r, margin): 540 | 541 | mask_r = torch.ByteTensor(best_pred_idx_r) 542 | assert mask_r.size() == prob_r.size() 543 | max_score_r = torch.masked_select(prob_r, mask_r) 544 | 545 | ## Relation true label scores 546 | M, Mc = prob_r.size() 547 | 548 | idx_mat_r = np.zeros((M, Mc), dtype=int) 549 | 550 | for m in range(M): 551 | idx_mat_r[m][all_label_r[m]] = 1 552 | 553 | mask_r = torch.ByteTensor(idx_mat_r) 554 | assert mask_r.size() == prob_r.size() 555 | label_score_r = torch.masked_select(prob_r, mask_r) 556 | 557 | ## Relation loss 558 | #delta = torch.FloatTensor([margin for _ in range(M)]) 559 | delta = Variable(torch.FloatTensor([0.00000001 if label_score_r[m].data == max_score_r[m].data else margin for m in range(M)]), requires_grad=True) 560 | diff = delta + max_score_r - label_score_r 561 | 562 | count = 0 563 | losses_r = [] 564 | for m in range(M): 565 | if diff[m].data.numpy() <= 0.0: 566 | losses_r.append(Variable(torch.FloatTensor([0.0]))) 567 | else: 568 | count += 1 569 | losses_r.append(diff[m].reshape(1,)) 570 | 571 | return torch.mean(torch.cat(losses_r)) 572 | 573 | 574 | def train_epoch(self, train_data, dev_data, args, test_data = None): 575 | 576 | if args.data_type == "matres": 577 | label_map = matres_label_map 578 | if args.data_type == "tbd": 579 | label_map = tbd_label_map 580 | 581 | assert len(label_map) > 0 582 | 583 | all_labels = list(OrderedDict.fromkeys(label_map.values())) 584 | ## append negative pair label 585 | all_labels.append('NONE') 586 | 587 | if args.joint: 588 | label_map_c = causal_label_map 589 | # in order to perserve order of unique keys 590 | all_labels_c = list(OrderedDict.fromkeys(label_map_c.values())) 591 | self._label_to_id_c = OrderedDict([(all_labels_c[l],l) for l in range(len(all_labels_c))]) 592 | self._id_to_label_c = OrderedDict([(l,all_labels_c[l]) for l in range(len(all_labels_c))]) 593 | print(self._label_to_id_c) 594 | print(self._label_to_id_c) 595 | 596 | self._label_to_id = OrderedDict([(all_labels[l],l) for l in range(len(all_labels))]) 597 | self._id_to_label = OrderedDict([(l,all_labels[l]) for l in range(len(all_labels))]) 598 | 599 | print(self._label_to_id) 600 | print(self._id_to_label) 601 | 602 | args.label_to_id = self._label_to_id 603 | 604 | ### pos embdding is not used for now, but can be added later 605 | pos_emb= np.zeros((len(args.pos2idx) + 1, len(args.pos2idx) + 1)) 606 | for i in range(pos_emb.shape[0]): 607 | pos_emb[i, i] = 1.0 608 | 609 | best_f1, best_epoch = self._train(train_data, dev_data, pos_emb, args) 610 | print("Final Dev F1: %.4f" % best_f1) 611 | return best_f1, best_epoch 612 | 613 | def weighted_f1(self, pred_labels, true_labels, ent_corr, ent_pred, ent_true, rw=0.0, ew=0.0): 614 | def safe_division(numr, denr, on_err=0.0): 615 | return on_err if denr == 0.0 else numr / denr 616 | 617 | assert len(pred_labels) == len(true_labels) 618 | 619 | weighted_f1_scores = {} 620 | if 'NONE' in self._label_to_id.keys(): 621 | num_tests = len([x for x in true_labels if x != self._label_to_id['NONE']]) 622 | else: 623 | num_tests = len([x for x in true_labels]) 624 | 625 | print("Total relation to evaluate: %s" % len(true_labels)) 626 | print("Total positive relation samples to eval: %s" % num_tests) 627 | total_true = Counter(true_labels) 628 | total_pred = Counter(pred_labels) 629 | 630 | labels = list(self._id_to_label.keys()) 631 | 632 | n_correct = 0 633 | n_true = 0 634 | n_pred = 0 635 | 636 | if rw > 0: 637 | # f1 score is used for tcr and matres and hence exclude vague 638 | exclude_labels = ['NONE', 'VAGUE'] if len(self._label_to_id) == 5 else ['NONE'] 639 | 640 | for label in labels: 641 | if self._id_to_label[label] not in exclude_labels: 642 | 643 | true_count = total_true.get(label, 0) 644 | pred_count = total_pred.get(label, 0) 645 | 646 | n_true += true_count 647 | n_pred += pred_count 648 | 649 | correct_count = len([l for l in range(len(pred_labels)) 650 | if pred_labels[l] == true_labels[l] and pred_labels[l] == label]) 651 | n_correct += correct_count 652 | if ew > 0: 653 | # add entity prediction results before calculating precision, recall and f1 654 | n_correct += ent_corr 655 | n_pred += ent_pred 656 | n_true += ent_true 657 | 658 | precision = safe_division(n_correct, n_pred) 659 | recall = safe_division(n_correct, n_true) 660 | f1_score = safe_division(2.0 * precision * recall, precision + recall) 661 | print("Overall Precision: %.4f\tRecall: %.4f\tF1: %.4f" % (precision, recall, f1_score)) 662 | 663 | return(f1_score) 664 | 665 | @dataclass 666 | class EventEvaluator: 667 | def __init__(self, model): 668 | self.model = model 669 | 670 | def evaluate(self, test_data, args): 671 | # load test data first since it needs to be executed twice in this function 672 | print("start testing...") 673 | 674 | pred_labels, true_labels, ent_f1, nopred_rels \ 675 | = self.model.predict(self.model.model, test_data, args, test = True, gold = False) 676 | 677 | 678 | pred_labels.extend([self.model._label_to_id['NONE'] for _ in nopred_rels]) 679 | true_labels = true_labels.tolist() 680 | true_labels.extend(nopred_rels) 681 | 682 | rel_f1 = self.model.weighted_f1(pred_labels, true_labels, 0, 0, 0, rw=1.0) 683 | 684 | print("Gold pairs labled as None: %s" % len(nopred_rels)) 685 | 686 | pred_labels = [self.model._id_to_label[x] for x in pred_labels] 687 | true_labels = [self.model._id_to_label[x] for x in true_labels] 688 | 689 | print(len(pred_labels), len(true_labels)) 690 | 691 | out = ClassificationReport(args.model, true_labels, pred_labels) 692 | print(out) 693 | print("F1 Excluding Vague: %.4f" % rel_f1) 694 | return rel_f1, ent_f1 695 | 696 | def main(args): 697 | 698 | data_dir = args.data_dir 699 | opt_args = {} 700 | 701 | params = {'batch_size': args.batch, 702 | 'shuffle': False, 703 | 'collate_fn': pad_collate} 704 | 705 | type_dir = "/all_context/" 706 | test_data = EventDataset(args.data_dir + type_dir, "test") 707 | test_generator = data.DataLoader(test_data, **params) 708 | 709 | train_data = EventDataset(args.data_dir + type_dir, "train") 710 | train_generator = data.DataLoader(train_data, **params) 711 | 712 | dev_data = EventDataset(args.data_dir + type_dir, "dev") 713 | dev_generator = data.DataLoader(dev_data, **params) 714 | 715 | model = NNClassifier() 716 | print(f"======={args.model}=====\n") 717 | best_f1, best_epoch = model.train_epoch(train_generator, dev_generator, args) 718 | evaluator = EventEvaluator(model) 719 | rel_f1, ent_f1 = evaluator.evaluate(test_generator, args) 720 | 721 | print(rel_f1, ent_f1) 722 | 723 | return 724 | if __name__ == '__main__': 725 | p = argparse.ArgumentParser() 726 | # arguments for data processing 727 | p.add_argument('-data_dir', type=str, default = '../data') 728 | p.add_argument('-other_dir', type=str, default = '../other') 729 | # select model 730 | p.add_argument('-model', type=str, default='joint/global') #'multitask/gold', 'multitask/pipeline' 731 | # arguments for RNN model 732 | p.add_argument('-emb', type=int, default=100) 733 | p.add_argument('-hid', type=int, default=90) 734 | p.add_argument('-num_layers', type=int, default=1) 735 | p.add_argument('-batch', type=int, default=1) 736 | p.add_argument('-data_type', type=str, default="red") 737 | p.add_argument('-epochs', type=int, default=0) 738 | p.add_argument('-pipe_epoch', type=int, default=1000) # 1000: no pipeline training; otherwise <= epochs 739 | p.add_argument('-seed', type=int, default=123) 740 | p.add_argument('-lr', type=float, default=0.1) # 0.0005, 0.001 741 | p.add_argument('-num_classes', type=int, default=2) # get updated in main() 742 | p.add_argument('-dropout', type=float, default=0.6) 743 | p.add_argument('-pos2idx', type=dict, default = {}) 744 | p.add_argument('-w2i', type=OrderedDict) 745 | p.add_argument('-glove', type=OrderedDict) 746 | p.add_argument('-cuda', action='store_true') 747 | p.add_argument('-params', type=dict, default={}) 748 | p.add_argument('-n_fts', type=int, default=1) 749 | p.add_argument('-relation_weight', type=float, default=1.0) 750 | p.add_argument('-entity_weight', type=float, default=1.0) 751 | p.add_argument('-save_model', type=bool, default=False) 752 | p.add_argument('-save_stamp', type=str, default="relation_best") 753 | p.add_argument('-load_model_file', type=str, default="matres_pipeline_best.pt") 754 | p.add_argument('-joint', type=bool, default=False) # Note: this is for tcr causal pairs 755 | p.add_argument('-load_model', type=bool, default=True) 756 | p.add_argument('-num_causal', type=int, default=2) 757 | p.add_argument('-bert_config', type=dict, default={}) 758 | p.add_argument('-loss_u', type=str, default="") 759 | p.add_argument('-fine_tune', type=bool, default=False) 760 | p.add_argument('-eval_with_timex', type=str, default=False) 761 | p.add_argument('-eval_gold',type=bool, default=False) 762 | p.add_argument('-margin', type=float, default=0.3) 763 | p.add_argument('-momentum', type=float, default=0.9) 764 | p.add_argument('-decay', type=float, default=0.9) 765 | p.add_argument('-ent_thresh',type=float, default=0.2) 766 | p.add_argument('-rel_thresh',type=float, default=0.5) 767 | args = p.parse_args() 768 | 769 | #args.eval_gold = True if args.pipe_epoch >= 1000 else False 770 | 771 | # if training with pipeline, ensure train / eval pipe epoch are the same 772 | #if args.pipe_epoch < 1000: 773 | # assert args.pipe_epoch == args.eval_pipe_epoch 774 | 775 | args.data_dir += args.data_type 776 | # create pos_tag and vocabulary dictionaries 777 | # make sure raw data files are stored in the same directory as train/dev/test data 778 | tags = open(args.other_dir + "/pos_tags.txt") 779 | pos2idx = {} 780 | idx = 0 781 | for tag in tags: 782 | tag = tag.strip() 783 | pos2idx[tag] = idx 784 | idx += 1 785 | args.pos2idx = pos2idx 786 | 787 | args.idx2pos = {v+1:k for k,v in pos2idx.items()} 788 | 789 | args.bert_config = { 790 | "attention_probs_dropout_prob": 0.1, 791 | "hidden_act": "gelu", 792 | "hidden_dropout_prob": 0.1, 793 | "hidden_size": 768, 794 | "initializer_range": 0.02, 795 | "intermediate_size": 3072, 796 | "max_position_embeddings": 512, 797 | "num_attention_heads": 12, 798 | "num_hidden_layers": 12, 799 | "type_vocab_size": 2, 800 | "vocab_size_or_config_json_file": 30522 801 | } 802 | print(args.momentum, args.decay) 803 | main(args) 804 | 805 | 806 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Mapping, Union, Optional, Set 2 | from collections import defaultdict, Counter, OrderedDict 3 | from datetime import datetime 4 | 5 | class ClassificationReport: 6 | def __init__(self, name, true_labels: List[Union[int, str]], 7 | pred_labels: List[Union[int, str]]): 8 | 9 | assert len(true_labels) == len(pred_labels) 10 | self.num_tests = len([x for x in true_labels if x != 'NONE']) 11 | self.total_truths = Counter(true_labels) 12 | self.total_predictions = Counter(pred_labels) 13 | self.name = name 14 | self.labels = sorted(set(true_labels) | set(pred_labels)) 15 | self.confusion_mat = self.confusion_matrix(true_labels, pred_labels) 16 | self.accuracy = sum(y == y_ for y, y_ in zip(true_labels, pred_labels)) / len(true_labels) 17 | self.trim_label_width = 15 18 | self.rel_f1 = 0.0 19 | 20 | @staticmethod 21 | def confusion_matrix(true_labels: List[str], predicted_labels: List[str]) \ 22 | -> Mapping[str, Mapping[str, int]]: 23 | mat = defaultdict(lambda: defaultdict(int)) 24 | for truth, prediction in zip(true_labels, predicted_labels): 25 | mat[truth][prediction] += 1 26 | return mat 27 | 28 | def __repr__(self): 29 | res = f'Name: {self.name}\t Created: {datetime.now().isoformat()}\t' 30 | res += f'Total Labels: {len(self.labels)} \t Total Tests: {self.num_tests}\n' 31 | display_labels = [label[:self.trim_label_width] for label in self.labels] 32 | label_widths = [len(l) + 1 for l in display_labels] 33 | max_label_width = max(label_widths) 34 | header = [l.ljust(w) for w, l in zip(label_widths, display_labels)] 35 | header.insert(0, ''.ljust(max_label_width)) 36 | res += ''.join(header) + '\n' 37 | for true_label, true_disp_label in zip(self.labels, display_labels): 38 | predictions = self.confusion_mat[true_label] 39 | row = [true_disp_label.ljust(max_label_width)] 40 | for pred_label, width in zip(self.labels, label_widths): 41 | row.append(str(predictions[pred_label]).ljust(width)) 42 | res += ''.join(row) + '\n' 43 | res += '\n' 44 | 45 | def safe_division(numr, denr, on_err=0.0): 46 | return on_err if denr == 0.0 else numr / denr 47 | 48 | def num_to_str(num): 49 | return '0' if num == 0 else str(num) if type(num) is int else f'{num:.4f}' 50 | 51 | n_correct = 0 52 | n_true = 0 53 | n_pred = 0 54 | 55 | all_scores = [] 56 | header = ['Total ', 'Predictions', 'Correct', 'Precision', 'Recall ', 'F1-Measure'] 57 | res += ''.ljust(max_label_width + 2) + ' '.join(header) + '\n' 58 | head_width = [len(h) for h in header] 59 | 60 | for label, width, display_label in zip(self.labels, label_widths, display_labels): 61 | if label not in ['NONE']: 62 | total_count = self.total_truths.get(label, 0) 63 | pred_count = self.total_predictions.get(label, 0) 64 | 65 | #if label != 'VAGUE': 66 | n_true += total_count 67 | 68 | n_pred += pred_count 69 | 70 | correct_count = self.confusion_mat[label][label] 71 | n_correct += correct_count 72 | 73 | precision = safe_division(correct_count, pred_count) 74 | recall = safe_division(correct_count, total_count) 75 | f1_score = safe_division(2 * precision * recall, precision + recall) 76 | all_scores.append((precision, recall, f1_score)) 77 | 78 | row = [total_count, pred_count, correct_count, precision, recall, f1_score] 79 | row = [num_to_str(cell).ljust(w) for cell, w in zip(row, head_width)] 80 | row.insert(0, display_label.rjust(max_label_width)) 81 | res += ' '.join(row) + '\n' 82 | 83 | # weighing by the truth label's frequency 84 | label_weights = [safe_division(self.total_truths.get(label, 0), self.num_tests) 85 | for label in self.labels if label not in ['NONE']] 86 | weighted_scores = [(w * p, w * r, w * f) for w, (p, r, f) in zip(label_weights, all_scores)] 87 | 88 | assert len(label_weights) == len(weighted_scores) 89 | 90 | res += '\n' 91 | res += ' '.join(['Weighted Avg'.rjust(max_label_width), 92 | ''.ljust(head_width[0]), 93 | ''.ljust(head_width[1]), 94 | ''.ljust(head_width[2]), 95 | num_to_str(sum(p for p, _, _ in weighted_scores)).ljust(head_width[3]), 96 | num_to_str(sum(r for _, r, _ in weighted_scores)).ljust(head_width[4]), 97 | num_to_str(sum(f for _, _, f in weighted_scores)).ljust(head_width[5])]) 98 | 99 | print(n_correct, n_pred, n_true) 100 | 101 | precision = safe_division(n_correct, n_pred) 102 | recall = safe_division(n_correct, n_true) 103 | f1_score = safe_division(2.0 * precision * recall, precision + recall) 104 | 105 | res += f'\n Total Examples: {self.num_tests}' 106 | res += f'\n Overall Precision: {num_to_str(precision)}' 107 | res += f'\n Overall Recall: {num_to_str(recall)}' 108 | res += f'\n Overall F1: {num_to_str(f1_score)} ' 109 | self.rel_f1 = f1_score 110 | return res 111 | -------------------------------------------------------------------------------- /data/matres/test.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/JointEventTempRel/3a6291f6fd70d85fdc2571b1d7a17904edd1cec9/data/matres/test.pickle -------------------------------------------------------------------------------- /data/matres/test_docs.txt: -------------------------------------------------------------------------------- 1 | bbc_20130322_721 2 | bbc_20130322_332 3 | nyt_20130322_strange_computer 4 | nyt_20130321_cyprus 5 | CNN_20130322_1243 6 | CNN_20130321_821 7 | bbc_20130322_1353 8 | CNN_20130322_1003 9 | CNN_20130322_314 10 | bbc_20130322_1600 11 | WSJ_20130321_1145 12 | WSJ_20130322_804 13 | CNN_20130322_248 14 | WSJ_20130322_159 15 | WSJ_20130318_731 16 | nyt_20130321_sarcozy 17 | AP_20130322 18 | nyt_20130321_china_pollution 19 | nyt_20130321_women_senate 20 | bbc_20130322_1150 21 | -------------------------------------------------------------------------------- /data/matres/train.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/JointEventTempRel/3a6291f6fd70d85fdc2571b1d7a17904edd1cec9/data/matres/train.pickle -------------------------------------------------------------------------------- /data/matres/train_docs.txt: -------------------------------------------------------------------------------- 1 | APW19980227.0476 2 | PRI19980121.2000.2591 3 | wsj_0132 4 | wsj_0695 5 | wsj_0805 6 | wsj_0760 7 | wsj_1003 8 | wsj_0585 9 | wsj_0135 10 | wsj_0928 11 | wsj_0346 12 | wsj_0568 13 | wsj_0175 14 | SJMN91-06338157 15 | wsj_0172 16 | wsj_1038 17 | wsj_0752 18 | PRI19980303.2000.2550 19 | wsj_0348 20 | wsj_1031 21 | wsj_0660 22 | NYT19980402.0453 23 | wsj_0032 24 | wsj_0325 25 | ea980120.1830.0456 26 | wsj_0667 27 | wsj_0570 28 | wsj_0736 29 | wsj_0612 30 | wsj_0505 31 | wsj_0151 32 | wsj_0187 33 | wsj_0778 34 | APW19980213.1380 35 | wsj_0158 36 | wsj_0189 37 | wsj_0124 38 | wsj_0542 39 | NYT19980206.0466 40 | CNN19980227.2130.0067 41 | wsj_0106 42 | wsj_0670 43 | wsj_0332 44 | APW19980227.0494 45 | NYT19980424.0421 46 | wsj_0173 47 | APW19980219.0476 48 | wsj_0340 49 | wsj_0679 50 | wsj_1039 51 | APW19980306.1001 52 | PRI19980205.2000.1890 53 | VOA19980303.1600.0917 54 | wsj_0263 55 | wsj_1042 56 | wsj_0713 57 | CNN19980213.2130.0155 58 | wsj_0637 59 | wsj_0520 60 | wsj_0768 61 | wsj_0527 62 | wsj_0471 63 | wsj_0584 64 | wsj_0927 65 | wsj_0555 66 | wsj_0583 67 | ed980111.1130.0089 68 | wsj_0122 69 | wsj_0685 70 | wsj_1014 71 | wsj_0815 72 | AP900815-0044 73 | wsj_0006 74 | wsj_0159 75 | ABC19980304.1830.1636 76 | NYT19980206.0460 77 | wsj_0316 78 | wsj_1013 79 | wsj_0157 80 | wsj_0938 81 | ea980120.1830.0071 82 | wsj_0073 83 | APW19980227.0468 84 | wsj_0150 85 | VOA19980305.1800.2603 86 | wsj_0356 87 | wsj_0165 88 | wsj_0661 89 | wsj_0904 90 | APW19980213.1310 91 | wsj_0324 92 | wsj_0745 93 | APW19980308.0201 94 | wsj_0811 95 | wsj_0650 96 | APW19980418.0210 97 | NYT19980212.0019 98 | wsj_0816 99 | wsj_0991 100 | wsj_0292 101 | CNN19980223.1130.0960 102 | wsj_0706 103 | wsj_0161 104 | VOA19980331.1700.1533 105 | wsj_0610 106 | wsj_0329 107 | wsj_0662 108 | wsj_0575 109 | wsj_0907 110 | wsj_0168 111 | PRI19980115.2000.0186 112 | wsj_1025 113 | APW19980501.0480 114 | wsj_0026 115 | wsj_0781 116 | wsj_0674 117 | wsj_1033 118 | wsj_0786 119 | WSJ910225-0066 120 | wsj_0344 121 | wsj_0918 122 | PRI19980216.2000.0170 123 | CNN19980126.1600.1104 124 | wsj_1073 125 | ABC19980120.1830.0957 126 | wsj_0376 127 | wsj_0558 128 | wsj_1008 129 | wsj_0068 130 | wsj_1006 131 | wsj_0923 132 | wsj_0551 133 | wsj_0762 134 | ABC19980108.1830.0711 135 | wsj_0924 136 | ABC19980114.1830.0611 137 | wsj_0791 138 | wsj_0321 139 | APW19980227.0487 140 | wsj_0906 141 | CNN19980222.1130.0084 142 | wsj_0169 143 | wsj_0167 144 | wsj_0798 145 | wsj_0973 146 | wsj_0160 147 | APW19980301.0720 148 | APW19980227.0489 149 | PRI19980213.2000.0313 150 | wsj_0534 151 | wsj_0184 152 | wsj_0152 153 | wsj_0533 154 | VOA19980501.1800.0355 155 | wsj_0541 156 | wsj_0313 157 | wsj_0810 158 | wsj_0127 159 | wsj_1011 160 | wsj_0709 161 | APW19980213.1320 162 | wsj_0586 163 | wsj_0136 164 | wsj_0557 165 | wsj_0806 166 | PRI19980306.2000.1675 167 | WSJ900813-0157 168 | wsj_0981 169 | AP900816-0139 170 | wsj_0950 171 | VOA19980303.1600.2745 172 | wsj_0144 173 | wsj_0171 174 | wsj_0266 175 | APW19980626.0364 176 | APW19980322.0749 177 | wsj_0176 178 | wsj_1040 179 | wsj_0751 180 | wsj_0027 181 | wsj_0675 182 | wsj_1035 183 | -------------------------------------------------------------------------------- /data/tbd/dev.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/JointEventTempRel/3a6291f6fd70d85fdc2571b1d7a17904edd1cec9/data/tbd/dev.pickle -------------------------------------------------------------------------------- /data/tbd/dev_docs.txt: -------------------------------------------------------------------------------- 1 | NYT19980212.0019 2 | ed980111.1130.0089 3 | APW19980227.0487 4 | CNN19980223.1130.0960 5 | PRI19980216.2000.0170 6 | -------------------------------------------------------------------------------- /data/tbd/test.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/JointEventTempRel/3a6291f6fd70d85fdc2571b1d7a17904edd1cec9/data/tbd/test.pickle -------------------------------------------------------------------------------- /data/tbd/test_docs.txt: -------------------------------------------------------------------------------- 1 | APW19980308.0201 2 | APW19980418.0210 3 | CNN19980213.2130.0155 4 | APW19980227.0489 5 | NYT19980402.0453 6 | CNN19980126.1600.1104 7 | PRI19980115.2000.0186 8 | PRI19980306.2000.1675 9 | APW19980227.0494 10 | -------------------------------------------------------------------------------- /data/tbd/train.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/JointEventTempRel/3a6291f6fd70d85fdc2571b1d7a17904edd1cec9/data/tbd/train.pickle -------------------------------------------------------------------------------- /data/tbd/train_docs.txt: -------------------------------------------------------------------------------- 1 | PRI19980121.2000.2591 2 | PRI19980205.2000.1998 3 | CNN19980227.2130.0067 4 | APW19980213.1320 5 | AP900816-0139 6 | PRI19980205.2000.1890 7 | ABC19980304.1830.1636 8 | APW19980227.0476 9 | NYT19980206.0466 10 | ABC19980120.1830.0957 11 | ea980120.1830.0456 12 | APW19980213.1380 13 | ABC19980108.1830.0711 14 | ABC19980114.1830.0611 15 | APW19980219.0476 16 | CNN19980222.1130.0084 17 | AP900815-0044 18 | PRI19980213.2000.0313 19 | APW19980213.1310 20 | ea980120.1830.0071 21 | APW19980227.0468 22 | NYT19980206.0460 23 | -------------------------------------------------------------------------------- /other/pos_tags.txt: -------------------------------------------------------------------------------- 1 | CC 2 | CD 3 | DT 4 | EX 5 | FW 6 | IN 7 | JJ 8 | JJR 9 | JJS 10 | LS 11 | MD 12 | NN 13 | NNS 14 | NNP 15 | NNPS 16 | PDT 17 | POS 18 | PRP 19 | PRP$ 20 | RB 21 | RBR 22 | RBS 23 | RP 24 | SYM 25 | TO 26 | UH 27 | VB 28 | VBD 29 | VBG 30 | VBN 31 | VBP 32 | VBZ 33 | WDT 34 | WP 35 | WP$ 36 | WRB --------------------------------------------------------------------------------