├── .gitignore ├── Config.py ├── Dataset.py ├── Dataset_Trigger.py ├── Eval_Trigger.py ├── LICENSE ├── Model.py ├── Model_Trigger.py ├── Preprocess.py ├── README.md ├── Script.py ├── Script_Test.py ├── Serving_Trigger.py ├── Util.py ├── Visualize.py ├── requirements.txt ├── send_script.py ├── test.py └── visualization └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | venv/ 4 | __pycache__/ 5 | *.txt 6 | word_embed/ 7 | debug/ 8 | *.xlsx 9 | *.json 10 | logdir/ 11 | logdir_dev/ 12 | house/ 13 | mnli/ 14 | runs/ 15 | *.html 16 | *.bin 17 | 18 | pid.txt 19 | log.txt 20 | -------------------------------------------------------------------------------- /Config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class MyConfig: 4 | raw_data_path = './data/ace_2005_td_v7/data/English/{}/adj/' 5 | raw_dir_list = os.listdir('./data/ace_2005_td_v7/data/English/') 6 | word_embed_size = 100 7 | glove_txt_path = './data/glove/glove.6B/glove.6B.{}d.txt'.format(word_embed_size) 8 | mark_long_entity_in_pos = True 9 | 10 | 11 | class HyperParams: 12 | batch_size = 30 13 | max_sequence_length = 20 14 | windows = 3 15 | word_embedding_size = MyConfig.word_embed_size 16 | pos_embedding_size = 10 17 | lr = 1e-3 18 | filter_sizes = [3, 4, 5] 19 | filter_num = 100 20 | 21 | num_epochs = 20 22 | 23 | class HyperParams_Tri_classification: 24 | batch_size = 128 25 | max_sequence_length = 30 26 | windows = 3 27 | word_embedding_size = MyConfig.word_embed_size 28 | pos_embedding_size = 10 29 | lr = 0.001 30 | filter_sizes = [3, 4, 5] 31 | filter_num = 128 32 | 33 | num_epochs = 200 # 50 for Identification, 200 for Classification 34 | -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from Util import one_hot, find_candidates 3 | 4 | 5 | class Dataset: 6 | def __init__(self, 7 | data_path='', 8 | batch_size=30, 9 | max_sequence_length=30, 10 | windows=3, 11 | eval_num=30, 12 | dtype=None): 13 | assert dtype in ['IDENTIFICATION','CLASSIFICATION'] 14 | 15 | self.windows = windows 16 | self.batch_size = batch_size 17 | self.max_sequence_length = max_sequence_length 18 | self.eval_num = eval_num 19 | self.dtype = dtype 20 | 21 | self.all_words = list() 22 | self.all_pos_taggings = list() 23 | self.all_marks = list() 24 | self.all_labels = list() 25 | self.instances = list() 26 | 27 | self.word_id = dict() 28 | self.pos_taggings_id = dict() 29 | self.mark_id = dict() 30 | self.label_id = dict() 31 | 32 | self.read_dataset() 33 | 34 | self.word_embed = None 35 | 36 | self.train_instances, self.eval_instances = [],[] 37 | self.divide_train_eval_data() 38 | self.batch_nums = len(self.train_instances) // self.batch_size 39 | self.index = np.arange(len(self.train_instances)) 40 | self.point = 0 41 | print('all label for dataset: {}'.format(len(self.all_labels))) 42 | 43 | def divide_train_eval_data(self): 44 | testset_fname = [] 45 | for ins in self.instances: 46 | if 'nw/adj' not in ins['fname']: 47 | self.train_instances.append(ins) 48 | elif ins['fname'] in testset_fname: 49 | self.eval_instances.append(ins) 50 | elif len(testset_fname) > 40: 51 | self.train_instances.append(ins) 52 | else: 53 | testset_fname.append(ins['fname']) 54 | self.eval_instances.append(ins) 55 | 56 | print('TRAIN: {} TEST: {}'.format(len(self.train_instances), len(self.eval_instances))) 57 | assert len(self.instances) == (len(self.train_instances) + len(self.eval_instances)) 58 | 59 | def read_dataset(self): 60 | all_words, all_pos_taggings, all_labels, all_marks = [set() for _ in range(4)] 61 | 62 | def read_one(words, marks, label, fname): 63 | # TODO: remove comments mark when use POS tag info for model. `nltk.pos_tag()` method too slow. 64 | #pos_taggings = nltk.pos_tag(words) 65 | #pos_taggings = [pos_tagging[1] for pos_tagging in pos_taggings] 66 | pos_taggings = [None for i in range(10)] 67 | 68 | for word in words: all_words.add(word) 69 | for mark in marks: all_marks.add(mark) 70 | for pos_tag in pos_taggings: all_pos_taggings.add(pos_tag) 71 | all_labels.add(label) 72 | 73 | if len(words) > 80: 74 | # print('len(word) > 80, Goodbye! ', len(words), words) 75 | return 76 | 77 | self.instances.append({ 78 | 'words': words, 79 | 'pos_taggings': pos_taggings, 80 | 'marks': marks, 81 | 'label': label, 82 | 'fname':fname 83 | }) 84 | 85 | # current word: $500 billion 86 | # read_one( 87 | # words=['It', 'could', 'swell', 'to', 'as', 'much', 'as', '$500 billion', 'if', 'we', 'go', 'to', 'war', 'in', 'Iraq'], 88 | # marks=['A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'A', 'A', 'T', 'A', 'A'], 89 | # label='None', 90 | # ) 91 | # # current word: we 92 | # read_one( 93 | # words=['It', 'could', 'swell', 'to', 'as', 'much', 'as', '$500 billion', 'if', 'we', 'go', 'to', 'war', 'in', 'Iraq'], 94 | # marks=['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'T', 'A', 'A'], 95 | # label='Attacker', 96 | # ) 97 | # # current word: Iraq 98 | # read_one( 99 | # words=['It', 'could', 'swell', 'to', 'as', 'much', 'as', '$500 billion', 'if', 'we', 'go', 'to', 'war', 'in', 'Iraq'], 100 | # marks=['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'T', 'A', 'B'], 101 | # label='Place', 102 | # ) 103 | 104 | from Preprocess import PreprocessManager 105 | man = PreprocessManager() 106 | man.preprocess(tasktype='ARGUMENT', subtasktype=self.dtype) 107 | argument_classification_data = man.arg_task_format_data 108 | for data in argument_classification_data: 109 | read_one(words=data[0], marks=data[1], label=data[2], fname=data[3]) 110 | 111 | all_words.add('') 112 | all_pos_taggings.add('*') 113 | 114 | self.word_id = dict(zip(all_words, range(len(all_words)))) 115 | self.pos_taggings_id = dict(zip(all_pos_taggings, range(len(all_pos_taggings)))) 116 | self.mark_id = dict(zip(all_marks, range(len(all_marks)))) 117 | self.label_id = dict(zip(all_labels, range(len(all_labels)))) 118 | 119 | self.all_words = list(all_words) 120 | self.all_pos_taggings = list(all_pos_taggings) 121 | self.all_labels = list(all_labels) 122 | self.all_marks = list(all_marks) 123 | 124 | def shuffle(self): 125 | np.random.shuffle(self.index) 126 | self.point = 0 127 | 128 | def next_batch(self): 129 | start = self.point 130 | self.point = self.point + self.batch_size 131 | if self.point > len(self.train_instances): 132 | self.shuffle() 133 | start = 0 134 | self.point = self.point + self.batch_size 135 | end = self.point 136 | batch_instances = map(lambda x: self.train_instances[x], self.index[start:end]) 137 | return batch_instances 138 | 139 | def next_train_data(self): 140 | batch_instances = self.next_batch() 141 | pos_tag, y, x, t, c, pos_c, pos_t = [list() for _ in range(7)] 142 | 143 | for instance in batch_instances: 144 | words = instance['words'] 145 | pos_taggings = instance['pos_taggings'] 146 | marks = instance['marks'] 147 | label = instance['label'] 148 | 149 | index_candidates = find_candidates(marks, ['B']) 150 | assert (len(index_candidates)) == 1 151 | index_triggers = find_candidates(marks, ['T']) 152 | # assert (len(index_triggers)) == 1 153 | y.append(label) 154 | marks = marks + ['A'] * (self.max_sequence_length - len(marks)) 155 | words = words + [''] * (self.max_sequence_length - len(words)) 156 | pos_taggings = pos_taggings + ['*'] * (self.max_sequence_length - len(pos_taggings)) 157 | pos_taggings = list(map(lambda x: self.pos_taggings_id[x], pos_taggings)) 158 | pos_tag.append(pos_taggings) 159 | index_words = list(map(lambda x: self.word_id[x], words)) 160 | x.append(index_words) 161 | pos_candidate = [i for i in range(-index_candidates[0], 0)] + [i for i in range(0, self.max_sequence_length - index_candidates[0])] 162 | pos_c.append(pos_candidate) 163 | pos_trigger = [i for i in range(-index_triggers[0], 0)] + [i for i in range(0, self.max_sequence_length - index_triggers[0])] 164 | pos_t.append(pos_trigger) 165 | t.append([index_words[index_triggers[0]]] * self.max_sequence_length) 166 | c.append([index_words[index_candidates[0]]] * self.max_sequence_length) 167 | 168 | # print(len(words), len(marks), len(pos_taggings), len(index_words), len(pos_candidate), len(pos_trigger)) 169 | assert len(words) == len(marks) == len(pos_taggings) == len(index_words) == len(pos_candidate) == len(pos_trigger) 170 | assert len(y) == len(x) == len(t) == len(c) == len(pos_c) == len(pos_t) == len(pos_tag) 171 | return x, t, c, one_hot(y, self.label_id, len(self.all_labels)), pos_c, pos_t, pos_tag 172 | 173 | def eval_data(self): 174 | batch_instances = self.eval_instances 175 | pos_tag, y, x, t, c, pos_c, pos_t = [list() for _ in range(7)] 176 | 177 | for instance in batch_instances: 178 | words = instance['words'] 179 | pos_taggings = instance['pos_taggings'] 180 | marks = instance['marks'] 181 | label = instance['label'] 182 | index_candidates = find_candidates(marks, ['B']) 183 | assert (len(index_candidates)) == 1 184 | index_triggers = find_candidates(marks, ['T']) 185 | # assert (len(index_triggers)) == 1 186 | y.append(label) 187 | marks = marks + ['A'] * (self.max_sequence_length - len(marks)) 188 | words = words + [''] * (self.max_sequence_length - len(words)) 189 | pos_taggings = pos_taggings + ['*'] * (self.max_sequence_length - len(pos_taggings)) 190 | pos_taggings = list(map(lambda x: self.pos_taggings_id[x], pos_taggings)) 191 | pos_tag.append(pos_taggings) 192 | index_words = list(map(lambda x: self.word_id[x], words)) 193 | x.append(index_words) 194 | pos_candidate = [i for i in range(-index_candidates[0], 0)] + [i for i in range(0, self.max_sequence_length - index_candidates[0])] 195 | pos_c.append(pos_candidate) 196 | pos_trigger = [i for i in range(-index_triggers[0], 0)] + [i for i in range(0, self.max_sequence_length - index_triggers[0])] 197 | pos_t.append(pos_trigger) 198 | t.append([index_words[index_triggers[0]]] * self.max_sequence_length) 199 | c.append([index_words[index_candidates[0]]] * self.max_sequence_length) 200 | assert len(words) == len(marks) == len(pos_taggings) == len(index_words) == len(pos_candidate) == len(pos_trigger) 201 | assert len(y) == len(x) == len(t) == len(c) == len(pos_c) == len(pos_t) == len( 202 | pos_tag) 203 | return x, t, c, one_hot(y, self.label_id, len(self.all_labels)), pos_c, pos_t, pos_tag 204 | 205 | if __name__=='__main__': 206 | import pprint 207 | pp = pprint.PrettyPrinter(indent=4) 208 | D = Dataset() 209 | q = D.next_train_data() 210 | for i in q: 211 | pp.pprint(i[0]) 212 | -------------------------------------------------------------------------------- /Dataset_Trigger.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle, os 3 | import numpy as np 4 | import nltk 5 | from Util import one_hot, find_candidates 6 | from Config import MyConfig, HyperParams_Tri_classification 7 | 8 | 9 | class Dataset_Trigger: 10 | def __init__(self, 11 | data_path='', 12 | batch_size=30, 13 | max_sequence_length=30, 14 | windows=3, 15 | eval_num=30, 16 | dtype=None): 17 | assert dtype in ['IDENTIFICATION', 'CLASSIFICATION'] 18 | 19 | self.windows = windows 20 | self.batch_size = batch_size 21 | self.max_sequence_length = max_sequence_length 22 | 23 | self.dtype = dtype 24 | 25 | self.all_words = list() 26 | self.all_pos_taggings = list() 27 | self.all_marks = list() 28 | self.all_labels = list() 29 | self.instances = list() 30 | 31 | self.word_id = dict() 32 | self.id2word = dict() 33 | self.pos_taggings_id = dict() 34 | self.mark_id = dict() 35 | self.label_id = dict() 36 | self.id2label = dict() 37 | 38 | print('read data...', end=' ') 39 | self.read_dataset() 40 | print('complete') 41 | 42 | self.word_embed = self.embed_manager() 43 | 44 | self.valid_instances, self.eval_instances, self.train_instances = [], [], [] 45 | self.divide_train_valid_eval_data() 46 | print('\n\n########### TRAIN: {} VALID: {} TEST: {}'.format(len(self.train_instances), 47 | len(self.valid_instances), 48 | len(self.eval_instances))) 49 | 50 | # self.over_sampling() 51 | self.batch_nums = len(self.train_instances) // self.batch_size 52 | self.index = np.arange(len(self.train_instances)) 53 | self.point = 0 54 | print('all label for dataset: {}'.format(len(self.all_labels))) 55 | 56 | def embed_manager(self): 57 | matrix = np.zeros([len(self.all_words), HyperParams_Tri_classification.word_embedding_size]) 58 | word_map = self.read_glove() 59 | 60 | special_key_dump_fname = './data/special_key_emblen_{}.bin'.format(HyperParams_Tri_classification.word_embedding_size) 61 | 62 | if os.path.exists(special_key_dump_fname): 63 | with open(special_key_dump_fname,'rb') as f: 64 | dumped_skey = pickle.load(f) 65 | for k in self.special_key: 66 | matrix[self.word_id[k]] = dumped_skey[k] 67 | else: 68 | dumped_skey = dict() 69 | for k in self.special_key: 70 | tmp_val = np.random.normal(0, 0.001, HyperParams_Tri_classification.word_embedding_size) 71 | dumped_skey[k] = tmp_val 72 | matrix[self.word_id[k]] = tmp_val 73 | with open(special_key_dump_fname,'wb') as f: 74 | pickle.dump(dumped_skey, f) 75 | 76 | for idx, word in enumerate(self.all_words): 77 | if word in word_map.keys(): 78 | matrix[idx] = word_map[word] 79 | else: 80 | if len(word.split()) == 1: # OOV case 81 | 82 | if word.lower() in word_map.keys():# even 'Did' is OOV! 83 | matrix[idx] = word_map[word.lower()] 84 | else: 85 | print('oov: {}'.format(word)) 86 | matrix[idx] = matrix[self.word_id['']] 87 | else: # multiple word as one word, maybe Entity case 88 | pass # Do it after iterating all voca once 89 | 90 | for idx, word in enumerate(self.all_words): 91 | if word not in word_map.keys() and (len(word.split()) != 1): ##Entity case 92 | for subword in word.split(): 93 | if subword in word_map: 94 | matrix[idx] += word_map[subword] 95 | else: 96 | matrix[idx] += matrix[self.word_id['']] 97 | return matrix 98 | 99 | @staticmethod 100 | def read_glove(): 101 | word_map = dict() 102 | with open(MyConfig.glove_txt_path, 'r', encoding='utf8') as f: 103 | ls = f.readlines() 104 | for l in ls: 105 | l = l.split() 106 | word_map[l[0]] = [float(el) for el in l[1:]] 107 | return word_map 108 | 109 | def over_sampling(self): 110 | label_instance = dict() 111 | for label in self.all_labels: 112 | label_instance[label] = [] 113 | 114 | label_max_count = 0 115 | for instance in self.train_instances: 116 | label_instance[instance['label']].append(instance) 117 | for label in label_instance: 118 | if label_max_count < len(label_instance[label]): label_max_count = len(label_instance[label]) 119 | 120 | new_train_instances = [] 121 | for label in self.all_labels: 122 | more = label_max_count - len(label_instance[label]) 123 | instances = label_instance[label] 124 | for i in range(more): 125 | instances.append(instances[i]) 126 | new_train_instances = new_train_instances + instances 127 | 128 | print('label_max_count : ', label_max_count) 129 | print('before_train_instances: ', len(self.train_instances)) 130 | print('new_train_instances :', len(new_train_instances)) 131 | self.train_instances = new_train_instances 132 | 133 | 134 | def divide_train_valid_eval_data(self): 135 | tdv_instance_fname = './data/trigger_TDV_divide_{}_maxlen_{}_instance.bin'.format(self.dtype, HyperParams_Tri_classification.max_sequence_length) 136 | train_ins, valid_ins, test_ins = [], [], [] 137 | 138 | if os.path.exists(tdv_instance_fname): 139 | with open(tdv_instance_fname,'rb') as f: 140 | train_ins,valid_ins,test_ins = pickle.load(f) 141 | else: 142 | validset_fname, testset_fname = [], [] 143 | random.shuffle(self.instances) 144 | # select test set randomly 145 | # for ins in self.instances: 146 | # if 'nw/adj' not in ins['fname']: 147 | # train_ins.append(ins) 148 | # elif ins['fname'] in testset_fname: 149 | # test_ins.append(ins) 150 | # elif ins['fname'] in validset_fname: 151 | # valid_ins.append(ins) 152 | # elif len(testset_fname) >= 40 and len(validset_fname)>= 30: 153 | # train_ins.append(ins) 154 | # elif len(validset_fname)<30: 155 | # validset_fname.append(ins['fname']) 156 | # valid_ins.append(ins) 157 | # elif len(testset_fname)<: 158 | # testset_fname.append(ins['fname']) 159 | # test_ins.append(ins) 160 | # else: 161 | # raise ValueError 162 | for ins in self.instances: 163 | if ins['fname'] in testset_fname: 164 | test_ins.append(ins) 165 | elif ins['fname'] in validset_fname: 166 | valid_ins.append(ins) 167 | elif len(validset_fname)<35: 168 | validset_fname.append(ins['fname']) 169 | valid_ins.append(ins) 170 | elif len(testset_fname)<35: 171 | testset_fname.append(ins['fname']) 172 | test_ins.append(ins) 173 | else: 174 | train_ins.append(ins) 175 | with open(tdv_instance_fname, 'wb') as f: 176 | pickle.dump([train_ins, valid_ins, test_ins],f) 177 | 178 | self.train_instances, self.valid_instances, self.eval_instances = train_ins, valid_ins, test_ins 179 | random.shuffle(self.train_instances) 180 | assert len(self.instances) == (len(self.train_instances) + len(self.eval_instances) + len(self.valid_instances)) 181 | 182 | def manage_entity_in_POS(self, poss, entity_mark): 183 | new_pos = [] 184 | assert len(poss) == len(entity_mark) 185 | for pos, ent in zip(poss, entity_mark): 186 | if ent == '*': 187 | new_pos.append(pos[1]) 188 | elif len(pos[0].split()) == 1: 189 | new_pos.append(pos[1]) 190 | else: 191 | new_pos.append('ENTITY') 192 | return new_pos 193 | 194 | def read_dataset(self): 195 | all_words, all_pos_taggings, all_labels, all_marks = [set() for _ in range(4)] 196 | 197 | def read_one(words, marks, label, fname, entity_mark): 198 | pos_taggings = nltk.pos_tag(words) 199 | if MyConfig.mark_long_entity_in_pos: 200 | pos_taggings = self.manage_entity_in_POS(pos_taggings, entity_mark) 201 | # pos_taggings = [pos_tagging[1] for pos_tagging in pos_taggings] 202 | 203 | assert len(pos_taggings) == len(words) 204 | 205 | for word in words: all_words.add(word) 206 | for mark in marks: all_marks.add(mark) 207 | for pos_tag in pos_taggings: all_pos_taggings.add(pos_tag) 208 | all_labels.add(label) 209 | 210 | if len(words) > HyperParams_Tri_classification.max_sequence_length: 211 | # print('len(word) > 80, Goodbye! ', len(words), words) 212 | return None 213 | 214 | res = { 215 | 'words': words, 216 | 'pos_taggings': pos_taggings, 217 | 'marks': marks, 218 | 'label': label, 219 | 'fname': fname 220 | } 221 | return res 222 | 223 | from Preprocess import PreprocessManager 224 | man = PreprocessManager() 225 | man.preprocess(tasktype='TRIGGER', subtasktype=self.dtype) 226 | tri_classification_data = man.tri_task_format_data 227 | 228 | total_instance = [] 229 | dump_instance_fname = './data/trigger_{}_maxlen_{}_instance.bin'.format(self.dtype, HyperParams_Tri_classification.max_sequence_length) 230 | 231 | if os.path.exists(dump_instance_fname): 232 | print('use previous instance data for trigger task') 233 | with open(dump_instance_fname, 'rb') as f: 234 | total_instance = pickle.load(f) 235 | for ins in total_instance: 236 | for word in ins['words']: all_words.add(word) 237 | for mark in ins['marks']: all_marks.add(mark) 238 | for pos_tag in ins['pos_taggings']: all_pos_taggings.add(pos_tag) 239 | all_labels.add(ins['label']) 240 | else: 241 | print('Read {} data....'.format(len(tri_classification_data))) 242 | for idx, data in enumerate(tri_classification_data): 243 | if idx % 1000 == 0: print('{}/{}'.format(idx, len(tri_classification_data))) 244 | res = read_one(words=data[0], marks=data[1], label=data[2], fname=data[3], entity_mark=data[4]) 245 | if res is not None: total_instance.append(res) 246 | with open(dump_instance_fname, 'wb') as f: 247 | pickle.dump(total_instance, f) 248 | 249 | self.instances = total_instance 250 | 251 | all_words.add('') 252 | all_words.add('') 253 | self.special_key = ['', ''] 254 | 255 | all_pos_taggings.add('*') 256 | 257 | self.word_id = dict(zip(all_words, range(len(all_words)))) 258 | for word in self.word_id: self.id2word[self.word_id[word]] = word 259 | self.pos_taggings_id = dict(zip(all_pos_taggings, range(len(all_pos_taggings)))) 260 | self.mark_id = dict(zip(all_marks, range(len(all_marks)))) 261 | self.label_id = dict(zip(all_labels, range(len(all_labels)))) 262 | for label in self.label_id: self.id2label[self.label_id[label]] = label 263 | 264 | self.all_words = list(all_words) 265 | self.all_pos_taggings = list(all_pos_taggings) 266 | self.all_labels = list(all_labels) 267 | self.all_marks = list(all_marks) 268 | 269 | def shuffle(self): 270 | np.random.shuffle(self.index) 271 | self.point = 0 272 | 273 | def next_batch(self): 274 | start = self.point 275 | self.point = self.point + self.batch_size 276 | if self.point > len(self.train_instances): 277 | self.shuffle() 278 | start = 0 279 | self.point = self.point + self.batch_size 280 | end = self.point 281 | batch_instances = map(lambda x: self.train_instances[x], self.index[start:end]) 282 | return batch_instances 283 | 284 | def next_train_data(self): 285 | batch_instances = self.next_batch() 286 | pos_tag, y, x, c, pos_c = [list() for _ in range(5)] 287 | 288 | for instance in batch_instances: 289 | words = instance['words'] 290 | pos_taggings = instance['pos_taggings'] 291 | marks = instance['marks'] 292 | label = instance['label'] 293 | 294 | index_candidates = find_candidates(marks, ['B']) 295 | assert (len(index_candidates)) == 1 296 | 297 | y.append(label) 298 | marks = marks + ['A'] * (self.max_sequence_length - len(marks)) 299 | words = words + [''] * (self.max_sequence_length - len(words)) 300 | pos_taggings = pos_taggings + ['*'] * (self.max_sequence_length - len(pos_taggings)) 301 | pos_taggings = list(map(lambda x: self.pos_taggings_id[x], pos_taggings)) 302 | pos_tag.append(pos_taggings) 303 | index_words = list(map(lambda x: self.word_id[x], words)) 304 | x.append(index_words) 305 | pos_candidate = [i for i in range(-index_candidates[0], 0)] + [i for i in range(0, self.max_sequence_length - index_candidates[0])] 306 | pos_c.append(pos_candidate) 307 | c.append([index_words[index_candidates[0]]] * self.max_sequence_length) 308 | assert len(words) == len(marks) == len(pos_taggings) == len(index_words) == len(pos_candidate) 309 | 310 | assert len(y) == len(x) == len(c) == len(pos_c) == len(pos_tag) 311 | return x, c, one_hot(y, self.label_id, len(self.all_labels)), pos_c, pos_tag 312 | 313 | def next_eval_data(self): 314 | batch_instances = self.eval_instances 315 | pos_tag, y, x, c, pos_c = [list() for _ in range(5)] 316 | 317 | for instance in batch_instances: 318 | words = instance['words'] 319 | pos_taggings = instance['pos_taggings'] 320 | marks = instance['marks'] 321 | label = instance['label'] 322 | 323 | index_candidates = find_candidates(marks, ['B']) 324 | assert (len(index_candidates)) == 1 325 | 326 | y.append(label) 327 | marks = marks + ['A'] * (self.max_sequence_length - len(marks)) 328 | words = words + [''] * (self.max_sequence_length - len(words)) 329 | pos_taggings = pos_taggings + ['*'] * (self.max_sequence_length - len(pos_taggings)) 330 | pos_taggings = list(map(lambda x: self.pos_taggings_id[x], pos_taggings)) 331 | pos_tag.append(pos_taggings) 332 | index_words = list(map(lambda x: self.word_id[x], words)) 333 | x.append(index_words) 334 | pos_candidate = [i for i in range(-index_candidates[0], 0)] + [i for i in range(0, 335 | self.max_sequence_length - 336 | index_candidates[0])] 337 | pos_c.append(pos_candidate) 338 | c.append([index_words[index_candidates[0]]] * self.max_sequence_length) 339 | assert len(words) == len(marks) == len(pos_taggings) == len(index_words) == len(pos_candidate) 340 | assert len(y) == len(x) == len(c) == len(pos_c) == len(pos_tag) 341 | return x, c, one_hot(y, self.label_id, len(self.all_labels)), pos_c, pos_tag 342 | 343 | def next_valid_data(self): 344 | batch_instances = self.valid_instances 345 | pos_tag, y, x, c, pos_c = [list() for _ in range(5)] 346 | 347 | for instance in batch_instances: 348 | words = instance['words'] 349 | pos_taggings = instance['pos_taggings'] 350 | marks = instance['marks'] 351 | label = instance['label'] 352 | 353 | index_candidates = find_candidates(marks, ['B']) 354 | assert (len(index_candidates)) == 1 355 | 356 | y.append(label) 357 | marks = marks + ['A'] * (self.max_sequence_length - len(marks)) 358 | words = words + [''] * (self.max_sequence_length - len(words)) 359 | pos_taggings = pos_taggings + ['*'] * (self.max_sequence_length - len(pos_taggings)) 360 | pos_taggings = list(map(lambda x: self.pos_taggings_id[x], pos_taggings)) 361 | pos_tag.append(pos_taggings) 362 | index_words = list(map(lambda x: self.word_id[x], words)) 363 | x.append(index_words) 364 | pos_candidate = [i for i in range(-index_candidates[0], 0)] + [i for i in range(0, 365 | self.max_sequence_length - 366 | index_candidates[0])] 367 | pos_c.append(pos_candidate) 368 | c.append([index_words[index_candidates[0]]] * self.max_sequence_length) 369 | assert len(words) == len(marks) == len(pos_taggings) == len(index_words) == len(pos_candidate) 370 | assert len(y) == len(x) == len(c) == len(pos_c) == len(pos_tag) 371 | return x, c, one_hot(y, self.label_id, len(self.all_labels)), pos_c, pos_tag 372 | 373 | 374 | if __name__ == '__main__': 375 | D = Dataset_Trigger() 376 | a = D.next_train_data() 377 | -------------------------------------------------------------------------------- /Eval_Trigger.py: -------------------------------------------------------------------------------- 1 | import datetime, os, time 2 | import numpy as np 3 | import tensorflow as tf 4 | from Dataset_Trigger import Dataset_Trigger as TRIGGER_DATASET 5 | from Config import HyperParams_Tri_classification as hp 6 | import nltk 7 | 8 | def get_batch(sentence, word_id, max_sequence_length): 9 | tokens = [word for word in nltk.word_tokenize(sentence)] 10 | 11 | 12 | words = [] 13 | for i in range(max_sequence_length): 14 | if i < len(tokens): 15 | words.append(tokens[i]) 16 | else: 17 | words.append('') 18 | 19 | word_ids = [] 20 | for word in words: 21 | if word in word_id: 22 | word_ids.append(word_id[word]) 23 | else: 24 | word_ids.append(word_id['']) 25 | 26 | # print('word_ids :', word_ids) 27 | size = len(word_ids) 28 | 29 | x_batch = [] 30 | x_pos_batch = [] 31 | for i in range(size): 32 | x_batch.append(word_ids) 33 | x_pos_batch.append([j - i for j in range(size)]) 34 | 35 | return x_batch, x_pos_batch, tokens 36 | 37 | if __name__ == '__main__': 38 | dataset = TRIGGER_DATASET(batch_size=hp.batch_size, max_sequence_length=hp.max_sequence_length, 39 | windows=hp.windows, dtype='IDENTIFICATION') 40 | 41 | x_batch, x_pos_batch, token = get_batch(sentence = 'It could swell to as much as $500 billion if we go to war in Iraq', 42 | word_id = dataset.word_id, max_sequence_length=hp.max_sequence_length) 43 | 44 | print('x_batch :', x_batch) 45 | print('x_pos_batch :', x_pos_batch) 46 | 47 | checkpoint_dir = './runs/1542831140/checkpoints' 48 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 49 | 50 | graph = tf.Graph() 51 | with graph.as_default(): 52 | sess = tf.Session() 53 | with sess.as_default(): 54 | # Load the saved meta graph and restore variables 55 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 56 | saver.restore(sess, checkpoint_file) 57 | 58 | # Get the placeholders from the graph by name 59 | input_x = graph.get_operation_by_name("input_x").outputs[0] 60 | input_c_pos = graph.get_operation_by_name("input_c_pos").outputs[0] 61 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 62 | 63 | # Tensors we want to evaluate 64 | predictions = graph.get_operation_by_name("output/predicts").outputs[0] 65 | 66 | feed_dict = { 67 | input_x: x_batch, 68 | input_c_pos: x_pos_batch, 69 | dropout_keep_prob: 1.0, 70 | } 71 | 72 | preds = sess.run(predictions, feed_dict) 73 | print('result!') 74 | for i in range(len(preds)): 75 | print('{}: {}'.format(dataset.id2word[x_batch[0][i]], preds[i])) 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 NLP*CL 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import time, datetime, os 2 | import tensorflow as tf 3 | from Dataset import Dataset 4 | import numpy as np 5 | 6 | from Config import HyperParams as hp 7 | 8 | """ 9 | Original taken from https://github.com/zhangluoyang/cnn-for-auto-event-extract 10 | """ 11 | 12 | 13 | class Model(): 14 | def __init__(self, 15 | sentence_length=30, 16 | num_labels=10, 17 | windows=3, 18 | vocab_size=2048, 19 | word_embedding_size=100, 20 | pos_embedding_size=10, 21 | filter_sizes=[3, 4, 5], 22 | filter_num=200, 23 | embed_matrx=None 24 | ): 25 | 26 | tf_version_checker = int(tf.__version__.split('.')[0]) 27 | 28 | """ 29 | :param sentence_length 30 | :param num_labels 31 | :param windows 32 | :param vocab_size 33 | :param word_embedding_size 34 | :param pos_embedding_size 35 | :param filter_sizes 36 | :param filter_num 37 | """ 38 | input_x = tf.placeholder(tf.int32, shape=[None, sentence_length], name="input_x") 39 | self.input_x = input_x 40 | input_y = tf.placeholder(tf.float32, shape=[None, num_labels], name="input_y") 41 | self.input_y = input_y 42 | # trigger distance vector 43 | input_t_pos = tf.placeholder(tf.int32, shape=[None, sentence_length], name="input_t_pos") 44 | self.input_t_pos = input_t_pos 45 | # argument candidates distance vector 46 | input_c_pos = tf.placeholder(tf.int32, shape=[None, sentence_length], name="input_c_pos") 47 | self.input_c_pos = input_c_pos 48 | dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 49 | self.dropout_keep_prob = dropout_keep_prob 50 | #with tf.device('/cpu:0'), tf.name_scope("word_embedding_layer"): 51 | with tf.name_scope("word_embedding_layer"): 52 | W_text = tf.Variable(tf.random_normal(shape=[vocab_size, word_embedding_size], mean=0.0, stddev=0.5), name="word_table") 53 | 54 | input_word_vec = tf.nn.embedding_lookup(W_text, input_x) 55 | input_t_pos_t = input_t_pos + (sentence_length - 1) 56 | Tri_pos = tf.Variable( 57 | tf.random_normal(shape=[2 * (sentence_length - 1) + 1, pos_embedding_size], mean=0.0, stddev=0.5), 58 | name="tri_pos_table") 59 | input_t_pos_vec = tf.nn.embedding_lookup(Tri_pos, input_t_pos_t) 60 | input_c_pos_c = input_c_pos + (sentence_length - 1) 61 | Can_pos = tf.Variable( 62 | tf.random_normal(shape=[2 * (sentence_length - 1) + 1, pos_embedding_size], mean=0.0, stddev=0.5), 63 | name="candidate_pos_table") 64 | input_c_pos_vec = tf.nn.embedding_lookup(Can_pos, input_c_pos_c) 65 | # [batch_size, sentence_length, word_embedding_size+2*pos_size] 66 | if tf_version_checker >= 1: 67 | input_sentence_vec = tf.concat([input_word_vec, input_t_pos_vec, input_c_pos_vec],2) 68 | else: 69 | input_sentence_vec = tf.concat(2, [input_word_vec, input_t_pos_vec, input_c_pos_vec]) 70 | # CNN supports 4d input, so increase the one-dimensional vector to indicate the number of input channels. 71 | input_sentence_vec_expanded = tf.expand_dims(input_sentence_vec, -1) 72 | pooled_outputs = [] 73 | for i, filter_size in enumerate(filter_sizes): 74 | with tf.name_scope('conv-maxpool-%s' % filter_size): 75 | # The current word and context of the sentence feature considered here 76 | filter_shape = [filter_size, word_embedding_size + 2 * pos_embedding_size, 1, filter_num] 77 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 78 | b = tf.Variable(tf.constant(0.1, shape=[filter_num]), name="b") 79 | conv = tf.nn.conv2d( 80 | input_sentence_vec_expanded, 81 | W, 82 | strides=[1, 1, 1, 1], 83 | padding="VALID", 84 | name="conv") 85 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 86 | pooled = tf.nn.max_pool( 87 | h, 88 | ksize=[1, sentence_length - filter_size + 1, 1, 1], 89 | strides=[1, 1, 1, 1], 90 | padding='VALID', 91 | name="pool") 92 | pooled_outputs.append(pooled) 93 | 94 | num_filters_total = filter_num * len(filter_sizes) 95 | if tf_version_checker >= 1: 96 | h_pool = tf.concat(pooled_outputs, 3) 97 | else: 98 | h_pool = tf.concat(3, pooled_outputs) 99 | h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total]) 100 | lexical_vec = tf.reshape(input_word_vec, shape=(-1, sentence_length * word_embedding_size)) 101 | # Combine lexical level features and sentence level features 102 | if tf_version_checker >= 1: 103 | all_input_features = tf.concat([lexical_vec, h_pool_flat], 1) 104 | else: 105 | all_input_features = tf.concat(1, [lexical_vec, h_pool_flat]) 106 | 107 | #with tf.device('/cpu:0'), tf.name_scope('dropout'): 108 | with tf.name_scope('dropout'): 109 | all_features = tf.nn.dropout(all_input_features, dropout_keep_prob) 110 | 111 | #with tf.device('/cpu:0'), tf.name_scope('softmax'): 112 | with tf.name_scope('softmax'): 113 | W = tf.Variable(tf.truncated_normal([num_filters_total + sentence_length * word_embedding_size, num_labels], stddev=0.1), name="W") 114 | b = tf.Variable(tf.constant(0.1, shape=[num_labels]), name="b") 115 | scores = tf.nn.xw_plus_b(all_features, W, b, name="scores") 116 | predicts = tf.arg_max(scores, dimension=1, name="predicts") 117 | self.scores = scores 118 | self.predicts = predicts 119 | 120 | #with tf.device('/cpu:0'), tf.name_scope('loss'): 121 | with tf.name_scope('loss'): 122 | if tf_version_checker >= 1: 123 | entropy = tf.nn.softmax_cross_entropy_with_logits(labels=input_y, logits=scores) 124 | else: 125 | entropy = tf.nn.softmax_cross_entropy_with_logits(scores, input_y) 126 | loss = tf.reduce_mean(entropy) 127 | self.loss = loss 128 | 129 | #with tf.device('/cpu:0'), tf.name_scope("accuracy"): 130 | with tf.name_scope("accuracy"): 131 | correct = tf.equal(predicts, tf.argmax(input_y, 1)) 132 | accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy") 133 | self.accuracy = accuracy 134 | 135 | -------------------------------------------------------------------------------- /Model_Trigger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time, datetime, os 3 | import tensorflow as tf 4 | from sklearn.metrics import classification_report, precision_score, recall_score, accuracy_score 5 | from Dataset_Trigger import Dataset_Trigger as Dataset 6 | 7 | from Config import HyperParams_Tri_classification as hp 8 | 9 | """ 10 | Trigger Classification is based on the previous argument classification task's code. 11 | Reference code link is https://github.com/zhangluoyang/cnn-for-auto-event-extract 12 | """ 13 | 14 | 15 | class Model(): 16 | def __init__(self, 17 | sentence_length=30, 18 | num_labels=10, 19 | windows=3, 20 | vocab_size=2048, 21 | pos_tag_max_size = 60, 22 | word_embedding_size=100, 23 | pos_embedding_size=10, 24 | filter_sizes=[3, 4, 5], 25 | filter_num=200, 26 | batch_size=10, 27 | embed_matrx=None 28 | ): 29 | 30 | tf_version_checker = int(tf.__version__.split('.')[0]) 31 | 32 | """ 33 | :param sentence_length 34 | :param num_labels 35 | :param windows 36 | :param vocab_size 37 | :param word_embedding_size 38 | :param pos_embedding_size 39 | :param filter_sizes 40 | :param filter_num 41 | """ 42 | 43 | # TODO: Check whether batch size can determined arbitrary in <1.0.0 version. 44 | batch_size = None 45 | # [batch_size, sentence_length] 46 | input_x = tf.placeholder(tf.int32, shape=[batch_size, sentence_length], name="input_x") 47 | self.input_x = input_x 48 | # [batch_size, num_labels] 49 | input_y = tf.placeholder(tf.float32, shape=[batch_size, num_labels], name="input_y") 50 | self.input_y = input_y 51 | 52 | # input_y_weights = tf.placeholder(tf.float32, shape=[batch_size], name="input_y_weights") 53 | 54 | # input_pos_tag = tf.placeholder(tf.int32, shape=[batch_size, sentence_length], name="input_pos_tag") 55 | # self.input_pos_tag = input_pos_tag 56 | 57 | # argument candidates distance vector 58 | # example: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 59 | input_c_pos = tf.placeholder(tf.int32, shape=[batch_size, sentence_length], name="input_c_pos") 60 | self.input_c_pos = input_c_pos 61 | 62 | dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 63 | self.dropout_keep_prob = dropout_keep_prob 64 | 65 | with tf.name_scope("word_embedding_layer"): 66 | # [vocab_size, embedding_size] 67 | 68 | # TODO: Word2Vec lookup table 69 | if embed_matrx is None: # use randomly initialized matrix as word embedding 70 | W_text = tf.Variable(tf.random_normal(shape=[vocab_size, word_embedding_size], mean=0.0, stddev=0.5), name="word_table") 71 | else: # pre-trained word embedding matrix 72 | W_text = tf.Variable(embed_matrx, trainable=False, dtype=tf.float32, name='word_embedding') 73 | input_word_vec = tf.nn.embedding_lookup(W_text, input_x) 74 | 75 | # Pos_tag = tf.Variable( 76 | # tf.random_normal(shape=[pos_tag_max_size, pos_embedding_size], mean=0.0, stddev=0.5), 77 | # name="input_pos_tag_table") 78 | # input_pos_tag_vec = tf.nn.embedding_lookup(Pos_tag, input_pos_tag) 79 | 80 | input_c_pos_c = input_c_pos + (sentence_length - 1) 81 | Can_pos = tf.Variable( 82 | tf.random_normal(shape=[2 * (sentence_length - 1) + 1, pos_embedding_size], mean=0.0, stddev=0.5), 83 | name="candidate_pos_table") 84 | input_c_pos_vec = tf.nn.embedding_lookup(Can_pos, input_c_pos_c) 85 | 86 | # The feature of the distance and the word features of the sentence constitute a collated feature as an input to the convolutional neural network. 87 | # [batch_size, sentence_length, word_embedding_size+2*pos_size] 88 | 89 | # [input_word_vec, input_c_pos_vec, input_pos_tag_vec] 90 | input_sentence_vec = tf.concat([input_word_vec, input_c_pos_vec], 2) 91 | # CNN supports 4d input, so increase the one-dimensional vector to indicate the number of input channels. 92 | input_sentence_vec_expanded = tf.expand_dims(input_sentence_vec, -1) 93 | pooled_outputs = [] 94 | for i, filter_size in enumerate(filter_sizes): 95 | with tf.name_scope('conv-maxpool-%s' % filter_size): 96 | # The current word and context of the sentence feature considered here 97 | # when using pos_tag: [filter_size, word_embedding_size + 2 * pos_embedding_size, 1, filter_num] 98 | filter_shape = [filter_size, word_embedding_size + pos_embedding_size, 1, filter_num] 99 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 100 | b = tf.Variable(tf.constant(0.1, shape=[filter_num]), name="b") 101 | # Convolution operation 102 | conv = tf.nn.conv2d( 103 | input_sentence_vec_expanded, 104 | W, 105 | strides=[1, 1, 1, 1], 106 | padding="VALID", 107 | name="conv") 108 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 109 | # Maximize pooling 110 | pooled = tf.nn.max_pool( 111 | h, 112 | ksize=[1, sentence_length - filter_size + 1, 1, 1], 113 | strides=[1, 1, 1, 1], 114 | padding='VALID', 115 | name="pool") 116 | pooled_outputs.append(pooled) 117 | 118 | num_filters_total = filter_num * len(filter_sizes) 119 | # The number of all filters used (number of channels output) 120 | h_pool = tf.concat(pooled_outputs, 3) 121 | # Expand to the next level classifier 122 | h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total]) 123 | lexical_vec = tf.reshape(input_word_vec, shape=(-1, sentence_length * word_embedding_size)) 124 | # Combine lexical level features and sentence level features 125 | # [batch_size, num_filters_total] + [batch_size, sentence_length*word_embedding_size] 126 | all_input_features = tf.concat([lexical_vec, h_pool_flat], 1) 127 | # The overall classifier goes through a layer of dropout and then into softmax 128 | with tf.device('/cpu:0'), tf.name_scope('dropout'): 129 | all_features = tf.nn.dropout(all_input_features, dropout_keep_prob) 130 | 131 | with tf.name_scope('output'): 132 | W = tf.Variable(tf.truncated_normal([num_filters_total + sentence_length * word_embedding_size, num_labels], stddev=0.1), name="W") 133 | b = tf.Variable(tf.constant(0.1, shape=[num_labels]), name="b") 134 | scores = tf.nn.xw_plus_b(all_features, W, b, name="scores") 135 | predicts = tf.arg_max(scores, dimension=1, name="predicts") 136 | self.scores = scores 137 | self.predicts = predicts 138 | 139 | with tf.name_scope('loss'): 140 | entropy = tf.nn.softmax_cross_entropy_with_logits(labels=input_y, logits=scores) 141 | loss = tf.reduce_mean(entropy) 142 | self.loss = loss 143 | 144 | with tf.name_scope("accuracy"): 145 | correct = tf.equal(predicts, tf.argmax(input_y, 1)) 146 | accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy") 147 | self.accuracy = accuracy 148 | -------------------------------------------------------------------------------- /Preprocess.py: -------------------------------------------------------------------------------- 1 | from string import ascii_letters, digits 2 | import os 3 | import xml.etree.ElementTree as ET 4 | import pickle 5 | from Config import MyConfig, HyperParams_Tri_classification as hp_f 6 | import pprint 7 | from bs4 import BeautifulSoup 8 | import json 9 | 10 | pp = pprint.PrettyPrinter(indent=4) 11 | 12 | 13 | class PreprocessManager(): 14 | def __init__(self): 15 | self.dir_list = MyConfig.raw_dir_list 16 | self.dir_path = MyConfig.raw_data_path 17 | self.dataset = [] 18 | self.tri_task_format_data = [] 19 | self.arg_task_format_data = [] 20 | 21 | def preprocess(self, tasktype, subtasktype): 22 | ''' 23 | Overall Iterator for whole dataset 24 | ''' 25 | fnames = self.fname_search() 26 | print('Total XML file: {}'.format(len(fnames))) 27 | total_res = [] 28 | for fname in fnames: 29 | total_res.append(self.process_one_file(fname)) 30 | print('total_event: {}개'.format(len(total_res))) 31 | 32 | for doc in total_res: 33 | self.dataset += self.process_sentencewise(doc) 34 | print("END PREPROCESSING") 35 | print('TOTAL DATA : {}'.format(len(self.dataset))) 36 | if tasktype=='TRIGGER': 37 | self.format_to_trigger(subtasktype) 38 | elif tasktype=='ARGUMENT': 39 | self.format_to_argument(subtasktype) 40 | else: 41 | raise ValueError 42 | 43 | print('TRIGGER DATASET: {}\nARGUMENT DATASET: {}\n'.format(len(self.tri_task_format_data), 44 | len(self.arg_task_format_data))) 45 | 46 | def format_to_trigger(self, subtasktype): 47 | for item in self.dataset: 48 | d = item[0] 49 | fname = item[1] 50 | generated_candi = self.generate_trigger_candidate_pos_list(d['trigger_position'], d['entity_position'], subtasktype) 51 | if len(d['sentence'])>hp_f.max_sequence_length:continue 52 | for candi in generated_candi: 53 | # Whether except the 'None' label at classification 54 | if subtasktype == 'CLASSIFICATION' and candi[1] == 'None': continue 55 | self.tri_task_format_data.append([d['sentence']]+candi+[fname]+[d['entity_position']]) 56 | 57 | def generate_trigger_candidate_pos_list(self, trigger_pos, entity_pos, subtasktype): 58 | cand_list = [] 59 | idx_list = [] 60 | for idx,el in enumerate(trigger_pos): 61 | if el!='*': idx_list.append((idx,el)) 62 | 63 | assert len(entity_pos)==len(trigger_pos) 64 | 65 | for idx in range(len(trigger_pos)): 66 | marks = ['A' for i in range(len(trigger_pos))] 67 | marks[idx]='B' 68 | label = 'None' 69 | for i in idx_list: 70 | if idx == i[0]: 71 | label = i[1] if subtasktype=='CLASSIFICATION' else 'TRIGGER' # else: Identification case 72 | cand_list.append([marks,label]) 73 | return cand_list 74 | 75 | def process_sentencewise(self, doc): 76 | entities, val_timexs, events, xml_fname = doc 77 | datas = [] 78 | for event in events: 79 | for e_mention in event['event_mention']: 80 | tmp = {'TYPE': event['TYPE'], 'SUBTYPE': event['SUBTYPE']} 81 | tmp['raw_sent'] = e_mention['ldc_scope']['text'] 82 | sent_pos = [int(i) for i in e_mention['ldc_scope']['position']] 83 | entities_in_sent = self.search_entity_in_sentence(entities, sent_pos) 84 | val_timexs_in_sent = self.search_valtimex_in_sentence(val_timexs, sent_pos) 85 | e_mention = self.get_argument_head(entities_in_sent, e_mention) 86 | res = self.packing_sentence(e_mention, tmp, sent_pos, entities_in_sent, val_timexs_in_sent) 87 | if res!=1: datas.append([res,xml_fname]) 88 | return datas 89 | 90 | def packing_sentence(self, e_mention, tmp, sent_pos, entities, valtimexes): 91 | packed_data = { 92 | 'sentence': [], 93 | 'EVENT_TYPE' : tmp['TYPE'], 94 | 'EVENT_SUBTYPE' : tmp['SUBTYPE'], 95 | 'entity_position' : [], 96 | } 97 | 98 | # Each Entity, value, timex2 overlap check 99 | assert self.check_entity_overlap(entities, valtimexes) 100 | raw_sent = e_mention['ldc_scope']['text'] 101 | 102 | idx_list = [0 for i in range(len(raw_sent))] 103 | if not (len(idx_list) == (int(e_mention['ldc_scope']['position'][1])-int(e_mention['ldc_scope']['position'][0])+1)): 104 | return 1 105 | sent_start_idx = int(e_mention['ldc_scope']['position'][0]) 106 | 107 | trigger_idx_list = [0 for i in range(len(raw_sent))] 108 | # pp.pprint(e_mention['anchor']) 109 | # input() 110 | # 111 | # for tri in e_mention['anchor']: 112 | # 113 | 114 | # Mark Entity position 115 | for ent in entities: 116 | ent_start_idx = int(ent['head']['position'][0]) 117 | for i in range(int(ent['head']['position'][1]) - int(ent['head']['position'][0]) + 1): 118 | if idx_list[ent_start_idx + i - sent_start_idx]==1: raise ValueError('까율~~~~~~~~~~~~~~~~~~') 119 | idx_list[ent_start_idx + i - sent_start_idx] = 1 # entity mark 120 | 121 | dupl_exist = False 122 | # Mark Value&Timex2 position 123 | for val in valtimexes: 124 | ent_start_idx = int(val['position'][0]) 125 | for i in range(int(val['position'][1]) - int(val['position'][0]) + 1): 126 | if idx_list[ent_start_idx + i - sent_start_idx] == 1: # entity mark 127 | dupl_exist = True 128 | if not dupl_exist: 129 | for val in valtimexes: 130 | ent_start_idx = int(val['position'][0]) 131 | for i in range(int(val['position'][1]) - int(val['position'][0]) + 1): 132 | idx_list[ent_start_idx + i - sent_start_idx] = 1 # entity mark 133 | 134 | token_list = [] 135 | entity_mark_list = [] 136 | curr_token = '' 137 | # TODO: save each mark as variable, not to type 'N' or 'E' each time. 138 | for idx, el in enumerate(raw_sent): 139 | if idx==0: 140 | curr_token += el 141 | continue 142 | if idx_list[idx]!=idx_list[idx-1]: 143 | if idx_list[idx-1]==1: entity_mark_list.append('E') 144 | else: entity_mark_list.append('*') 145 | token_list.append(curr_token) 146 | curr_token = el 147 | continue 148 | curr_token += el 149 | if idx == len(e_mention['ldc_scope']['text'])-1: 150 | if idx_list[idx]==1: entity_mark_list.append('E') 151 | else: entity_mark_list.append('*') 152 | token_list.append(curr_token) 153 | 154 | assert len(token_list)==len(entity_mark_list) 155 | splitted_token_list = [] # TODO: The better name.... 156 | splitted_entity_mark_list = [] 157 | 158 | for tok, mark in zip(token_list, entity_mark_list): 159 | if mark == '*': 160 | splitted_tok = tok.split() 161 | splitted_token_list += splitted_tok 162 | splitted_entity_mark_list += ['*' for i in range(len(splitted_tok))] 163 | if mark == 'E': 164 | splitted_token_list.append(tok) 165 | splitted_entity_mark_list.append('E') 166 | assert len(splitted_entity_mark_list)==len(splitted_token_list) 167 | 168 | # Arguement Mark 169 | argument_role_label = ['*' for i in range(len(splitted_entity_mark_list))] 170 | for arg in e_mention['argument']: 171 | if 'text_head' in arg: 172 | arg_text,arg_role = arg['text_head'],arg['ROLE'] 173 | else: 174 | arg_text,arg_role = arg['text'],arg['ROLE'] 175 | # TODO: Move this part to up 176 | arg_idx = None 177 | if arg_text not in splitted_token_list: 178 | for idx,el in enumerate(splitted_token_list): 179 | if arg_text in el: 180 | arg_idx = idx 181 | break 182 | else: 183 | arg_idx = splitted_token_list.index(arg_text) 184 | if arg_idx==None: 185 | # print('Exception') 186 | return 1 187 | argument_role_label[arg_idx] = arg_role 188 | 189 | assert len(splitted_entity_mark_list)==len(splitted_token_list) 190 | 191 | trigger_by_multi_w = False 192 | trigger_idx = None 193 | if e_mention['anchor']['text'] in splitted_token_list: 194 | trigger_idx = [splitted_token_list.index(e_mention['anchor']['text'])] 195 | else: 196 | for idx,tok in enumerate(splitted_token_list): 197 | if e_mention['anchor']['text'] in tok: 198 | if len(e_mention['anchor']['text'].split())>=2: continue 199 | trigger_idx = [idx] 200 | splitted_token_list[idx] = e_mention['anchor']['text'] 201 | 202 | if trigger_idx == None: # multiple trigger like 'blew him up' 203 | triggers = e_mention['anchor']['text'].split() 204 | if len(triggers)==1: 205 | print('##', triggers) 206 | return 1 207 | trigger_idx = [] 208 | first_tword = triggers[0] 209 | second_tword = triggers[1] 210 | for tok_idx,tok in enumerate(splitted_token_list): 211 | if first_tword in tok: 212 | if tok_idx!=len(splitted_token_list)-1 and second_tword in splitted_token_list[tok_idx+1]: 213 | for i in range(len(triggers)): 214 | trigger_idx.append(tok_idx+i) 215 | trigger_by_multi_w = True 216 | 217 | if trigger_idx in [None,[]]: 218 | print(123) 219 | return 1 220 | 221 | # Trigger by multiple word as one entity 222 | if trigger_by_multi_w: 223 | new_splited_token_list, new_argument_role_label, new_splited_entity_mark_list = [], [], [] 224 | first_trigger_idx = trigger_idx[0] 225 | for idx,tok in enumerate(splitted_token_list): 226 | if idx in trigger_idx: 227 | if idx==first_trigger_idx: 228 | new_splited_token_list.append(tok) 229 | new_argument_role_label.append(argument_role_label[idx]) 230 | new_splited_entity_mark_list.append(splitted_entity_mark_list[idx]) 231 | else: 232 | new_splited_token_list[-1] += ' '+tok 233 | else: 234 | new_splited_token_list.append(tok) 235 | new_argument_role_label.append(argument_role_label[idx]) 236 | new_splited_entity_mark_list.append(splitted_entity_mark_list[idx]) 237 | 238 | assert len(splitted_token_list) == (len(new_splited_token_list) + len(trigger_idx) - 1) 239 | 240 | splitted_token_list = new_splited_token_list 241 | argument_role_label = new_argument_role_label 242 | splitted_entity_mark_list = new_splited_entity_mark_list 243 | trigger_idx = [first_trigger_idx] 244 | 245 | trigger_type_label = ['*' for i in range(len(splitted_entity_mark_list))] 246 | 247 | for el in trigger_idx: 248 | trigger_type_label[el] = tmp['TYPE']# + '/' + tmp['SUBTYPE'] 249 | for idx, tok in enumerate(splitted_token_list): splitted_token_list[idx] = tok.strip() 250 | 251 | for idx, tok in enumerate(splitted_token_list): 252 | if len(tok) >= 2 and self.is_tail_symbol_only_check(tok): 253 | splitted_token_list[idx] = tok[:-1] 254 | 255 | assert len(splitted_entity_mark_list)==len(splitted_token_list)==len(trigger_type_label)==len(argument_role_label) 256 | 257 | 258 | packed_data['sentence'] = splitted_token_list 259 | packed_data['trigger_position'] = trigger_type_label 260 | packed_data['entity_position'] = splitted_entity_mark_list 261 | packed_data['argument_position'] = argument_role_label 262 | 263 | return packed_data 264 | 265 | @staticmethod 266 | def is_tail_symbol_only_check(str): 267 | if str[-1] in ascii_letters+digits: return False 268 | for c in str[:-1]: 269 | if c not in ascii_letters+digits: return False 270 | return True 271 | 272 | @staticmethod 273 | def check_entity_overlap(entities, valtimexes): 274 | ranges = [] 275 | # TODO: Implement this later 276 | for ent in entities: 277 | ranges.append(None) 278 | return True 279 | 280 | @staticmethod 281 | def search_entity_in_sentence(entities, sent_pos): 282 | headVSextent = 'head' 283 | entities_in_sent = list() 284 | check = dict() 285 | for entity in entities: 286 | for mention in entity['mention']: 287 | if sent_pos[0] <= int(mention[headVSextent]['position'][0]) and int(mention[headVSextent]['position'][1]) <= sent_pos[1]: 288 | if mention[headVSextent]['position'][0] in check: # duplicate entity in one word. 289 | #print('으악!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') 290 | #raise ValueError 291 | continue 292 | check[mention[headVSextent]['position'][0]] = 1 293 | entities_in_sent.append(mention) 294 | return entities_in_sent 295 | 296 | @staticmethod 297 | def search_valtimex_in_sentence(valtimex, sent_pos): 298 | valtimex_in_sent = list() 299 | for item in valtimex: 300 | for mention in item['mention']: 301 | if sent_pos[0] <= int(mention['position'][0]) and sent_pos[1] >= int(mention['position'][1]): 302 | valtimex_in_sent.append(mention) 303 | return valtimex_in_sent 304 | 305 | def format_to_argument(self, subtasktype): 306 | for item in self.dataset: 307 | d = item[0] 308 | fname = item[1] 309 | generated_candi = self.generate_argument_candidate_pos_list(d['argument_position'], d['entity_position'], 310 | d['trigger_position'], subtasktype) 311 | 312 | if len(d['sentence'])>80:continue 313 | 314 | trigger_cnt = 0 315 | for m in d['trigger_position']: 316 | if m=='T':trigger_cnt+=1 317 | if trigger_cnt>1:continue 318 | 319 | for candi in generated_candi: 320 | self.arg_task_format_data.append([d['sentence']]+candi+[fname]) 321 | 322 | def generate_argument_candidate_pos_list(self, arg_pos, enti_pos, trigger_pos, subtasktype): 323 | cand_list = [] 324 | Entity_as_candidate_only = True # Entity만 Candidates로 사용 325 | 326 | assert len(arg_pos)==len(enti_pos)==len(trigger_pos) 327 | for idx,el in enumerate(arg_pos): 328 | if Entity_as_candidate_only: 329 | if enti_pos[idx]!='E': continue 330 | if trigger_pos[idx]!='*': continue 331 | 332 | tri_idx_list = [] 333 | for j,a in enumerate(trigger_pos): 334 | if a != '*': tri_idx_list.append(j) 335 | 336 | marks = ['A' for i in range(len(arg_pos))] 337 | marks[idx]='B' 338 | for i in tri_idx_list: 339 | marks[i]='T' 340 | label = 'None' if arg_pos[idx]=='*' else arg_pos[idx] 341 | 342 | ''' 343 | Time-After, Time-At-End, Time-Ending , Time-Holds , Time-At-Beginning, Time-Before, Time-Within, Time-Starting to Time 344 | ''' 345 | if 'Time-' in label: label = 'Time' 346 | if subtasktype=='IDENTIFICATION' and label!='None':label = 'ARGUMENT' 347 | cand_list.append([marks,label]) 348 | return cand_list 349 | 350 | @staticmethod 351 | def get_argument_head(entities, e_mention): 352 | for idx, arg in enumerate(e_mention['argument']): 353 | arg_refID = arg['REFID'] 354 | for entity in entities: 355 | if entity['ID'] == arg_refID: 356 | e_mention['argument'][idx]['position_head'] = entity['head']['position'] 357 | e_mention['argument'][idx]['text_head'] = entity['head']['text'] 358 | return e_mention 359 | 360 | def fname_search(self): 361 | ''' 362 | Search dataset directory & Return list of (sgm fname, apf.xml fname) 363 | ''' 364 | fname_list = list() 365 | for dir in self.dir_list: 366 | # To exclude hidden files 367 | if len(dir) and dir[0] == '.': continue 368 | full_path = self.dir_path.format(dir) 369 | flist = os.listdir(full_path) 370 | for fname in flist: 371 | if '.sgm' not in fname: continue 372 | raw = fname.split('.sgm')[0] 373 | fname_list.append((self.dir_path.format(dir) + raw + '.sgm', self.dir_path.format(dir) + raw + '.apf.xml')) 374 | return fname_list 375 | 376 | def process_one_file(self, fname): 377 | # args fname = (sgm fname(full path), xml fname(full path)) 378 | # return some multiple [ sentence, entities, event mention(trigger + argument's information] 379 | xml_ent_res, xml_valtimex_res, xml_event_res = self.parse_one_xml(fname[1]) 380 | # sgm_ent_res, sgm_event_res = self.parse_one_sgm(fname[0]) 381 | # TODO : merge xml and sgm file together if need. 382 | return xml_ent_res, xml_valtimex_res, xml_event_res, fname[1] 383 | 384 | def parse_one_xml(self, fname): 385 | tree = ET.parse(fname) 386 | root = tree.getroot() 387 | entities, val_timex, events = [], [], [] 388 | 389 | for child in root[0]: 390 | if child.tag == 'entity': 391 | entities.append(self.xml_entity_parse(child, fname)) 392 | if child.tag in ['value', 'timex2']: 393 | val_timex.append(self.xml_value_timex_parse(child, fname)) 394 | if child.tag == 'event': 395 | events.append(self.xml_event_parse(child, fname)) 396 | return entities, val_timex, events 397 | 398 | def xml_value_timex_parse(self, item, fname): 399 | child = item.attrib 400 | child['fname'] = fname 401 | child['mention'] = [] 402 | for sub in item: 403 | mention = sub.attrib 404 | mention['position'] = [sub[0][0].attrib['START'], sub[0][0].attrib['END']] 405 | mention['text'] = sub[0][0].text 406 | child['mention'].append(mention) 407 | return child 408 | 409 | def xml_entity_parse(self, item, fname): 410 | entity = item.attrib 411 | entity['fname'] = fname 412 | entity['mention'] = [] 413 | entity['attribute'] = [] # What is this exactly? 414 | for sub in item: 415 | if sub.tag != 'entity_mention': continue 416 | mention = sub.attrib 417 | for el in sub: # charseq and head 418 | mention[el.tag] = dict() 419 | mention[el.tag]['position'] = [el[0].attrib['START'], el[0].attrib['END']] 420 | mention[el.tag]['text'] = el[0].text 421 | entity['mention'].append(mention) 422 | return entity 423 | 424 | def xml_event_parse(self, item, fname): 425 | # event: one event item 426 | event = item.attrib 427 | event['fname'] = fname 428 | event['argument'] = [] 429 | event['event_mention'] = [] 430 | for sub in item: 431 | if sub.tag == 'event_argument': 432 | tmp = sub.attrib 433 | event['argument'].append(tmp) 434 | continue 435 | if sub.tag == 'event_mention': 436 | mention = sub.attrib # init dict with mention ID 437 | mention['argument'] = [] 438 | for el in sub: 439 | if el.tag == 'event_mention_argument': 440 | one_arg = el.attrib 441 | one_arg['position'] = [el[0][0].attrib['START'], el[0][0].attrib['END']] 442 | one_arg['text'] = el[0][0].text 443 | mention['argument'].append(one_arg) 444 | else: # [extent, ldc_scope, anchor] case 445 | for seq in el: 446 | mention[el.tag] = dict() 447 | mention[el.tag]['position'] = [seq.attrib['START'], seq.attrib['END']] 448 | mention[el.tag]['text'] = seq.text 449 | event['event_mention'].append(mention) 450 | return event 451 | 452 | def parse_one_sgm(self, fname): 453 | print('fname :', fname) 454 | with open(fname, 'r') as f: 455 | data = f.read() 456 | soup = BeautifulSoup(data, features='html.parser') 457 | 458 | doc = soup.find('doc') 459 | doc_id = doc.docid.text 460 | doc_type = doc.doctype.text.strip() 461 | date_time = doc.datetime.text 462 | headline = doc.headline.text if doc.headline else '' 463 | 464 | body = [] 465 | 466 | if doc_type == 'WEB TEXT': 467 | posts = soup.findAll('post') 468 | for post in posts: 469 | poster = post.poster.text 470 | post.poster.extract() 471 | post_date = post.postdate.text 472 | post.postdate.extract() 473 | subject = post.subject.text if post.subject else '' 474 | if post.subject: post.subject.extract() 475 | text = post.text 476 | body.append({ 477 | 'poster': poster, 478 | 'post_date': post_date, 479 | 'subject': subject, 480 | 'text': text, 481 | }) 482 | elif doc_type in ['STORY', 'CONVERSATION', 'NEWS STORY']: 483 | turns = soup.findAll('turn') 484 | for turn in turns: 485 | speaker = turn.speaker.text if turn.speaker else '' 486 | if turn.speaker: turn.speaker.extract() 487 | text = turn.text 488 | body.append({ 489 | 'speaker': speaker, 490 | 'text': text, 491 | }) 492 | 493 | result = { 494 | 'doc_id': doc_id, 495 | 'doc_type': doc_type, 496 | 'date_time': date_time, 497 | 'headline': headline, 498 | 'body': body, 499 | } 500 | 501 | return result 502 | 503 | def Data2Json(self, data): 504 | pass 505 | 506 | def next_train_data(self): 507 | pass 508 | 509 | def eval_data(self): 510 | pass 511 | 512 | 513 | if __name__ == '__main__': 514 | man = PreprocessManager() 515 | man.preprocess() 516 | 517 | # Example 518 | trigger_classification_data = man.tri_task_format_data 519 | argument_classification_data = man.arg_task_format_data 520 | # print('\n\n') 521 | 522 | 523 | all_labels = set() 524 | total = 0 525 | for data in argument_classification_data: 526 | total += 1 527 | all_labels.add(data[2]) 528 | 529 | print('total :', total) 530 | print('label len:', len(all_labels)) 531 | 532 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Event Extraction 2 | 3 | Tensorflow Implementation of Deep Learning Approach for Event Extraction([**ACE 2005**](https://catalog.ldc.upenn.edu/LDC2006T06)) via Dynamic Multi-Pooling Convolutional Neural Networks. 4 | 5 | # Requirements 6 | 7 | * Tensorflow 8 | * Scikit-learn 9 | * NLTK 10 | 11 | `pip install -r requirements.txt` may help. 12 | 13 | ## Usage 14 | 15 | ### Train 16 | 17 | * "[GoogleNews-vectors-negative300](https://code.google.com/archive/p/word2vec/)" is used as a pre-trained word2vec model. 18 | 19 | * "[glove.6B](https://nlp.stanford.edu/projects/glove/)" is used as a pre-trained GloVe model. 20 | 21 | * Performance (accuracy and f1-socre) outputs during training are **UNOFFICIAL SCORES** of *ACE 2005*. 22 | 23 | ##### Train Example: 24 | ```bash 25 | $ python Script.py {taskID} {subtaskID} 26 | ``` 27 | * `taskID`: 1 for Trigger, 2 for Argument 28 | 29 | * `subtaskID`: 1 for Identification, 2 for Classification 30 | 31 | * After model training, evaluation results will be shown. 32 | 33 | ```bash 34 | $ python Script.py 1 2 # Script for `Trigger Classification` 35 | ``` 36 | 37 | ## Todo 38 | 39 | - Apply Dymamic Multi-Pooling CNN 40 | - Evaluation Script 41 | 42 | ## Results 43 | 44 | ### Trigger identification performance 45 | ``` 46 | precision recall f1-score support 47 | 48 | TRIGGER 0.59 0.44 0.50 527 49 | None 0.97 0.98 0.98 9151 50 | 51 | micro avg 0.95 0.95 0.95 9678 52 | macro avg 0.78 0.71 0.74 9678 53 | weighted avg 0.95 0.95 0.95 9678 54 | 55 | ``` 56 | 57 | ### Trigger classification performance 58 | 59 | ``` 60 | precision recall f1-score support 61 | 62 | Life 0.75 0.70 0.72 114 63 | Justice 0.78 0.85 0.81 114 64 | Movement 0.69 0.70 0.69 53 65 | Personnel 0.68 0.64 0.66 78 66 | Business 0.75 0.46 0.57 13 67 | Conflict 0.78 0.83 0.80 247 68 | Contact 0.79 0.86 0.83 36 69 | Transaction 1.00 0.48 0.65 27 70 | 71 | micro avg 0.76 0.76 0.76 682 72 | macro avg 0.78 0.69 0.72 682 73 | weighted avg 0.76 0.76 0.76 682 74 | ``` 75 | 76 | ##### with None label 77 | ``` 78 | precision recall f1-score support 79 | 80 | Movement 0.47 0.21 0.29 68 81 | Business 1.00 0.10 0.18 10 82 | Contact 0.67 0.22 0.33 37 83 | Justice 0.32 0.17 0.23 63 84 | None 0.96 0.99 0.98 8348 85 | Conflict 0.70 0.38 0.50 156 86 | Life 0.64 0.38 0.48 65 87 | Transaction 0.75 0.10 0.18 29 88 | Personnel 0.73 0.24 0.36 46 89 | 90 | micro avg 0.95 0.95 0.95 8822 91 | macro avg 0.69 0.31 0.39 8822 92 | weighted avg 0.94 0.95 0.94 8822 93 | ``` 94 | 95 | ### Argument classification performance 96 | ``` 97 | precision recall f1-score support 98 | 99 | Seller 0.00 0.00 0.00 4 100 | Money 0.00 0.00 0.00 13 101 | Target 0.30 0.16 0.21 67 102 | Destination 0.45 0.20 0.28 49 103 | Victim 0.35 0.23 0.28 48 104 | Instrument 0.25 0.10 0.14 31 105 | Crime 0.67 0.14 0.23 43 106 | Adjudicator 0.00 0.00 0.00 20 107 | Origin 0.00 0.00 0.00 34 108 | Time 0.46 0.22 0.30 193 109 | Agent 0.00 0.00 0.00 40 110 | Position 0.00 0.00 0.00 20 111 | Giver 0.00 0.00 0.00 16 112 | Beneficiary 0.00 0.00 0.00 5 113 | Org 0.00 0.00 0.00 6 114 | Artifact 0.00 0.00 0.00 14 115 | Place 0.28 0.21 0.24 149 116 | None 0.75 0.95 0.84 2593 117 | Prosecutor 0.00 0.00 0.00 2 118 | Person 0.25 0.18 0.21 113 119 | Attacker 0.32 0.11 0.16 75 120 | Defendant 0.47 0.14 0.21 51 121 | Sentence 0.67 0.40 0.50 10 122 | Plaintiff 0.00 0.00 0.00 17 123 | Vehicle 0.00 0.00 0.00 10 124 | Entity 0.17 0.02 0.03 110 125 | Recipient 0.00 0.00 0.00 4 126 | Price 0.00 0.00 0.00 1 127 | Buyer 0.00 0.00 0.00 4 128 | 129 | micro avg 0.70 0.70 0.70 3742 130 | macro avg 0.19 0.11 0.13 3742 131 | weighted avg 0.61 0.70 0.64 3742 132 | ``` 133 | 134 | ## References 135 | 136 | * **Event Extraction via Dynamic Multi-Pooling Convolutional Neural Networks** (IJCNLP 2015), Chen, Yubo, et al. [[paper]](https://pdfs.semanticscholar.org/ca70/480f908ec60438e91a914c1075b9954e7834.pdf) 137 | * zhangluoyang's cnn-for-auto-event-extract repository [[github]](https://github.com/zhangluoyang/cnn-for-auto-event-extract) 138 | -------------------------------------------------------------------------------- /Script.py: -------------------------------------------------------------------------------- 1 | import datetime, os, time 2 | import numpy as np 3 | import tensorflow as tf 4 | from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score, precision_recall_fscore_support as prf_score 5 | from Util import train_parser 6 | from Dataset import Dataset as ARGUMENT_DATASET 7 | from Dataset_Trigger import Dataset_Trigger as TRIGGER_DATASET 8 | from Config import HyperParams_Tri_classification as hp_trigger, HyperParams as hp_argument 9 | import Visualize 10 | 11 | if __name__ == '__main__': 12 | task, subtask = train_parser() 13 | subtask_type = 'IDENTIFICATION' if subtask == 1 else 'CLASSIFICATION' 14 | hp, dataset, Model = [None for _ in range(3)] 15 | 16 | if task == 1: 17 | hp = hp_trigger 18 | dataset = TRIGGER_DATASET(batch_size=hp.batch_size, max_sequence_length=hp.max_sequence_length, 19 | windows=hp.windows, dtype=subtask_type) 20 | for label in dataset.all_labels: 21 | print(label + ' ' + str(dataset.label_id[label])) 22 | 23 | from Model_Trigger import Model 24 | 25 | print("\n\nTrigger {} start.\n\n".format(subtask_type)) 26 | if task == 2: 27 | hp = hp_argument 28 | dataset = ARGUMENT_DATASET(batch_size=hp.batch_size, max_sequence_length=hp.max_sequence_length, 29 | windows=hp.windows, dtype=subtask_type) 30 | from Model import Model 31 | 32 | print("\n\nArgument {} start.\n\n".format(subtask_type)) 33 | 34 | with tf.Graph().as_default(): 35 | sess = tf.Session() 36 | with sess.as_default(): 37 | model = Model(sentence_length=hp.max_sequence_length, 38 | num_labels=len(dataset.all_labels), 39 | vocab_size=len(dataset.all_words), 40 | word_embedding_size=hp.word_embedding_size, 41 | pos_embedding_size=hp.pos_embedding_size, 42 | filter_sizes=hp.filter_sizes, 43 | pos_tag_max_size=len(dataset.all_pos_taggings), 44 | filter_num=hp.filter_num, 45 | embed_matrx=dataset.word_embed) 46 | 47 | optimizer = tf.train.AdamOptimizer(hp.lr) 48 | grads_and_vars = optimizer.compute_gradients(model.loss) 49 | train_op = optimizer.apply_gradients(grads_and_vars) 50 | 51 | timestamp = str(int(time.time())) 52 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 53 | print("Writing to {}\n".format(out_dir)) 54 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 55 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 56 | if not os.path.exists(checkpoint_dir): 57 | os.makedirs(checkpoint_dir) 58 | saver = tf.train.Saver(tf.all_variables(), max_to_keep=20) 59 | sess.run(tf.initialize_all_variables()) 60 | 61 | def trigger_train_step(input_x, input_y, input_c, input_c_pos, input_pos_tag, dropout_keep_prob, log=False): 62 | feed_dict = { 63 | model.input_x: input_x, 64 | model.input_y: input_y, 65 | model.input_c_pos: input_c_pos, 66 | # model.input_pos_tag: input_pos_tag, 67 | model.dropout_keep_prob: dropout_keep_prob, 68 | } 69 | _, loss, accuracy = sess.run([train_op, model.loss, model.accuracy], feed_dict) 70 | time_str = datetime.datetime.now().isoformat() 71 | if log: 72 | print("{}: loss {:g}, acc {:g}".format(time_str, loss, accuracy)) 73 | 74 | 75 | def trigger_eval_step(input_x, input_y, input_c, input_c_pos, input_pos_tag, dropout_keep_prob, is_test=False): 76 | feed_dict = { 77 | model.input_x: input_x, 78 | model.input_y: input_y, 79 | model.input_c_pos: input_c_pos, 80 | # model.input_pos_tag: input_pos_tag, 81 | model.dropout_keep_prob: dropout_keep_prob, 82 | } 83 | accuracy, predicts = sess.run([model.accuracy, model.predicts], feed_dict) 84 | #print("eval accuracy:{}".format(accuracy)) 85 | 86 | 87 | y_true = [np.argmax(item) for item in input_y] 88 | y_pred = predicts 89 | target_names = dataset.all_labels 90 | 91 | print(classification_report(y_true, y_pred, 92 | target_names=dataset.all_labels)) 93 | 94 | metrics = ['macro','weighted','micro'] 95 | for metric in metrics: 96 | print("\n## {} ##".format(metric)) 97 | res = prf_score(y_true, y_pred,average=metric) 98 | 99 | prf = [round(res[0]*100,2),round(res[1]*100,2),round(res[2]*100,2)] 100 | print('Precision Recall F1') 101 | print('{} {} {}'.format(prf[0], prf[1], prf[2])) 102 | print('Accuracy: {}%'.format(round(100*accuracy,2))) 103 | 104 | Visualize.draw( 105 | epoch=epoch, 106 | input_x=input_x, 107 | input_y=[np.argmax(item) for item in input_y], 108 | predicts=predicts, 109 | input_c_pos=input_c_pos, 110 | id2label = dataset.id2label, 111 | id2word=dataset.id2word, 112 | ) 113 | 114 | return predicts 115 | 116 | 117 | def argument_train_step(input_x, input_y, input_t, input_c, input_t_pos, input_c_pos, dropout_keep_prob): 118 | feed_dict = { 119 | model.input_x: input_x, 120 | model.input_y: input_y, 121 | # model.input_t:input_t, 122 | # model.input_c:input_c, 123 | model.input_t_pos: input_t_pos, 124 | model.input_c_pos: input_c_pos, 125 | model.dropout_keep_prob: dropout_keep_prob, 126 | } 127 | _, loss, accuracy = sess.run([train_op, model.loss, model.accuracy], feed_dict) 128 | time_str = datetime.datetime.now().isoformat() 129 | # print("{}: loss {:g}, acc {:g}".format(time_str, loss, accuracy)) 130 | 131 | 132 | def argument_eval_step(input_x, input_y, input_t, input_c, input_t_pos, input_c_pos, dropout_keep_prob): 133 | feed_dict = { 134 | model.input_x: input_x, 135 | model.input_y: input_y, 136 | # model.input_t:input_t, 137 | # model.input_c:input_c, 138 | model.input_t_pos: input_t_pos, 139 | model.input_c_pos: input_c_pos, 140 | model.dropout_keep_prob: dropout_keep_prob, 141 | } 142 | accuracy, predicts = sess.run([model.accuracy, model.predicts], feed_dict) 143 | from sklearn.metrics import classification_report 144 | print("eval accuracy:{}".format(accuracy)) 145 | # print("input_y : ", [np.argmax(item) for item in input_y], ', predicts :', predicts) 146 | print(classification_report([np.argmax(item) for item in input_y], predicts, 147 | target_names=dataset.all_labels)) 148 | return predicts 149 | 150 | 151 | print("TRAIN START") 152 | for epoch in range(hp.num_epochs): 153 | print('epoch: {}/{}'.format(epoch + 1, hp.num_epochs)) 154 | for j in range(len(dataset.train_instances) // hp.batch_size): 155 | if task == 1: 156 | x, c, y, pos_c, pos_tag = dataset.next_train_data() 157 | if j==0: 158 | trigger_train_step(input_x=x, input_y=y, input_c=c, input_c_pos=pos_c, input_pos_tag=pos_tag, 159 | dropout_keep_prob=0.5, log=True) 160 | else: 161 | trigger_train_step(input_x=x, input_y=y, input_c=c, input_c_pos=pos_c, 162 | input_pos_tag=pos_tag, 163 | dropout_keep_prob=0.5) 164 | 165 | if task == 2: 166 | x, t, c, y, pos_c, pos_t, _ = dataset.next_train_data() 167 | argument_train_step(input_x=x, input_y=y, input_t=t, input_c=c, input_c_pos=pos_c, 168 | input_t_pos=pos_t, 169 | dropout_keep_prob=0.5) 170 | 171 | if epoch % 5 == 0: 172 | if task == 1: # Trigger 173 | x, c, y, pos_c, pos_tag = dataset.next_valid_data() 174 | trigger_eval_step(input_x=x, input_y=y, input_c=c, input_c_pos=pos_c, input_pos_tag=pos_tag, 175 | dropout_keep_prob=1.0) 176 | path = saver.save(sess, checkpoint_prefix + "-Trigger-Identification", epoch) 177 | print("Saved model checkpoint to {}\n".format(path)) 178 | 179 | x, c, y, pos_c, pos_tag = dataset.next_eval_data() 180 | trigger_eval_step(input_x=x, input_y=y, input_c=c, input_c_pos=pos_c, input_pos_tag=pos_tag, 181 | dropout_keep_prob=1.0, is_test=True) 182 | 183 | 184 | if task == 2: 185 | x, t, c, y, pos_c, pos_t, _ = dataset.eval_data() 186 | argument_eval_step(input_x=x, input_y=y, input_t=t, input_c=c, input_c_pos=pos_c, 187 | input_t_pos=pos_t, 188 | dropout_keep_prob=1.0) 189 | 190 | print("----test results---------------------------------------------------------------------") 191 | if task == 1: 192 | x, c, y, pos_c, pos_tag = dataset.next_eval_data() 193 | predicts = trigger_eval_step(input_x=x, input_y=y, input_c=c, input_c_pos=pos_c, input_pos_tag=pos_tag, dropout_keep_prob=1.0, is_test=True) 194 | if task == 2: 195 | x, t, c, y, pos_c, pos_t, _ = dataset.eval_data() 196 | predicts = argument_eval_step(input_x=x, input_y=y, input_t=t, input_c=c, input_c_pos=pos_c, 197 | input_t_pos=pos_t, 198 | dropout_keep_prob=1.0) 199 | -------------------------------------------------------------------------------- /Script_Test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpcl-lab/event-extraction/05975c348bf7e278cdf086bc3a5171ad28fce14f/Script_Test.py -------------------------------------------------------------------------------- /Serving_Trigger.py: -------------------------------------------------------------------------------- 1 | import datetime, os, sys, json 2 | import tensorflow as tf 3 | from Dataset_Trigger import Dataset_Trigger as TRIGGER_DATASET 4 | from Config import HyperParams_Tri_classification as hp 5 | import nltk 6 | 7 | from flask import Flask, session, g, request, render_template, redirect, Response, jsonify 8 | 9 | app = Flask(__name__) 10 | 11 | def get_batch(sentence, word_id, max_sequence_length): 12 | tokens = [word for word in nltk.word_tokenize(sentence)] 13 | words = [] 14 | for i in range(max_sequence_length): 15 | if i < len(tokens): 16 | words.append(tokens[i]) 17 | else: 18 | words.append('') 19 | 20 | word_ids = [] 21 | for word in words: 22 | if word in word_id: 23 | word_ids.append(word_id[word]) 24 | else: 25 | word_ids.append(word_id['']) 26 | 27 | # print('word_ids :', word_ids) 28 | size = len(word_ids) 29 | 30 | x_batch = [] 31 | x_pos_batch = [] 32 | for i in range(size): 33 | x_batch.append(word_ids) 34 | x_pos_batch.append([j - i for j in range(size)]) 35 | 36 | return x_batch, x_pos_batch, tokens 37 | 38 | 39 | dataset = TRIGGER_DATASET(batch_size=hp.batch_size, max_sequence_length=hp.max_sequence_length, 40 | windows=hp.windows, dtype='IDENTIFICATION') 41 | 42 | checkpoint_dir = './runs/1543232582/checkpoints' 43 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 44 | 45 | graph = tf.Graph() 46 | with graph.as_default(): 47 | sess = tf.Session() 48 | with sess.as_default(): 49 | # Load the saved meta graph and restore variables 50 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 51 | print('restore model from {}.meta'.format(checkpoint_file)) 52 | saver.restore(sess, checkpoint_file) 53 | 54 | @app.route('/api/event-extraction/trigger/identification', methods=['POST']) 55 | def serving(): 56 | data = request.get_json() 57 | sentence = data['sentence'] 58 | 59 | x_batch, x_pos_batch, tokens = get_batch(sentence=sentence, word_id=dataset.word_id, max_sequence_length=hp.max_sequence_length) 60 | 61 | print('x_batch :', x_batch) 62 | print('x_pos_batch :', x_pos_batch) 63 | 64 | # Get the placeholders from the graph by name 65 | input_x = graph.get_operation_by_name("input_x").outputs[0] 66 | input_c_pos = graph.get_operation_by_name("input_c_pos").outputs[0] 67 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 68 | 69 | # Tensors we want to evaluate 70 | predictions = graph.get_operation_by_name("output/predicts").outputs[0] 71 | 72 | feed_dict = { 73 | input_x: x_batch, 74 | input_c_pos: x_pos_batch, 75 | dropout_keep_prob: 1.0, 76 | } 77 | 78 | preds = sess.run(predictions, feed_dict) 79 | print('id2label : ', dataset.id2label) 80 | result = '' 81 | for i in range(len(preds)): 82 | word = dataset.id2word[x_batch[0][i]] 83 | if word == '': word = tokens[i] 84 | if word == '': break 85 | print('word: {}, pred: {}'.format(word, str(preds[i]))) 86 | result += '{}/{} '.format(word, dataset.id2label[preds[i]]) 87 | 88 | return Response(json.dumps({'result': result}), status=200, mimetype='application/json') 89 | 90 | base_dir = os.path.abspath(os.path.dirname(__file__) + '/') 91 | sys.path.append(base_dir) 92 | FLASK_DEBUG = os.getenv('FLASK_DEBUG', True) 93 | app.run(host='0.0.0.0', debug=FLASK_DEBUG, port=8085) 94 | -------------------------------------------------------------------------------- /Util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def train_parser(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('task',help='Trigger:1 Argument:2',type=int, choices=[1,2]) 7 | parser.add_argument('subtask',help='Identification:1 Classification:2',type=int, choices=[1,2]) 8 | args = parser.parse_args() 9 | task = args.task 10 | subtask = args.subtask 11 | return task,subtask 12 | 13 | 14 | def find_candidates(items1, items2): 15 | result = [] 16 | for i in range(len(items1)): 17 | if items1[i] in items2: 18 | result.append(i) 19 | return result 20 | 21 | 22 | def one_hot(labels, label_id, label_num): 23 | result = [] 24 | for i in range(0, len(labels)): 25 | one_hot_vec = [0] * label_num 26 | one_hot_vec[label_id[labels[i]]] = 1 27 | result.append(one_hot_vec) 28 | return result 29 | -------------------------------------------------------------------------------- /Visualize.py: -------------------------------------------------------------------------------- 1 | def draw(epoch, input_x, input_y, predicts, input_c_pos, id2label, id2word): 2 | sents_visual_file = './visualization/{}.html'.format(epoch) 3 | 4 | batch_size = len(input_y) 5 | with open(sents_visual_file, "w") as html_file: 6 | html_file.write('') 7 | 8 | for i in range(batch_size): 9 | if input_y[i] == predicts[i]: continue 10 | 11 | sent_size = len(input_x[i]) 12 | current_pos = 0 13 | for j in range(sent_size): 14 | if input_c_pos[i][j] == 0: 15 | current_pos = j 16 | break 17 | 18 | sent = '' 19 | for j in range(sent_size): 20 | word = id2word[input_x[i][j]] 21 | if word == '': continue 22 | if j == current_pos: 23 | sent += '{} '.format(word) 24 | else: 25 | sent += word + ' ' 26 | 27 | html_file.write('

{}

'.format(sent)) 28 | html_file.write('
Prediction: {}
'.format(id2label[predicts[i]])) 29 | html_file.write('
Answer: {}
'.format(id2label[input_y[i]])) 30 | 31 | html_file.write('') 32 | 33 | html_file.write('') 34 | 35 | 36 | if __name__ == '__main__': 37 | draw() 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | astor==0.8.0 3 | beautifulsoup4==4.8.0 4 | gast==0.3.2 5 | google-pasta==0.1.7 6 | grpcio==1.23.0 7 | h5py==2.10.0 8 | joblib==0.13.2 9 | Keras-Applications==1.0.8 10 | Keras-Preprocessing==1.1.0 11 | Markdown==3.1.1 12 | nltk==3.4.5 13 | numpy==1.17.2 14 | protobuf==3.9.1 15 | scikit-learn==0.21.3 16 | scipy==1.3.1 17 | six==1.12.0 18 | soupsieve==1.9.3 19 | tensorboard==1.14.0 20 | tensorflow==1.14.0 21 | tensorflow-estimator==1.14.0 22 | tensorflow-gpu==1.14.0 23 | termcolor==1.1.0 24 | Werkzeug==0.15.6 25 | wrapt==1.11.2 26 | -------------------------------------------------------------------------------- /send_script.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | filenames = ['10_1_mic1.wav','10_1_mic2.wav'] 4 | 5 | files = dict() 6 | for filename in filenames: 7 | f = open(filename, 'rb') 8 | files[filename] = f 9 | 10 | res = requests.post('http://localhost:8080/upload', files=files) 11 | print('res :', res) 12 | 13 | r = requests.get('http://localhost:8080/download') 14 | with open('result.wav', 'wb') as f: 15 | f.write(r.content) 16 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | 3 | if __name__ == '__main__': 4 | """ 5 | To get phrase using START, END offsets 6 | 7 | Secretary of Homeland Security Tom Ridge 8 | 9 | """ 10 | with open('./data/ace_2005_td_v7/data/English/bc/adj/CNN_CF_20030303.1900.00.sgm', 'r') as f: 11 | data = f.read() 12 | soup = BeautifulSoup(data, features='html.parser') 13 | text = soup.text 14 | print(text[754-1:793]) 15 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpcl-lab/event-extraction/05975c348bf7e278cdf086bc3a5171ad28fce14f/visualization/__init__.py --------------------------------------------------------------------------------