├── README.md └── baseline.py /README.md: -------------------------------------------------------------------------------- 1 | Data for the shared task is available at https://github.com/SUDA-HLT/IPRE, and the review paper is available at https://arxiv.org/abs/1908.11337. 2 | # A Baseline System For CCKS-2019-IPRE 3 | 4 | ## Introduction 5 | We provide a baseline system based on convolutional neural network with selective attention. 6 | 7 | ## Getting Started 8 | ### Environment Requirements 9 | * python 3.6 10 | * numpy 11 | * tensorflow 1.12.0 12 | ### Step 1: Download data 13 | Please download the data from [the competition website](https://biendata.com/competition/ccks_2019_ipre/data/), then unzip files and put them in `./data/` folder. 14 | 15 | ### Step 2: Train the model 16 | You can use the following command to train models for Sent-Track or Bag-Track: 17 | ``` 18 | python baseline.py --level sent 19 | python baseline.py --level bag 20 | ``` 21 | The model will be stored in `./model/` floder. We provide large scale unmarked corpus for train word vectors or language mdoels. The word vectors used in baseline system are trained by a package named gensim in python, and some parameters are set as follows: 22 | ``` 23 | from gensim.models import word2vec 24 | model = word2vec.Word2Vec(sentences, sg=1, size=300, window=5, min_count=10, negative=5, sample=1e-4, workers=10) 25 | ``` 26 | ### Step 3: Test the model 27 | You can use the following command to test models for Sent-Track or Bag-Track: 28 | ``` 29 | python baseline.py --mode test --level sent 30 | python baseline.py --mode test --level bag 31 | ``` 32 | Predicted results will be stored in result_sent.txt or result_bag.txt. 33 | ## Evaluation 34 | We use f1 score as the basic evaluation metric to measure the performance of systems. In our baseline system, we get about 0.22 f1 score in Sent-track and about 0.31 f1 score in Bag-Track by using pre-trained word vectors. 35 | ## References 36 | * Wang H, He Z, Zhu T, et al. CCKS 2019 Shared Task on Inter-Personal Relationship Extraction[J]. arXiv preprint arXiv:1908.11337, 2019. 37 | * Lin Y, Shen S, Liu Z, et al. Neural relation extraction with selective attention over instances[C]//Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2016, 1: 2124-2133. 38 | -------------------------------------------------------------------------------- /baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | import os 5 | import datetime 6 | from collections import Counter 7 | 8 | def set_seed(): 9 | os.environ['PYTHONHASHSEED'] = '0' 10 | np.random.seed(2019) 11 | random.seed(2019) 12 | tf.set_random_seed(2019) 13 | 14 | set_seed() 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | tf.app.flags.DEFINE_string('cuda', '0', 'gpu id') 18 | tf.app.flags.DEFINE_boolean('pre_embed', True, 'load pre-trained word2vec') 19 | tf.app.flags.DEFINE_integer('batch_size', 50, 'batch size') 20 | tf.app.flags.DEFINE_integer('epochs', 200, 'max train epochs') 21 | tf.app.flags.DEFINE_integer('hidden_dim', 300, 'dimension of hidden embedding') 22 | tf.app.flags.DEFINE_integer('word_dim', 300, 'dimension of word embedding') 23 | tf.app.flags.DEFINE_integer('pos_dim', 5, 'dimension of position embedding') 24 | tf.app.flags.DEFINE_integer('pos_limit', 15, 'max distance of position embedding') 25 | tf.app.flags.DEFINE_integer('sen_len', 60, 'sentence length') 26 | tf.app.flags.DEFINE_integer('window', 3, 'window size') 27 | tf.app.flags.DEFINE_string('model_path', './model', 'save model dir') 28 | tf.app.flags.DEFINE_string('data_path', './data', 'data dir to load') 29 | tf.app.flags.DEFINE_string('level', 'bag', 'bag level or sentence level, option:bag/sent') 30 | tf.app.flags.DEFINE_string('mode', 'train', 'train or test') 31 | tf.app.flags.DEFINE_float('dropout', 0.5, 'dropout rate') 32 | tf.app.flags.DEFINE_float('lr', 0.001, 'learning rate') 33 | tf.app.flags.DEFINE_integer('word_frequency', 5, 'minimum word frequency when constructing vocabulary list') 34 | 35 | class Baseline: 36 | def __init__(self, flags): 37 | self.lr = flags.lr 38 | self.sen_len = flags.sen_len 39 | self.pre_embed = flags.pre_embed 40 | self.pos_limit = flags.pos_limit 41 | self.pos_dim = flags.pos_dim 42 | self.window = flags.window 43 | self.word_dim = flags.word_dim 44 | self.hidden_dim = flags.hidden_dim 45 | self.batch_size = flags.batch_size 46 | self.data_path = flags.data_path 47 | self.model_path = flags.model_path 48 | self.mode = flags.mode 49 | self.epochs = flags.epochs 50 | self.dropout = flags.dropout 51 | self.word_frequency = flags.word_frequency 52 | 53 | if flags.level == 'sent': 54 | self.bag = False 55 | elif flags.level == 'bag': 56 | self.bag = True 57 | else: 58 | self.bag = True 59 | 60 | self.pos_num = 2 * self.pos_limit + 3 61 | self.relation2id = self.load_relation() 62 | self.num_classes = len(self.relation2id) 63 | 64 | 65 | if self.pre_embed: 66 | self.wordMap, word_embed = self.load_wordVec() 67 | self.word_embedding = tf.get_variable(initializer=word_embed, name='word_embedding', trainable=False) 68 | 69 | else: 70 | self.wordMap = self.load_wordMap() 71 | self.word_embedding = tf.get_variable(shape=[len(self.wordMap), self.word_dim], name='word_embedding',trainable=True) 72 | 73 | self.pos_e1_embedding = tf.get_variable(name='pos_e1_embedding', shape=[self.pos_num, self.pos_dim]) 74 | self.pos_e2_embedding = tf.get_variable(name='pos_e2_embedding', shape=[self.pos_num, self.pos_dim]) 75 | 76 | self.relation_embedding = tf.get_variable(name='relation_embedding', shape=[self.hidden_dim, self.num_classes]) 77 | self.relation_embedding_b = tf.get_variable(name='relation_embedding_b', shape=[self.num_classes]) 78 | 79 | self.sentence_reps = self.CNN_encoder() 80 | 81 | if self.bag: 82 | self.bag_level() 83 | else: 84 | self.sentence_level() 85 | self._classifier_train_op = tf.train.AdamOptimizer(self.lr).minimize(self.classifier_loss) 86 | 87 | def pos_index(self, x): 88 | if x < -self.pos_limit: 89 | return 0 90 | if x >= -self.pos_limit and x <= self.pos_limit: 91 | return x + self.pos_limit + 1 92 | if x > self.pos_limit: 93 | return 2 * self.pos_limit + 2 94 | 95 | def load_wordVec(self): 96 | wordMap = {} 97 | wordMap['PAD'] = len(wordMap) 98 | wordMap['UNK'] = len(wordMap) 99 | word_embed = [] 100 | for line in open(os.path.join(self.data_path, 'word2vec.txt')): 101 | content = line.strip().split() 102 | if len(content) != self.word_dim + 1: 103 | continue 104 | wordMap[content[0]] = len(wordMap) 105 | word_embed.append(np.asarray(content[1:], dtype=np.float32)) 106 | 107 | word_embed = np.stack(word_embed) 108 | embed_mean, embed_std = word_embed.mean(), word_embed.std() 109 | 110 | pad_embed = np.random.normal(embed_mean, embed_std, (2, self.word_dim)) 111 | word_embed = np.concatenate((pad_embed, word_embed), axis=0) 112 | word_embed = word_embed.astype(np.float32) 113 | return wordMap, word_embed 114 | 115 | def load_wordMap(self): 116 | wordMap = {} 117 | wordMap['PAD'] = len(wordMap) 118 | wordMap['UNK'] = len(wordMap) 119 | all_content = [] 120 | for line in open(os.path.join(self.data_path, 'sent_train.txt')): 121 | all_content += line.strip().split('\t')[3].split() 122 | for item in Counter(all_content).most_common(): 123 | if item[1] > self.word_frequency: 124 | wordMap[item[0]] = len(wordMap) 125 | else: 126 | break 127 | return wordMap 128 | 129 | def load_relation(self): 130 | relation2id = {} 131 | for line in open(os.path.join(self.data_path, 'relation2id.txt')): 132 | relation, id_ = line.strip().split() 133 | relation2id[relation] = int(id_) 134 | return relation2id 135 | 136 | def load_sent(self, filename): 137 | sentence_dict = {} 138 | with open(os.path.join(self.data_path, filename), 'r') as fr: 139 | for line in fr: 140 | id_, en1, en2, sentence = line.strip().split('\t') 141 | sentence = sentence.split() 142 | en1_pos = 0 143 | en2_pos = 0 144 | for i in range(len(sentence)): 145 | if sentence[i] == en1: 146 | en1_pos = i 147 | if sentence[i] == en2: 148 | en2_pos = i 149 | words = [] 150 | pos1 = [] 151 | pos2 = [] 152 | 153 | length = min(self.sen_len, len(sentence)) 154 | 155 | for i in range(length): 156 | words.append(self.wordMap.get(sentence[i], self.wordMap['UNK'])) 157 | pos1.append(self.pos_index(i - en1_pos)) 158 | pos2.append(self.pos_index(i - en2_pos)) 159 | 160 | if length < self.sen_len: 161 | for i in range(length, self.sen_len): 162 | words.append(self.wordMap['PAD']) 163 | pos1.append(self.pos_index(i - en1_pos)) 164 | pos2.append(self.pos_index(i - en2_pos)) 165 | sentence_dict[id_] = np.reshape(np.asarray([words, pos1, pos2], dtype=np.int32), (1, 3, self.sen_len)) 166 | return sentence_dict 167 | 168 | def data_batcher(self, sentence_dict, filename, padding=False, shuffle=True): 169 | if self.bag: 170 | all_bags = [] 171 | all_sents = [] 172 | all_labels = [] 173 | with open(os.path.join(self.data_path, filename), 'r') as fr: 174 | for line in fr: 175 | rel = [0] * self.num_classes 176 | try: 177 | bag_id, _, _, sents, types = line.strip().split('\t') 178 | type_list = types.split() 179 | for tp in type_list: 180 | if len(type_list) > 1 and tp == '0': # if a bag has multiple relations, we only consider non-NA relations 181 | continue 182 | rel[int(tp)] = 1 183 | except: 184 | bag_id, _, _, sents = line.strip().split('\t') 185 | 186 | sent_list = [] 187 | for sent in sents.split(): 188 | sent_list.append(sentence_dict[sent]) 189 | 190 | all_bags.append(bag_id) 191 | all_sents.append(np.concatenate(sent_list,axis=0)) 192 | all_labels.append(np.asarray(rel, dtype=np.float32)) 193 | 194 | self.data_size = len(all_bags) 195 | self.datas = all_bags 196 | data_order = list(range(self.data_size)) 197 | if shuffle: 198 | np.random.shuffle(data_order) 199 | if padding: 200 | if self.data_size % self.batch_size != 0: 201 | data_order += [data_order[-1]] * (self.batch_size - self.data_size % self.batch_size) 202 | 203 | for i in range(len(data_order) // self.batch_size): 204 | total_sens = 0 205 | out_sents = [] 206 | out_sent_nums = [] 207 | out_labels = [] 208 | for k in data_order[i * self.batch_size:(i + 1) * self.batch_size]: 209 | out_sents.append(all_sents[k]) 210 | out_sent_nums.append(total_sens) 211 | total_sens += all_sents[k].shape[0] 212 | out_labels.append(all_labels[k]) 213 | 214 | 215 | out_sents = np.concatenate(out_sents, axis=0) 216 | out_sent_nums.append(total_sens) 217 | out_sent_nums = np.asarray(out_sent_nums, dtype=np.int32) 218 | out_labels = np.stack(out_labels) 219 | 220 | yield out_sents, out_labels, out_sent_nums 221 | else: 222 | all_sent_ids = [] 223 | all_sents = [] 224 | all_labels = [] 225 | with open(os.path.join(self.data_path, filename), 'r') as fr: 226 | for line in fr: 227 | rel = [0] * self.num_classes 228 | try: 229 | sent_id, types = line.strip().split('\t') 230 | type_list = types.split() 231 | for tp in type_list: 232 | if len(type_list) > 1 and tp == '0': # if a sentence has multiple relations, we only consider non-NA relations 233 | continue 234 | rel[int(tp)] = 1 235 | except: 236 | sent_id = line.strip() 237 | 238 | all_sent_ids.append(sent_id) 239 | all_sents.append(sentence_dict[sent_id]) 240 | 241 | all_labels.append(np.reshape(np.asarray(rel, dtype=np.float32), (-1, self.num_classes))) 242 | 243 | self.data_size = len(all_sent_ids) 244 | self.datas = all_sent_ids 245 | 246 | all_sents = np.concatenate(all_sents, axis=0) 247 | all_labels = np.concatenate(all_labels, axis=0) 248 | 249 | data_order = list(range(self.data_size)) 250 | if shuffle: 251 | np.random.shuffle(data_order) 252 | if padding: 253 | if self.data_size % self.batch_size != 0: 254 | data_order += [data_order[-1]] * (self.batch_size - self.data_size % self.batch_size) 255 | 256 | for i in range(len(data_order) // self.batch_size): 257 | idx = data_order[i * self.batch_size:(i + 1) * self.batch_size] 258 | yield all_sents[idx], all_labels[idx], None 259 | 260 | def CNN_encoder(self): 261 | self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob') 262 | self.input_word = tf.placeholder(dtype=tf.int32, shape=[None, self.sen_len], name='input_word') 263 | self.input_pos_e1 = tf.placeholder(dtype=tf.int32, shape=[None, self.sen_len], name='input_pos_e1') 264 | self.input_pos_e2 = tf.placeholder(dtype=tf.int32, shape=[None, self.sen_len], name='input_pos_e2') 265 | self.input_label = tf.placeholder(dtype=tf.float32, shape=[None, self.num_classes], name='input_label') 266 | 267 | inputs_forward = tf.concat(axis=2, values=[tf.nn.embedding_lookup(self.word_embedding, self.input_word), \ 268 | tf.nn.embedding_lookup(self.pos_e1_embedding, self.input_pos_e1), \ 269 | tf.nn.embedding_lookup(self.pos_e2_embedding, self.input_pos_e2)]) 270 | inputs_forward = tf.expand_dims(inputs_forward, -1) 271 | 272 | with tf.name_scope('conv-maxpool'): 273 | w = tf.get_variable(name='w', shape=[self.window, self.word_dim + 2 * self.pos_dim, 1, self.hidden_dim]) 274 | b = tf.get_variable(name='b', shape=[self.hidden_dim]) 275 | conv = tf.nn.conv2d( 276 | inputs_forward, 277 | w, 278 | strides=[1, 1, 1, 1], 279 | padding='VALID', 280 | name='conv') 281 | h = tf.nn.bias_add(conv, b) 282 | pooled = tf.nn.max_pool( 283 | h, 284 | ksize=[1, self.sen_len - self.window + 1, 1, 1], 285 | strides=[1, 1, 1, 1], 286 | padding='VALID', 287 | name='pool') 288 | sen_reps = tf.tanh(tf.reshape(pooled, [-1, self.hidden_dim])) 289 | sen_reps = tf.nn.dropout(sen_reps, self.keep_prob) 290 | return sen_reps 291 | 292 | def bag_level(self): 293 | self.classifier_loss = 0.0 294 | self.probability = [] 295 | 296 | self.bag_sens = tf.placeholder(dtype=tf.int32, shape=[self.batch_size + 1], name='bag_sens') 297 | self.att_A = tf.get_variable(name='att_A', shape=[self.hidden_dim]) 298 | self.rel = tf.reshape(tf.transpose(self.relation_embedding), [self.num_classes, self.hidden_dim]) 299 | 300 | for i in range(self.batch_size): 301 | sen_reps = tf.reshape(self.sentence_reps[self.bag_sens[i]:self.bag_sens[i + 1]], [-1, self.hidden_dim]) 302 | 303 | att_sen = tf.reshape(tf.multiply(sen_reps, self.att_A), [-1, self.hidden_dim]) 304 | score = tf.matmul(self.rel, tf.transpose(att_sen)) 305 | alpha = tf.nn.softmax(score, 1) 306 | bag_rep = tf.matmul(alpha, sen_reps) 307 | 308 | out = tf.matmul(bag_rep, self.relation_embedding) + self.relation_embedding_b 309 | 310 | prob = tf.reshape(tf.reduce_sum(tf.nn.softmax(out, 1) * tf.reshape(self.input_label[i], [-1, 1]), 0), 311 | [self.num_classes]) 312 | 313 | self.probability.append( 314 | tf.reshape(tf.reduce_sum(tf.nn.softmax(out, 1) * tf.diag([1.0] * (self.num_classes)), 1), 315 | [-1, self.num_classes])) 316 | self.classifier_loss += tf.reduce_sum( 317 | -tf.log(tf.clip_by_value(prob, 1.0e-10, 1.0)) * tf.reshape(self.input_label[i], [-1])) 318 | 319 | self.probability = tf.concat(axis=0, values=self.probability) 320 | self.classifier_loss = self.classifier_loss / tf.cast(self.batch_size, tf.float32) 321 | 322 | def sentence_level(self): 323 | out = tf.matmul(self.sentence_reps, self.relation_embedding) + self.relation_embedding_b 324 | self.probability = tf.nn.softmax(out, 1) 325 | self.classifier_loss = tf.reduce_mean( 326 | tf.reduce_sum(-tf.log(tf.clip_by_value(self.probability, 1.0e-10, 1.0)) * self.input_label, 1)) 327 | 328 | def run_train(self, sess, batch): 329 | 330 | sent_batch, label_batch, sen_num_batch = batch 331 | 332 | feed_dict = {} 333 | feed_dict[self.keep_prob] = self.dropout 334 | feed_dict[self.input_word] = sent_batch[:, 0, :] 335 | feed_dict[self.input_pos_e1] = sent_batch[:, 1, :] 336 | feed_dict[self.input_pos_e2] = sent_batch[:, 2, :] 337 | feed_dict[self.input_label] = label_batch 338 | if self.bag: 339 | feed_dict[self.bag_sens] = sen_num_batch 340 | 341 | _, classifier_loss = sess.run([self._classifier_train_op, self.classifier_loss], feed_dict) 342 | 343 | return classifier_loss 344 | 345 | def run_dev(self, sess, dev_batchers): 346 | all_labels = [] 347 | all_probs = [] 348 | for batch in dev_batchers: 349 | sent_batch, label_batch, sen_num_batch = batch 350 | all_labels.append(label_batch) 351 | 352 | feed_dict = {} 353 | feed_dict[self.keep_prob] = 1.0 354 | feed_dict[self.input_word] = sent_batch[:, 0, :] 355 | feed_dict[self.input_pos_e1] = sent_batch[:, 1, :] 356 | feed_dict[self.input_pos_e2] = sent_batch[:, 2, :] 357 | if self.bag: 358 | feed_dict[self.bag_sens] = sen_num_batch 359 | prob = sess.run([self.probability], feed_dict) 360 | all_probs.append(np.reshape(prob, (-1, self.num_classes))) 361 | 362 | all_labels = np.concatenate(all_labels, axis=0)[:self.data_size] 363 | all_probs = np.concatenate(all_probs, axis=0)[:self.data_size] 364 | if self.bag: 365 | all_preds = all_probs 366 | all_preds[all_probs > 0.9] = 1 367 | all_preds[all_probs <= 0.9] = 0 368 | else: 369 | all_preds = np.eye(self.num_classes)[np.reshape(np.argmax(all_probs, 1), (-1))] 370 | 371 | return all_preds, all_labels 372 | 373 | def run_test(self, sess, test_batchers): 374 | all_probs = [] 375 | for batch in test_batchers: 376 | sent_batch, _, sen_num_batch = batch 377 | 378 | feed_dict = {} 379 | feed_dict[self.keep_prob] = 1.0 380 | feed_dict[self.input_word] = sent_batch[:, 0, :] 381 | feed_dict[self.input_pos_e1] = sent_batch[:, 1, :] 382 | feed_dict[self.input_pos_e2] = sent_batch[:, 2, :] 383 | if self.bag: 384 | feed_dict[self.bag_sens] = sen_num_batch 385 | prob = sess.run([self.probability], feed_dict) 386 | all_probs.append(np.reshape(prob, (-1, self.num_classes))) 387 | 388 | all_probs = np.concatenate(all_probs,axis=0)[:self.data_size] 389 | if self.bag: 390 | all_preds = all_probs 391 | all_preds[all_probs > 0.9] = 1 392 | all_preds[all_probs <= 0.9] = 0 393 | else: 394 | all_preds = np.eye(self.num_classes)[np.reshape(np.argmax(all_probs, 1), (-1))] 395 | 396 | if self.bag: 397 | with open('result_bag.txt', 'w') as fw: 398 | for i in range(self.data_size): 399 | rel_one_hot = [int(num) for num in all_preds[i].tolist()] 400 | rel_list = [] 401 | for j in range(0, self.num_classes): 402 | if rel_one_hot[j] == 1: 403 | rel_list.append(str(j)) 404 | if len(rel_list) == 0: # if a bag has no relation, it will be consider as having a relation NA 405 | rel_list.append('0') 406 | fw.write(self.datas[i] + '\t' + ' '.join(rel_list) + '\n') 407 | else: 408 | with open('result_sent.txt', 'w') as fw: 409 | for i in range(self.data_size): 410 | rel_one_hot = [int(num) for num in all_preds[i].tolist()] 411 | rel_list = [] 412 | for j in range(0, self.num_classes): 413 | if rel_one_hot[j] == 1: 414 | rel_list.append(str(j)) 415 | fw.write(self.datas[i] + '\t' + ' '.join(rel_list) + '\n') 416 | 417 | def run_model(self, sess, saver): 418 | if self.mode == 'train': 419 | global_step = 0 420 | sent_train = self.load_sent('sent_train.txt') 421 | sent_dev = self.load_sent('sent_dev.txt') 422 | max_f1 = 0.0 423 | 424 | if not os.path.isdir(self.model_path): 425 | os.mkdir(self.model_path) 426 | 427 | for epoch in range(self.epochs): 428 | if self.bag: 429 | train_batchers = self.data_batcher(sent_train, 'bag_relation_train.txt', padding=False, shuffle=True) 430 | else: 431 | train_batchers = self.data_batcher(sent_train, 'sent_relation_train.txt', padding=False, shuffle=True) 432 | for batch in train_batchers: 433 | 434 | losses = self.run_train(sess, batch) 435 | global_step += 1 436 | if global_step % 50 == 0: 437 | time_str = datetime.datetime.now().isoformat() 438 | tempstr = "{}: step {}, classifier_loss {:g}".format(time_str, global_step, losses) 439 | print(tempstr) 440 | if global_step % 200 == 0: 441 | if self.bag: 442 | dev_batchers = self.data_batcher(sent_dev, 'bag_relation_dev.txt', padding=True, shuffle=False) 443 | else: 444 | dev_batchers = self.data_batcher(sent_dev, 'sent_relation_dev.txt', padding=True, shuffle=False) 445 | all_preds, all_labels = self.run_dev(sess, dev_batchers) 446 | 447 | # when calculate f1 score, we don't consider whether NA results are predicted or not 448 | # the number of non-NA answers in test is counted as n_std 449 | # the number of non-NA answers in predicted answers is counted as n_sys 450 | # intersection of two answers is counted as n_r 451 | n_r = int(np.sum(all_preds[:, 1:] * all_labels[:, 1:])) 452 | n_std = int(np.sum(all_labels[:,1:])) 453 | n_sys = int(np.sum(all_preds[:,1:])) 454 | try: 455 | precision = n_r / n_sys 456 | recall = n_r / n_std 457 | f1 = 2 * precision * recall / (precision + recall) 458 | except ZeroDivisionError: 459 | f1 = 0.0 460 | 461 | if f1 > max_f1: 462 | max_f1 = f1 463 | print('f1: %f' % f1) 464 | print('saving model') 465 | path = saver.save(sess, os.path.join(self.model_path, 'ipre_bag_%d' % (self.bag)), global_step=0) 466 | tempstr = 'have saved model to ' + path 467 | print(tempstr) 468 | 469 | else: 470 | path = os.path.join(self.model_path, 'ipre_bag_%d' % self.bag) + '-0' 471 | tempstr = 'load model: ' + path 472 | print(tempstr) 473 | try: 474 | saver.restore(sess, path) 475 | except: 476 | raise ValueError('Unvalid model name') 477 | 478 | sent_test = self.load_sent('sent_test.txt') 479 | if self.bag: 480 | test_batchers = self.data_batcher(sent_test, 'bag_relation_test.txt', padding=True, shuffle=False) 481 | else: 482 | test_batchers = self.data_batcher(sent_test, 'sent_relation_test.txt', padding=True, shuffle=False) 483 | 484 | self.run_test(sess, test_batchers) 485 | 486 | 487 | def main(_): 488 | tf.reset_default_graph() 489 | print('build model') 490 | gpu_options = tf.GPUOptions(visible_device_list=FLAGS.cuda, allow_growth=True) 491 | with tf.Graph().as_default(): 492 | set_seed() 493 | sess = tf.Session( 494 | config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True, intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)) 495 | with sess.as_default(): 496 | initializer = tf.contrib.layers.xavier_initializer() 497 | with tf.variable_scope('', initializer=initializer): 498 | model = Baseline(FLAGS) 499 | sess.run(tf.global_variables_initializer()) 500 | saver = tf.train.Saver(max_to_keep=None) 501 | model.run_model(sess, saver) 502 | 503 | 504 | if __name__ == '__main__': 505 | tf.app.run() 506 | --------------------------------------------------------------------------------