├── .gitignore
├── README.md
├── cnn.py
├── cnn_context.py
├── data
├── er
│ ├── source.txt
│ └── target.txt
├── gold_patterns.tsv
├── mlmi
│ ├── source.att
│ ├── source.left
│ ├── source.middle
│ ├── source.right
│ ├── source.txt
│ └── target.txt
├── negative_candidates.tsv
├── negative_relations.tsv
├── positive_candidates.tsv
└── positive_relations.tsv
├── distant_supervision.py
├── eval.py
├── img
├── auc.png
├── cnn.png
├── emb_er.png
├── emb_left.png
├── emb_mlmi.png
├── emb_right.png
├── f1.png
├── loss.png
└── pr_curve.png
├── train.py
├── train_context.py
├── util.py
├── visualize.ipynb
└── word2vec
└── .gitignore
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .DS_Store
3 |
4 | *.pyc
5 | .ipynb_checkpoints/
6 |
7 | data/*.cPickle
8 | data/candidates*.tsv
9 | data/*/ids.*
10 | data/*/vocab.*
11 | data/*/emb.npy
12 |
13 | models/
14 | train/
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Convolutional Neural Network for Relation Extraction
2 |
3 | **Note:** This project is mostly based on https://github.com/yuhaozhang/sentence-convnet
4 |
5 | ---
6 |
7 |
8 | ## Requirements
9 |
10 | - [Python 2.7](https://www.python.org/)
11 | - [Tensorflow](https://www.tensorflow.org/) (tested with version 0.10.0rc0 -> 1.0.1)
12 | - [Numpy](http://www.numpy.org/)
13 |
14 | To download wikipedia articles (`distant_supervision.py`)
15 |
16 | - [Beautifulsoup](https://www.crummy.com/software/BeautifulSoup/bs4/doc/)
17 | - [Pandas](http://pandas.pydata.org/)
18 | - [Stanford NER](http://nlp.stanford.edu/software/CRF-NER.shtml)
19 | *Path to Stanford-NER is specified in `ner_path` variable in `distant_supervision.py`
20 |
21 | To visualize the results (`visualize.ipynb`)
22 |
23 | - [Matplotlib](https://matplotlib.org/)
24 | - [Scikit-learn](http://scikit-learn.org/)
25 |
26 |
27 | ## Data
28 | - `data` directory includes preprocessed data:
29 | ```
30 | cnn-re-tf
31 | ├── ...
32 | ├── word2vec
33 | └── data
34 | ├── er # binay-classification dataset
35 | │ ├── source.txt # source sentences
36 | │ └── target.txt # target labels
37 | └── mlmi # multi-label multi-instance dataset
38 | ├── source.att # attention
39 | ├── source.left # left context
40 | ├── source.middle # middle context
41 | ├── source.right # right context
42 | ├── source.txt # source sentences
43 | └── target.txt # target labels
44 | ```
45 | To reproduce:
46 | ```
47 | python ./distant_supervision.py
48 | ```
49 |
50 | - `word2vec` directory is empty. Please download the Google News pretrained vector data from
51 | [this Google Drive link](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit),
52 | and unzip it to the directory. It will be a `.bin` file.
53 |
54 |
55 |
56 | ## Usage
57 | ### Preprocess
58 |
59 | ```sh
60 | python ./util.py
61 | ```
62 | It creates `vocab.txt`, `ids.txt` and `emb.npy` files.
63 |
64 | ### Training
65 |
66 | - Binary classification (ER-CNN):
67 | ```sh
68 | python ./train.py --sent_len=3 --vocab_size=11208 --num_classes=2 --train_size=15000 \
69 | --data_dir=./data/er --attention=False --multi_label=False --use_pretrain=False
70 | ```
71 |
72 | - Multi-label multi-instance learning (MLMI-CNN):
73 | ```sh
74 | python ./train.py --sent_len=255 --vocab_size=36112 --num_classes=23 --train_size=10000 \
75 | --data_dir=./data/mlmi --attention=True --multi_label=True --use_pretrain=True
76 | ```
77 |
78 | - Multi-label multi-instance Context-wise learning (MLMI-CONT):
79 | ```sh
80 | python ./train_context.py --sent_len=102 --vocab_size=36112 --num_classes=23 --train_size=10000 \
81 | --data_dir=./data/mlmi --attention=True --multi_label=True --use_pretrain=True
82 | ```
83 |
84 | **Caution:** A wrong value for input-data-dependent options (`sent_len`, `vocab_size` and `num_class`)
85 | may cause an error. If you want to train the model on another dataset, please check these values.
86 |
87 |
88 | ### Evaluation
89 |
90 | ```sh
91 | python ./eval.py --train_dir=./train/1473898241
92 | ```
93 | Replace the `--train_dir` with the output from the training.
94 |
95 |
96 | ### Run TensorBoard
97 |
98 | ```sh
99 | tensorboard --logdir=./train/1473898241
100 | ```
101 |
102 |
103 | ## Architecture
104 |
105 | 
106 |
107 |
108 | ## Results
109 |
110 | | | P | R | F | AUC |init_lr|l2_reg|
111 | |--------:|:----:|:----:|:----:|:----:|------:|-----:|
112 | | ER-CNN |0.9410|0.8630|0.9003|0.9303| 0.005| 0.05|
113 | | MLMI-CNN|0.8205|0.6406|0.7195|0.7424| 1e-3| 1e-4|
114 | |MLMI-CONT|0.8819|0.7158|0.7902|0.8156| 1e-3| 1e-4|
115 |
116 | 
117 | 
118 | 
119 | 
120 | 
121 | 
122 | 
123 | 
124 |
125 | *As you see above, these models somewhat suffer from overfitting ...
126 |
127 |
128 | ## References
129 |
130 | * http://github.com/yuhaozhang/sentence-convnet
131 | * http://github.com/dennybritz/cnn-text-classification-tf
132 | * http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/
133 | * http://tkengo.github.io/blog/2016/03/14/text-classification-by-cnn/
134 | * Adel et al. [Comparing Convolutional Neural Networks to Traditional Models for Slot Filling](http://arxiv.org/abs/1603.05157) NAACL 2016
135 | * Nguyen and Grishman. [Relation Extraction: Perspective from Convolutional Neural Networks](http://www.cs.nyu.edu/~thien/pubs/vector15.pdf) NAACL 2015
136 | * Lin et al. [Neural Relation Extraction with Selective Attention over Instances](http://www.aclweb.org/anthology/P/P16/P16-1200.pdf) ACL 2016
137 |
--------------------------------------------------------------------------------
/cnn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | ##########################################################
4 | #
5 | # Attention-based Convolutional Neural Network
6 | # for Multi-label Multi-instance Learning
7 | #
8 | #
9 | # Note: this implementation is mostly based on
10 | # https://github.com/yuhaozhang/sentence-convnet/blob/master/model.py
11 | #
12 | ##########################################################
13 |
14 | import tensorflow as tf
15 |
16 | # model parameters
17 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Training batch size')
18 | tf.app.flags.DEFINE_integer('emb_size', 300, 'Size of word embeddings')
19 | tf.app.flags.DEFINE_integer('num_kernel', 100, 'Number of filters for each window size')
20 | tf.app.flags.DEFINE_integer('min_window', 3, 'Minimum size of filter window')
21 | tf.app.flags.DEFINE_integer('max_window', 5, 'Maximum size of filter window')
22 | tf.app.flags.DEFINE_integer('vocab_size', 40000, 'Vocabulary size')
23 | tf.app.flags.DEFINE_integer('num_classes', 10, 'Number of class to consider')
24 | tf.app.flags.DEFINE_integer('sent_len', 400, 'Input sentence length.')
25 | tf.app.flags.DEFINE_float('l2_reg', 1e-4, 'l2 regularization weight')
26 | tf.app.flags.DEFINE_boolean('attention', False, 'Whether use attention or not')
27 | tf.app.flags.DEFINE_boolean('multi_label', False, 'Multilabel or not')
28 |
29 |
30 | def _variable_on_cpu(name, shape, initializer):
31 | with tf.device('/cpu:0'):
32 | var = tf.get_variable(name, shape, initializer=initializer)
33 | return var
34 |
35 | def _variable_with_weight_decay(name, shape, initializer, wd):
36 | var = _variable_on_cpu(name, shape, initializer)
37 | if wd is not None and wd != 0.:
38 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
39 | else:
40 | weight_decay = tf.constant(0.0, dtype=tf.float32)
41 | return var, weight_decay
42 |
43 |
44 | def _auc_pr(true, prob, threshold):
45 | pred = tf.where(prob > threshold, tf.ones_like(prob), tf.zeros_like(prob))
46 | tp = tf.logical_and(tf.cast(pred, tf.bool), tf.cast(true, tf.bool))
47 | fp = tf.logical_and(tf.cast(pred, tf.bool), tf.logical_not(tf.cast(true, tf.bool)))
48 | fn = tf.logical_and(tf.logical_not(tf.cast(pred, tf.bool)), tf.cast(true, tf.bool))
49 | pre = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)), tf.reduce_sum(tf.cast(tf.logical_or(tp, fp), tf.int32)))
50 | rec = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)), tf.reduce_sum(tf.cast(tf.logical_or(tp, fn), tf.int32)))
51 | return pre, rec
52 |
53 |
54 | class Model(object):
55 |
56 | def __init__(self, config, is_train=True):
57 | self.is_train = is_train
58 | self.emb_size = config['emb_size']
59 | self.batch_size = config['batch_size']
60 | self.num_kernel = config['num_kernel']
61 | self.min_window = config['min_window']
62 | self.max_window = config['max_window']
63 | self.vocab_size = config['vocab_size']
64 | self.num_classes = config['num_classes']
65 | self.sent_len = config['sent_len']
66 | self.l2_reg = config['l2_reg']
67 | self.multi_instance = config['attention']
68 | self.multi_label = config['multi_label']
69 | if is_train:
70 | self.optimizer = config['optimizer']
71 | self.dropout = config['dropout']
72 | self.build_graph()
73 |
74 | def build_graph(self):
75 | """ Build the computation graph. """
76 | self._inputs = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_x')
77 | self._labels = tf.placeholder(dtype=tf.float32, shape=[None, self.num_classes], name='input_y')
78 | self._attention = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='attention')
79 | losses = []
80 |
81 | # lookup layer
82 | with tf.variable_scope('embedding') as scope:
83 | self._W_emb = _variable_on_cpu(name='embedding', shape=[self.vocab_size, self.emb_size],
84 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0))
85 | # sent_batch is of shape: (batch_size, sent_len, emb_size, 1), in order to use conv2d
86 | sent_batch = tf.nn.embedding_lookup(params=self._W_emb, ids=self._inputs)
87 | sent_batch = tf.expand_dims(sent_batch, -1)
88 |
89 | # conv + pooling layer
90 | pool_tensors = []
91 | for k_size in range(self.min_window, self.max_window+1):
92 | with tf.variable_scope('conv-%d' % k_size) as scope:
93 | kernel, wd = _variable_with_weight_decay(
94 | name='kernel-%d' % k_size,
95 | shape=[k_size, self.emb_size, 1, self.num_kernel],
96 | initializer=tf.truncated_normal_initializer(stddev=0.01),
97 | wd=self.l2_reg)
98 | losses.append(wd)
99 | conv = tf.nn.conv2d(input=sent_batch, filter=kernel, strides=[1,1,1,1], padding='VALID')
100 | biases = _variable_on_cpu(name='bias-%d' % k_size,
101 | shape=[self.num_kernel],
102 | initializer=tf.constant_initializer(0.0))
103 | bias = tf.nn.bias_add(conv, biases)
104 | activation = tf.nn.relu(bias, name=scope.name)
105 | # shape of activation: [batch_size, conv_len, 1, num_kernel]
106 | conv_len = activation.get_shape()[1]
107 | pool = tf.nn.max_pool(activation, ksize=[1,conv_len,1,1], strides=[1,1,1,1], padding='VALID')
108 | # shape of pool: [batch_size, 1, 1, num_kernel]
109 | pool_tensors.append(pool)
110 |
111 | # Combine all pooled tensors
112 | num_filters = self.max_window - self.min_window + 1
113 | pool_size = num_filters * self.num_kernel
114 | pool_layer = tf.concat(pool_tensors, num_filters, name='pool')
115 | pool_flat = tf.reshape(pool_layer, [-1, pool_size])
116 |
117 | # drop out layer
118 | if self.is_train and self.dropout > 0:
119 | pool_dropout = tf.nn.dropout(pool_flat, 1 - self.dropout)
120 | else:
121 | pool_dropout = pool_flat
122 |
123 | # fully-connected layer
124 | with tf.variable_scope('output') as scope:
125 | W, wd = _variable_with_weight_decay('W', shape=[pool_size, self.num_classes],
126 | initializer=tf.truncated_normal_initializer(stddev=0.05),
127 | wd=self.l2_reg)
128 | losses.append(wd)
129 | biases = _variable_on_cpu('bias', shape=[self.num_classes],
130 | initializer=tf.constant_initializer(0.01))
131 | self.logits = tf.nn.bias_add(tf.matmul(pool_dropout, W), biases, name='logits')
132 |
133 | # loss
134 | with tf.variable_scope('loss') as scope:
135 | if self.multi_label:
136 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self._labels,
137 | name='cross_entropy_per_example')
138 | else:
139 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self._labels,
140 | name='cross_entropy_per_example')
141 |
142 | if self.is_train and self.multi_instance: # apply attention
143 | cross_entropy_loss = tf.reduce_sum(tf.multiply(cross_entropy, self._attention),
144 | name='cross_entropy_loss')
145 | else:
146 | cross_entropy_loss = tf.reduce_mean(cross_entropy, name='cross_entropy_loss')
147 |
148 | losses.append(cross_entropy_loss)
149 | self._total_loss = tf.add_n(losses, name='total_loss')
150 |
151 | # eval with precision-recall
152 | with tf.variable_scope('evaluation') as scope:
153 | precision = []
154 | recall = []
155 | for threshold in range(10, -1, -1):
156 | pre, rec = _auc_pr(self._labels, tf.sigmoid(self.logits), threshold * 0.1)
157 | precision.append(pre)
158 | recall.append(rec)
159 | self._eval_op = zip(precision, recall)
160 |
161 | # f1 score on threshold=0.5
162 | #self._f1_score = tf.truediv(tf.mul(tf.constant(2.0, dtype=tf.float64),
163 | # tf.mul(precision[5], recall[5])), tf.add(precision, recall))
164 |
165 | # train on a batch
166 | self._lr = tf.Variable(0.0, trainable=False)
167 | if self.is_train:
168 | if self.optimizer == 'adadelta':
169 | opt = tf.train.AdadeltaOptimizer(self._lr)
170 | elif self.optimizer == 'adagrad':
171 | opt = tf.train.AdagradOptimizer(self._lr)
172 | elif self.optimizer == 'adam':
173 | opt = tf.train.AdamOptimizer(self._lr)
174 | elif self.optimizer == 'sgd':
175 | opt = tf.train.GradientDescentOptimizer(self._lr)
176 | else:
177 | raise ValueError("Optimizer not supported.")
178 | grads = opt.compute_gradients(self._total_loss)
179 | self._train_op = opt.apply_gradients(grads)
180 |
181 | for var in tf.trainable_variables():
182 | tf.summary.histogram(var.op.name, var)
183 | else:
184 | self._train_op = tf.no_op()
185 |
186 | return
187 |
188 | @property
189 | def inputs(self):
190 | return self._inputs
191 |
192 | @property
193 | def labels(self):
194 | return self._labels
195 |
196 | @property
197 | def attention(self):
198 | return self._attention
199 |
200 | @property
201 | def lr(self):
202 | return self._lr
203 |
204 | @property
205 | def train_op(self):
206 | return self._train_op
207 |
208 | @property
209 | def total_loss(self):
210 | return self._total_loss
211 |
212 | @property
213 | def eval_op(self):
214 | return self._eval_op
215 |
216 | @property
217 | def scores(self):
218 | return tf.sigmoid(self.logits)
219 |
220 | @property
221 | def W_emb(self):
222 | return self._W_emb
223 |
224 | def assign_lr(self, session, lr_value):
225 | session.run(tf.assign(self.lr, lr_value))
226 |
227 | def assign_embedding(self, session, pretrained):
228 | session.run(tf.assign(self.W_emb, pretrained))
229 |
--------------------------------------------------------------------------------
/cnn_context.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | ##########################################################
4 | #
5 | # Attention-based Convolutional Neural Network
6 | # for Context-wise Learning
7 | #
8 | #
9 | # Note: this implementation is mostly based on
10 | # https://github.com/yuhaozhang/sentence-convnet/blob/master/model.py
11 | #
12 | ##########################################################
13 |
14 | import tensorflow as tf
15 |
16 | # model parameters
17 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Training batch size')
18 | tf.app.flags.DEFINE_integer('emb_size', 300, 'Size of word embeddings')
19 | tf.app.flags.DEFINE_integer('num_kernel', 100, 'Number of filters for each window size')
20 | tf.app.flags.DEFINE_integer('min_window', 3, 'Minimum size of filter window')
21 | tf.app.flags.DEFINE_integer('max_window', 5, 'Maximum size of filter window')
22 | tf.app.flags.DEFINE_integer('vocab_size', 40000, 'Vocabulary size')
23 | tf.app.flags.DEFINE_integer('num_classes', 10, 'Number of class to consider')
24 | tf.app.flags.DEFINE_integer('sent_len', 400, 'Input sentence length.')
25 | tf.app.flags.DEFINE_float('l2_reg', 1e-4, 'l2 regularization weight')
26 | tf.app.flags.DEFINE_boolean('attention', False, 'Whether use attention or not')
27 | tf.app.flags.DEFINE_boolean('multi_label', False, 'Multilabel or not')
28 |
29 |
30 | def _variable_on_cpu(name, shape, initializer):
31 | with tf.device('/cpu:0'):
32 | var = tf.get_variable(name, shape, initializer=initializer)
33 | return var
34 |
35 |
36 | def _variable_with_weight_decay(name, shape, initializer, wd):
37 | var = _variable_on_cpu(name, shape, initializer)
38 | if wd is not None and wd != 0.:
39 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
40 | else:
41 | weight_decay = tf.constant(0.0, dtype=tf.float32)
42 | return var, weight_decay
43 |
44 |
45 | def _auc_pr(true, prob, threshold):
46 | pred = tf.where(prob > threshold, tf.ones_like(prob), tf.zeros_like(prob))
47 | tp = tf.logical_and(tf.cast(pred, tf.bool), tf.cast(true, tf.bool))
48 | fp = tf.logical_and(tf.cast(pred, tf.bool), tf.logical_not(tf.cast(true, tf.bool)))
49 | fn = tf.logical_and(tf.logical_not(tf.cast(pred, tf.bool)), tf.cast(true, tf.bool))
50 | pre = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)),
51 | tf.reduce_sum(tf.cast(tf.logical_or(tp, fp), tf.int32)))
52 | rec = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)),
53 | tf.reduce_sum(tf.cast(tf.logical_or(tp, fn), tf.int32)))
54 | return pre, rec
55 |
56 |
57 | class Model(object):
58 |
59 | def __init__(self, config, is_train=True):
60 | self.is_train = is_train
61 | self.emb_size = config['emb_size']
62 | self.batch_size = config['batch_size']
63 | self.num_kernel = config['num_kernel']
64 | self.min_window = config['min_window']
65 | self.max_window = config['max_window']
66 | self.vocab_size = config['vocab_size']
67 | self.num_classes = config['num_classes']
68 | self.sent_len = config['sent_len']
69 | self.l2_reg = config['l2_reg']
70 | self.multi_instance = config['attention']
71 | self.multi_label = config['multi_label']
72 | if is_train:
73 | self.optimizer = config['optimizer']
74 | self.dropout = config['dropout']
75 | self.build_graph()
76 |
77 | def conv_layer(self, input, context):
78 | pool_tensors = []
79 | losses = []
80 | for k_size in range(self.min_window, self.max_window+1):
81 | with tf.variable_scope('conv-%d-%s' % (k_size, context)) as scope:
82 | kernel, wd = _variable_with_weight_decay(
83 | name='kernel-%d-%s' % (k_size, context),
84 | shape=[k_size, self.emb_size, 1, self.num_kernel],
85 | initializer=tf.truncated_normal_initializer(stddev=0.01),
86 | wd=self.l2_reg)
87 | losses.append(wd)
88 | conv = tf.nn.conv2d(input=input, filter=kernel, strides=[1,1,1,1], padding='VALID')
89 | biases = _variable_on_cpu('bias-%d-%s' % (k_size, context),
90 | [self.num_kernel], tf.constant_initializer(0.0))
91 | bias = tf.nn.bias_add(conv, biases)
92 | activation = tf.nn.relu(bias, name=scope.name)
93 | # shape of activation: [batch_size, conv_len, 1, num_kernel]
94 | conv_len = activation.get_shape()[1]
95 | pool = tf.nn.max_pool(activation, ksize=[1,conv_len,1,1], strides=[1,1,1,1], padding='VALID')
96 | # shape of pool: [batch_size, 1, 1, num_kernel]
97 | pool_tensors.append(pool)
98 |
99 | # Combine pooled tensors
100 | num_filters = self.max_window - self.min_window + 1
101 | pool_size = num_filters * self.num_kernel # 300
102 | pool_layer = tf.concat(pool_tensors, num_filters, name='pool-%s' % context)
103 | pool_flat = tf.reshape(pool_layer, [-1, pool_size])
104 |
105 | return losses, pool_flat
106 |
107 | def build_graph(self):
108 | """ Build the computation graph. """
109 | self._left = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_left')
110 | self._middle = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_middle')
111 | self._right = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_right')
112 | self._labels = tf.placeholder(dtype=tf.float32, shape=[None, self.num_classes], name='input_y')
113 | self._attention = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='attention')
114 | losses = []
115 |
116 | with tf.variable_scope('embedding-left') as scope:
117 | self._W_emb_left = _variable_on_cpu(name=scope.name, shape=[self.vocab_size, self.emb_size],
118 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0))
119 | sent_batch_left = tf.nn.embedding_lookup(params=self._W_emb_left, ids=self._left)
120 | input_left = tf.expand_dims(sent_batch_left, -1)
121 |
122 | with tf.variable_scope('embedding-middle') as scope:
123 | self._W_emb_middle = _variable_on_cpu(name=scope.name, shape=[self.vocab_size, self.emb_size],
124 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0))
125 | sent_batch_middle = tf.nn.embedding_lookup(params=self._W_emb_middle, ids=self._middle)
126 | input_middle = tf.expand_dims(sent_batch_middle, -1)
127 |
128 | with tf.variable_scope('embedding-right') as scope:
129 | self._W_emb_right = _variable_on_cpu(name=scope.name, shape=[self.vocab_size, self.emb_size],
130 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0))
131 | sent_batch_right = tf.nn.embedding_lookup(params=self._W_emb_right, ids=self._right)
132 | input_right = tf.expand_dims(sent_batch_right, -1)
133 |
134 | # conv + pooling layer
135 | contexts = []
136 | for contextwise_input, context in zip([input_left, input_middle, input_right],
137 | ['left', 'middle', 'right']):
138 | conv_losses, pool_flat = self.conv_layer(contextwise_input, context)
139 | losses.extend(conv_losses)
140 | contexts.append(pool_flat)
141 | # Combine context tensors
142 | num_filters = self.max_window - self.min_window + 1
143 | pool_size = num_filters * self.num_kernel # 300
144 | concat_context = tf.concat(contexts, 1, name='concat')
145 | flat_context = tf.reshape(concat_context, [-1, pool_size*3])
146 |
147 | # drop out layer
148 | if self.is_train and self.dropout > 0:
149 | pool_dropout = tf.nn.dropout(flat_context, 1 - self.dropout)
150 | else:
151 | pool_dropout = flat_context
152 |
153 | # fully-connected layer
154 | with tf.variable_scope('output') as scope:
155 | W, wd = _variable_with_weight_decay('W', shape=[pool_size*3, self.num_classes],
156 | initializer=tf.truncated_normal_initializer(stddev=0.05), wd=self.l2_reg)
157 | losses.append(wd)
158 | biases = _variable_on_cpu('bias', shape=[self.num_classes],
159 | initializer=tf.constant_initializer(0.01))
160 | self.logits = tf.nn.bias_add(tf.matmul(pool_dropout, W), biases, name='logits')
161 |
162 | # loss
163 | with tf.variable_scope('loss') as scope:
164 | if self.multi_label:
165 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self._labels,
166 | name='cross_entropy_per_example')
167 | else:
168 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self._labels,
169 | name='cross_entropy_per_example')
170 |
171 | if self.is_train and self.multi_instance: # apply attention
172 | cross_entropy_loss = tf.reduce_sum(tf.multiply(cross_entropy, self._attention),
173 | name='cross_entropy_loss')
174 | else:
175 | cross_entropy_loss = tf.reduce_mean(cross_entropy, name='cross_entropy_loss')
176 |
177 | losses.append(cross_entropy_loss)
178 | self._total_loss = tf.add_n(losses, name='total_loss')
179 |
180 | # eval with auc-pr metric
181 | with tf.variable_scope('evaluation') as scope:
182 | precision = []
183 | recall = []
184 | for threshold in range(10, -1, -1):
185 | pre, rec = _auc_pr(self._labels, tf.sigmoid(self.logits), threshold * 0.1)
186 | precision.append(pre)
187 | recall.append(rec)
188 | self._eval_op = zip(precision, recall)
189 |
190 | # train on a batch
191 | self._lr = tf.Variable(0.0, trainable=False)
192 | if self.is_train:
193 | if self.optimizer == 'adadelta':
194 | opt = tf.train.AdadeltaOptimizer(self._lr)
195 | elif self.optimizer == 'adagrad':
196 | opt = tf.train.AdagradOptimizer(self._lr)
197 | elif self.optimizer == 'adam':
198 | opt = tf.train.AdamOptimizer(self._lr)
199 | elif self.optimizer == 'sgd':
200 | opt = tf.train.GradientDescentOptimizer(self._lr)
201 | else:
202 | raise ValueError("Optimizer not supported.")
203 | grads = opt.compute_gradients(self._total_loss)
204 | self._train_op = opt.apply_gradients(grads)
205 |
206 | for var in tf.trainable_variables():
207 | tf.summary.histogram(var.op.name, var)
208 | else:
209 | self._train_op = tf.no_op()
210 |
211 | return
212 |
213 | @property
214 | def left(self):
215 | return self._left
216 |
217 | @property
218 | def middle(self):
219 | return self._middle
220 |
221 | @property
222 | def right(self):
223 | return self._right
224 |
225 | @property
226 | def labels(self):
227 | return self._labels
228 |
229 | @property
230 | def attention(self):
231 | return self._attention
232 |
233 | @property
234 | def lr(self):
235 | return self._lr
236 |
237 | @property
238 | def train_op(self):
239 | return self._train_op
240 |
241 | @property
242 | def total_loss(self):
243 | return self._total_loss
244 |
245 | @property
246 | def eval_op(self):
247 | return self._eval_op
248 |
249 | @property
250 | def scores(self):
251 | return tf.sigmoid(self.logits)
252 |
253 | @property
254 | def W_emb_left(self):
255 | return self._W_emb_left
256 |
257 | @property
258 | def W_emb_middle(self):
259 | return self._W_emb_middle
260 |
261 | @property
262 | def W_emb_right(self):
263 | return self._W_emb_right
264 |
265 | def assign_lr(self, session, lr_value):
266 | session.run(tf.assign(self.lr, lr_value))
267 |
268 | def assign_embedding(self, session, pretrained):
269 | session.run(tf.assign(self.W_emb_left, pretrained))
270 | session.run(tf.assign(self.W_emb_middle, pretrained))
271 | session.run(tf.assign(self.W_emb_right, pretrained))
272 |
--------------------------------------------------------------------------------
/data/gold_patterns.tsv:
--------------------------------------------------------------------------------
1 | # Original taken from
2 | # https://github.com/beroth/relationfactory/blob/master/resources/manual_annotation/context_patterns2012.txt
3 |
4 |
5 | #per:alternate_names $ARG1 formerly known as $ARG2
6 | #per:alternate_names $ARG1 aka $ARG2
7 | #per:alternate_names $ARG1 a.k.a. $ARG2
8 | #per:alternate_names $ARG1 a. k. a. $ARG2
9 | #per:alternate_names $ARG1 is also known as $ARG2
10 | #per:alternate_names $ARG1 is known as $ARG2
11 | #per:alternate_names $ARG1 , also known as $ARG2
12 | #per:alternate_names $ARG1 , better known as $ARG2
13 | #per:alternate_names $ARG1 , best known as $ARG2
14 | #per:alternate_names $ARG1 , aka $ARG2
15 | #per:alternate_names $ARG1 , a.k.a. $ARG2
16 | #per:alternate_names $ARG1 , a. k. a. $ARG2
17 | #per:alternate_names $ARG1 ( also known as $ARG2 )
18 | #per:alternate_names $ARG1 ( better known as $ARG2 )
19 | # This overgeneralizes too much (e.g. actors playing roles).
20 | #per:alternate_names $ARG1 ( $ARG2 )
21 | #per:alternate_names $ARG1 ( aka $ARG2 )
22 | #per:alternate_names $ARG1 ( a.k.a. $ARG2 )
23 | #per:alternate_names $ARG1 ( a. k. a. $ARG2 )
24 |
25 | #org:alternate_names $ARG1 formerly known as $ARG2
26 | #org:alternate_names $ARG1 aka $ARG2
27 | #org:alternate_names $ARG1 a.k.a. $ARG2
28 | #org:alternate_names $ARG1 is also known as $ARG2
29 | #org:alternate_names $ARG1 is known as $ARG2
30 | #org:alternate_names $ARG1 , also known as $ARG2
31 | #org:alternate_names $ARG1 , better known as $ARG2
32 | #org:alternate_names $ARG1 , best known as $ARG2
33 | #org:alternate_names $ARG1 ( also known as $ARG2 )
34 | #org:alternate_names $ARG1 ( better known as $ARG2 )
35 | #org:alternate_names $ARG1 formerly known as `` $ARG2
36 | #org:alternate_names $ARG1 aka `` $ARG2
37 | #org:alternate_names $ARG1 a.k.a. `` $ARG2
38 | #org:alternate_names $ARG1 is also known as `` $ARG2
39 | #org:alternate_names $ARG1 is known as `` $ARG2
40 | #org:alternate_names $ARG1 , also known as `` $ARG2
41 | #org:alternate_names $ARG1 , better known as `` $ARG2
42 | #org:alternate_names $ARG1 , best known as `` $ARG2
43 | #org:alternate_names $ARG1 ( also known as `` $ARG2 '' )
44 | #org:alternate_names $ARG1 ( better known as `` $ARG2 '' )
45 |
46 | # per:country_of_birth -> P19
47 | P19 $ARG1 , the $ARG2-born
48 | P19 $ARG1 , $ARG2-born
49 | P19 $ARG1 was born in $ARG2
50 | P19 $ARG1 is born in $ARG2
51 | P19 $ARG1 , born in $ARG2
52 | P19 $ARG1 , being born in $ARG2
53 | P19 $ARG1 is a native $ARG2
54 | P19 $ARG1 , a native $ARG2
55 | P19 $ARG1 was born * in $ARG2
56 | P19 $ARG1 was born * in * , $ARG2
57 | P19 $ARG1 , born in * , $ARG2
58 | P19 $ARG1 , born * in * , $ARG2
59 | P19 $ARG1 is a native of $ARG2
60 | P19 $ARG1 , a native of $ARG2
61 | P19 $ARG2 born $ARG1
62 | P19 $ARG2 -born $ARG1
63 | P19 $ARG2 - born $ARG1
64 |
65 | # per:origin -> P27 country of citizenship
66 | P27 $ARG1 , a $ARG2 citizen
67 | P27 $ARG1 , an $ARG2 citizen
68 | P27 $ARG1 , who is a $ARG2 citizen
69 | P27 $ARG1 , who is an $ARG2 citizen
70 | P27 $ARG1 is a $ARG2 citizen
71 | P27 $ARG1 , originally from $ARG2
72 | P27 $ARG1 from $ARG2
73 | P27 $ARG2 citizen $ARG1
74 | P27 $ARG1 , born in $ARG2
75 | P27 $ARG1 was born in $ARG2
76 |
77 | # per:country_of_death -> P20
78 | P20 $ARG1 died in $ARG2
79 | P20 $ARG1 was killed in $ARG2
80 | P20 $ARG1 succumbed to * in $ARG2
81 | P20 $ARG1 passed away $ARG2
82 | P20 $ARG1 who died in $ARG2
83 | P20 $ARG1 who was killed in $ARG2
84 | P20 $ARG1 who succumbed to * in $ARG2
85 | P20 $ARG1 who passed away $ARG2
86 | P20 $ARG1 , who died in $ARG2
87 | P20 $ARG1 , who was killed in $ARG2
88 | P20 $ARG1 , who succumbed to * in $ARG2
89 | P20 $ARG1 , who passed away $ARG2
90 | P20 $ARG1 died * in $ARG2
91 | P20 $ARG1 , who died * in $ARG2
92 | P20 $ARG1 died * in * , $ARG2
93 | P20 $ARG1 , who died * in * , $ARG2
94 |
95 |
96 | # per:countries_of_residence -> P551
97 | # TODO: include frequent 'born-in' patterns
98 | P551 $ARG1 grew up in $ARG2
99 | P551 $ARG1 grew up in * , $ARG2
100 | P551 $ARG1 lives in $ARG2
101 | P551 $ARG1 lived in $ARG2
102 | P551 $ARG1 lives in * $ARG2
103 | P551 $ARG1 lived in * $ARG2
104 | P551 $ARG1 lives in * , $ARG2
105 | P551 $ARG1 lived in * , $ARG2
106 | P551 $ARG1 moves to $ARG2
107 | P551 $ARG1 moved to $ARG2
108 | P551 $ARG1 moves to * $ARG2
109 | P551 $ARG1 moved to * $ARG2
110 | P551 $ARG1 moves to * , $ARG2
111 | P551 $ARG1 moved to * , $ARG2
112 | P551 $ARG1 was raised in $ARG2
113 | P551 $ARG1 has been living in $ARG2
114 | P551 $ARG1 resides in $ARG2
115 | P551 $ARG1 resided in $ARG2
116 | P551 $ARG1 's house in $ARG2
117 | P551 $ARG1 's home in $ARG2
118 | P551 $ARG1 has a house in $ARG2
119 | P551 $ARG1 immigrated to $ARG2
120 | P551 $ARG1 spent his childhood in $ARG2
121 | P551 $ARG1 , born in $ARG2
122 | P551 $ARG1 was born in $ARG2
123 |
124 | P551 $ARG2 citizen $ARG1
125 | P551 $ARG2 , home of $ARG1
126 | P551 $ARG2 , hometown of $ARG1
127 | P551 $ARG2 is the hometown of $ARG1
128 | P551 $ARG2 citizen , $ARG1
129 | P551 $ARG1 is a $ARG2 citizen
130 | P551 $ARG1 is an $ARG2 citizen
131 | P551 $ARG1, a $ARG2 citizen
132 | P551 $ARG1, an $ARG2 citizen
133 |
134 | # capital of -> P1376
135 | P1376 $ARG1 , the capital of $ARG2
136 | P1376 $ARG1 is the capital of $ARG2
137 | P1376 $ARG1 was the capital of $ARG2
138 | P1376 $ARG1 , the * capital of $ARG2
139 | P1376 $ARG1 is the * capital of $ARG2
140 | P1376 $ARG1 was the * capital of $ARG2
141 | P1376 $ARG1 , the capital of * $ARG2
142 | P1376 $ARG1 is the capital of * $ARG2
143 | P1376 $ARG1 was the capital of $ARG2
144 | P1376 $ARG1 , the * capital of * $ARG2
145 | P1376 $ARG1 is the * capital of * $ARG2
146 | P1376 $ARG1 was the * capital of *$ARG2
147 |
148 | # country -> P17
149 | P17 $ARG1 * the state of $ARG2
150 | P17 $ARG1 * the state of * , $ARG2
151 | P17 $ARG1 * the state of * , * $ARG2
152 | P17 $ARG1 * the * state of * , $ARG2
153 | P17 $ARG1 * the * state of * , $ARG2
154 | P17 $ARG1 * the * state of * , * $ARG2
155 | P17 $ARG1 * the province of $ARG2
156 | P17 $ARG1 * the province of * , $ARG2
157 | P17 $ARG1 * the province of * , * $ARG2
158 | P17 $ARG1 * the * province of * , $ARG2
159 | P17 $ARG1 * the * province of * , $ARG2
160 | P17 $ARG1 * the * province of * , * $ARG2
161 | P17 $ARG1 * the prefecture of $ARG2
162 | P17 $ARG1 * the prefecture of * , $ARG2
163 | P17 $ARG1 * the prefecture of * , * $ARG2
164 | P17 $ARG1 * the * prefecture of * , $ARG2
165 | P17 $ARG1 * the * prefecture of * , $ARG2
166 | P17 $ARG1 * the * prefecture of * , * $ARG2
167 |
168 |
169 | # educated at P69
170 | P69 $ARG1 graduated from $ARG2
171 | P69 $ARG1 graduated in * from $ARG2
172 | P69 $ARG1 attended * at $ARG2
173 | P69 $ARG1 attended $ARG2
174 | P69 $ARG1 earned his * at $ARG2
175 | P69 $ARG1 holds a * from $ARG2
176 | P69 $ARG1 holds a * in * from $ARG2
177 | P69 $ARG1 , who holds a * in * from $ARG2
178 | P69 $ARG1 , who holds a * from $ARG2
179 | P69 $ARG1 has a * from $ARG2
180 | P69 $ARG1 has a * in * from $ARG2
181 | P69 $ARG1 , who has a * in * from $ARG2
182 | P69 $ARG1 , who has a * from $ARG2
183 | # This should match the sentence:
184 | #Erraguntla , who has a masters degree in computer engineering from the University of Lousiana
185 | P69 $ARG1 studied at $ARG2
186 | P69 $ARG1 is a student at $ARG2
187 | P69 $ARG1 was a student at $ARG2
188 | P69 $ARG1 is a $ARG2 student
189 | P69 $ARG1, a $ARG2 student
190 | P69 $ARG1 is an $ARG2 student
191 | P69 $ARG1, an $ARG2 student
192 | P69 $ARG1 is a $ARG2 graduate
193 | P69 $ARG1, a $ARG2 graduate
194 | P69 $ARG1 is an $ARG2 graduate
195 | P69 $ARG1, an $ARG2 graduate
196 | P69 $ARG1 is a $ARG2 alumnus
197 | P69 $ARG1, a $ARG2 alumnus
198 | P69 $ARG1 is an $ARG2 alumnus
199 | P69 $ARG1, an $ARG2 alumnus
200 | P69 $ARG2 gradute $ARG1
201 | P69 $ARG2 student $ARG1
202 | P69 $ARG2 alumnus $ARG1
203 |
204 |
205 | # job title
206 | #per:title $ARG1 was appointed $ARG2
207 | #per:title $ARG1 was nominated $ARG2
208 | #per:title $ARG1 was elected as $ARG2
209 | #per:title $ARG1 , a $ARG2
210 | #per:title $ARG1 , an $ARG2
211 | #per:title $ARG1 , the $ARG2
212 | #per:title $ARG1 , a former $ARG2
213 | #per:title $ARG1 , an a former $ARG2
214 | #per:title $ARG1 is a $ARG2
215 | #per:title $ARG1 is an $ARG2
216 | #per:title $ARG1 is the $ARG2
217 | #per:title $ARG2 works as a $ARG1
218 | #per:title $ARG2 works as an $ARG1
219 | #per:title $ARG1 was a $ARG2
220 | #per:title $ARG1 was an $ARG2
221 | #per:title $ARG2 worked as a $ARG1
222 | #per:title $ARG2 worked as an $ARG1
223 | # TODO: high-recall vs. high-precision patterns
224 | #per:title $ARG2 $ARG1
225 |
226 |
227 |
228 | # per:member_of -> P463, P54
229 | P463 $ARG1 is a member of $ARG2
230 | P463 $ARG2 member $ARG1
231 | P463 $ARG1 is a fellow of $ARG2
232 | P463 $ARG2 fellow $ARG1
233 | P463 $ARG1 became a member of $ARG2
234 | P463 $ARG1 became $ARG2 member
235 | P463 $ARG1 became a fellow of $ARG2
236 | P463 $ARG1 became $ARG2 fellow
237 | P463 $ARG1 joined $ARG2
238 | P463 $ARG1 is leaving $ARG2
239 | P463 $ARG1 left $ARG2
240 |
241 | P463 $ARG1 worked for $ARG2
242 | P463 $ARG1 works for $ARG2
243 | P463 $ARG1 has been working for $ARG2
244 | P463 $ARG1 was working for $ARG2
245 | P463 $ARG1 had working for $ARG2
246 | P463 $ARG1 was employed by $ARG2
247 | P463 $ARG1 is employed by $ARG2
248 | P463 $ARG1 , an $ARG2 employee
249 | P463 $ARG1 , a $ARG2 employee
250 | P463 $ARG1 was hired by $ARG2
251 | P463 $ARG1 has been hired by $ARG2
252 | P463 $ARG1 had been hired by $ARG2
253 | P463 $ARG2 manager $ARG1
254 | P463 $ARG2 is a manager of $ARG1
255 | P463 $ARG2 coach $ARG1
256 | P463 $ARG2 is the coach of $ARG1
257 | P463 $ARG1 served $ARG2
258 |
259 | # sports team P54
260 | P54 $ARG1 is a member of $ARG2
261 | P54 $ARG2 member $ARG1
262 | P54 $ARG1 is a fellow of $ARG2
263 | P54 $ARG2 fellow $ARG1
264 | P54 $ARG1 became a member of $ARG2
265 | P54 $ARG1 became $ARG2 member
266 | P54 $ARG1 became a fellow of $ARG2
267 | P54 $ARG1 became $ARG2 fellow
268 | P54 $ARG1 joined $ARG2
269 | P54 $ARG1 is leaving $ARG2
270 | P54 $ARG1 left $ARG2
271 |
272 | P54 $ARG1 worked for $ARG2
273 | P54 $ARG1 works for $ARG2
274 | P54 $ARG1 has been working for $ARG2
275 | P54 $ARG1 was working for $ARG2
276 | P54 $ARG1 had working for $ARG2
277 | P54 $ARG1 was employed by $ARG2
278 | P54 $ARG1 is employed by $ARG2
279 | P54 $ARG1 , an $ARG2 employee
280 | P54 $ARG1 , a $ARG2 employee
281 | P54 $ARG1 was hired by $ARG2
282 | P54 $ARG1 has been hired by $ARG2
283 | P54 $ARG1 had been hired by $ARG2
284 | P54 $ARG2 manager $ARG1
285 | P54 $ARG2 is a manager of $ARG1
286 | P54 $ARG2 coach $ARG1
287 | P54 $ARG2 is the coach of $ARG1
288 | P54 $ARG1 served $ARG2
289 |
290 |
291 | # per:employee_of -> P108
292 | P108 $ARG1 worked for $ARG2
293 | P108 $ARG1 works for $ARG2
294 | P108 $ARG1 has been working for $ARG2
295 | P108 $ARG1 was working for $ARG2
296 | P108 $ARG1 had working for $ARG2
297 | P108 $ARG1 was employed by $ARG2
298 | P108 $ARG1 is employed by $ARG2
299 | P108 $ARG1 , an $ARG2 employee
300 | P108 $ARG1 , a $ARG2 employee
301 | P108 $ARG1 was hired by $ARG2
302 | P108 $ARG1 has been hired by $ARG2
303 | P108 $ARG1 had been hired by $ARG2
304 | P108 $ARG2 manager $ARG1
305 | P108 $ARG2 is a manager of $ARG1
306 | P108 $ARG2 coach $ARG1
307 | P108 $ARG2 is the coach of $ARG1
308 | P108 $ARG1 joined $ARG2
309 | P108 $ARG1 is leaving $ARG2
310 | P108 $ARG1 left $ARG2
311 | P108 $ARG1 served $ARG2
312 |
313 | P108 $ARG1 is a member of $ARG2
314 | P108 $ARG2 member $ARG1
315 | P108 $ARG1 is a fellow of $ARG2
316 | P108 $ARG2 fellow $ARG1
317 | P108 $ARG1 became a member of $ARG2
318 | P108 $ARG1 became $ARG2 member
319 | P108 $ARG1 became a fellow of $ARG2
320 | P108 $ARG1 became $ARG2 fellow
321 | P108 $ARG1 joined $ARG2
322 | P108 $ARG1 is leaving $ARG2
323 | P108 $ARG1 left $ARG2
324 | P108 $ARG1 worked for $ARG2
325 | P108 $ARG1 works for $ARG2
326 | P108 $ARG1 has been working for $ARG2
327 | P108 $ARG1 was working for $ARG2
328 | P108 $ARG1 had working for $ARG2
329 | P108 $ARG1 was employed by $ARG2
330 | P108 $ARG1 is employed by $ARG2
331 | P108 $ARG1 , an $ARG2 employee
332 | P108 $ARG1 , a $ARG2 employee
333 | P108 $ARG1 was hired by $ARG2
334 | P108 $ARG1 has been hired by $ARG2
335 | P108 $ARG1 had been hired by $ARG2
336 | P108 $ARG2 manager $ARG1
337 | P108 $ARG2 is a manager of $ARG1
338 | P108 $ARG2 coach $ARG1
339 | P108 $ARG2 is the coach of $ARG1
340 | P108 $ARG1 joined $ARG2
341 | P108 $ARG1 is leaving $ARG2
342 | P108 $ARG1 left $ARG2
343 | P108 $ARG1 served $ARG2
344 |
345 |
346 | # per:spouse -> P26
347 | P26 $ARG0 and $ARG0 are married
348 | P26 $ARG0 is married to $ARG0
349 | P26 $ARG0 is the wife of $ARG0
350 | P26 $ARG0 's wife , $ARG0
351 | P26 $ARG0' wife , $ARG0
352 | P26 $ARG0 , $ARG0 's wife
353 | P26 $ARG0 , $ARG0' wife
354 | P26 $ARG0 , the wife of $ARG0
355 | P26 $ARG0 and his wife $ARG0
356 | P26 $ARG0 is the husband of $ARG0
357 | P26 $ARG0 's husband , $ARG0
358 | P26 $ARG0' husband , $ARG0
359 | P26 $ARG0 , $ARG0 's husband
360 | P26 $ARG0 , $ARG0' husband
361 | P26 $ARG0 , the husband of $ARG0
362 | P26 $ARG0 and her husband $ARG0
363 |
364 | P40 $ARG1 's child $ARG2
365 | P40 $ARG1 ' child $ARG2
366 | P40 $ARG2 is the child of $ARG1
367 | P40 $ARG1 has a child named $ARG2
368 | P40 $ARG1 's oldest child $ARG2
369 | P40 $ARG1 ' oldest child $ARG2
370 | P40 $ARG2 is the oldest child of $ARG1
371 | P40 $ARG1 's youngest child $ARG2
372 | P40 $ARG1 ' youngest child $ARG2
373 | P40 $ARG2 is the youngest child of $ARG1
374 | P40 $ARG1 's daughter $ARG2
375 | P40 $ARG1 ' daughter $ARG2
376 | P40 $ARG2 is the daughter of $ARG1
377 | P40 $ARG1 has a doughter named $ARG2
378 | P40 $ARG1 's son $ARG2
379 | P40 $ARG1 ' son $ARG2
380 | P40 $ARG2 is the son of $ARG1
381 | P40 $ARG1 has a son named $ARG2
382 | P40 $ARG1 is the father of $ARG2
383 | P40 $ARG2 's father , $ARG1
384 | P40 $ARG2 ' father , $ARG1
385 | P40 $ARG2 's father $ARG1
386 | P40 $ARG2 ' father $ARG1
387 | P40 $ARG1 is the mother of $ARG2
388 | P40 $ARG2 's mother , $ARG1
389 | P40 $ARG2 ' mother , $ARG1
390 | P40 $ARG2 's mother $ARG1
391 | P40 $ARG2 ' mother $ARG1
392 |
393 | P22 $ARG2 's child $ARG1
394 | P22 $ARG2 ' child $ARG1
395 | P22 $ARG1 is the child of $ARG2
396 | P22 $ARG2 has a child named $ARG1
397 | P22 $ARG2 's oldest child $ARG1
398 | P22 $ARG2 ' oldest child $ARG1
399 | P22 $ARG1 is the oldest child of $ARG2
400 | P22 $ARG2 ' youngest child $ARG1
401 | P22 $ARG1 is the youngest child of $ARG2
402 | P22 $ARG2 's daughter $ARG1
403 | P22 $ARG2 ' daughter $ARG1
404 | P22 $ARG1 is the daughter of $ARG2
405 | P22 $ARG2 has a daughter named $ARG1
406 | P22 $ARG2 's son $ARG1
407 | P22 $ARG2 ' son $ARG1
408 | P22 $ARG1 is the son of $ARG2
409 | P22 $ARG2 has a son named $ARG1
410 | P22 $ARG2 is the father of $ARG1
411 | P22 $ARG1 's father , $ARG2
412 | P22 $ARG1 ' father , $ARG2
413 | P22 $ARG1 's father $ARG2
414 | P22 $ARG1 ' father $ARG2
415 |
416 | P25 $ARG2 's child $ARG1
417 | P25 $ARG2 ' child $ARG1
418 | P25 $ARG1 is the child of $ARG2
419 | P25 $ARG2 has a child named $ARG1
420 | P25 $ARG2 's oldest child $ARG1
421 | P25 $ARG2 ' oldest child $ARG1
422 | P25 $ARG1 is the oldest child of $ARG2
423 | P25 $ARG2 ' youngest child $ARG1
424 | P25 $ARG1 is the youngest child of $ARG2
425 | P25 $ARG2 's daughter $ARG1
426 | P25 $ARG2 ' daughter $ARG1
427 | P25 $ARG1 is the daughter of $ARG2
428 | P25 $ARG2 has a daughter named $ARG1
429 | P25 $ARG2 's son $ARG1
430 | P25 $ARG2 ' son $ARG1
431 | P25 $ARG1 is the son of $ARG2
432 | P25 $ARG2 has a son named $ARG1
433 | P25 $ARG2 is the mother of $ARG1
434 | P25 $ARG1 's mother , $ARG2
435 | P25 $ARG1 ' mother , $ARG2
436 | P25 $ARG1 's mother $ARG2
437 | P25 $ARG1 ' mother $ARG2
438 |
439 | P7 $ARG0 is the brother of $ARG0
440 | P7 $ARG0 is a brother of $ARG0
441 | P7 $ARG0 's brother , $ARG0
442 | P7 $ARG0' brother , $ARG0
443 | P7 $ARG0 's brother $ARG0
444 | P7 $ARG0' brother $ARG0
445 | P7 $ARG0 is the younger brother of $ARG0
446 | P7 $ARG0 's younger brother , $ARG0
447 | P7 $ARG0' younger brother , $ARG0
448 | P7 $ARG0 is the older brother of $ARG0
449 | P7 $ARG0 's older brother , $ARG0
450 | P7 $ARG0' older brother , $ARG0
451 | P9 $ARG0 is the sister of $ARG0
452 | P9 $ARG0 is a sister of $ARG0
453 | P9 $ARG0 's sister $ARG0
454 | P9 $ARG0' sister $ARG0
455 | P9 $ARG0 's sister , $ARG0
456 | P9 $ARG0' sister , $ARG0
457 | P9 $ARG0 is the younger sister of $ARG0
458 | P9 $ARG0 's younger sister , $ARG0
459 | P9 $ARG0' younger sister , $ARG0
460 | P9 $ARG0 is the older sister of $ARG0
461 | P9 $ARG0 's older sister , $ARG0
462 | P9 $ARG0' older sister , $ARG0
463 |
464 | P1038 $ARG0 's grandson $ARG0
465 | P1038 $ARG0 's granddaughter $ARG0
466 | P1038 $ARG0 's grandfather $ARG0
467 | P1038 $ARG0 's grandmother $ARG0
468 | P1038 $ARG0 's uncle $ARG0
469 | P1038 $ARG0 's aunt $ARG0
470 | P1038 $ARG0 's cousin $ARG0
471 | P1038 $ARG0 's nephew $ARG0
472 | P1038 $ARG0 's relative $ARG0
473 | P1038 $ARG0 's grandson , $ARG0
474 | P1038 $ARG0 's granddaughter , $ARG0
475 | P1038 $ARG0 's grandfather , $ARG0
476 | P1038 $ARG0 's grandmother , $ARG0
477 | P1038 $ARG0 's uncle , $ARG0
478 | P1038 $ARG0 's aunt , $ARG0
479 | P1038 $ARG0 's cousin , $ARG0
480 | P1038 $ARG0 's nephew , $ARG0
481 | P1038 $ARG0 's relative , $ARG0
482 | P1038 $ARG0 is a grandson of $ARG0
483 | P1038 $ARG0 is the grandson of $ARG0
484 | P1038 $ARG0 is a granddaughter of $ARG0
485 | P1038 $ARG0 is the granddaughter of $ARG0
486 | P1038 $ARG0 is the grandfather of $ARG0
487 | P1038 $ARG0 is the grandmother of $ARG0
488 | P1038 $ARG0 is the uncle of $ARG0
489 | P1038 $ARG0 is the aunt of $ARG0
490 | P1038 $ARG0 is the cousin of $ARG0
491 | P1038 $ARG0 is a cousin of $ARG0
492 | P1038 $ARG0 is the nephew of $ARG0
493 | P1038 $ARG0 is a nephew of $ARG0
494 | P1038 $ARG0 is a relative of $ARG0
495 | P1038 $ARG0 is related to $ARG0
496 |
497 |
498 | # org:political_religious_affiliation -> P102
499 | P102 $ARG1 is affiliated with $ARG2
500 | P102 $ARG1 , affiliated with $ARG2
501 | P102 $ARG1 is connected to $ARG2
502 | P102 $ARG1 , connected to $ARG2
503 | P102 $ARG1 is strongly connected to $ARG2
504 | P102 $ARG1 , strongly connected to $ARG2
505 | P102 $ARG1 is tightly connected to $ARG2
506 | P102 $ARG1 , tightly connected to $ARG2
507 | P102 $ARG1 is a $ARG2 organization
508 | P102 $ARG1 is an $ARG2 organization
509 | P102 $ARG1 is a $ARG2 group
510 | P102 $ARG1 is an $ARG2 group
511 | P102 $ARG1 as a $ARG2 organization
512 | P102 $ARG1 as an $ARG2 organization
513 | P102 $ARG1 as a $ARG2 group
514 | P102 $ARG1 as an $ARG2 group
515 | P102 $ARG1 , a $ARG2 organization
516 | P102 $ARG1 , an $ARG2 organization
517 | P102 $ARG1 , a $ARG2 group
518 | P102 $ARG1 , an $ARG2 group
519 | P102 $ARG1 and other $ARG2 organizations
520 | P102 $ARG1 and other $ARG2 groups
521 | P102 $ARG1 is an affiliate of $ARG2
522 |
523 |
524 | P108 $ARG1 leads $ARG2
525 | P108 $ARG1 commands $ARG2
526 | P108 $ARG1 oversees $ARG2
527 | P108 $ARG1 commanded $ARG2
528 | P108 $ARG2 board member $ARG1
529 | P108 $ARG2 CFO $ARG1
530 | P108 $ARG2 CEO $ARG1
531 | P108 $ARG2 chief executive officer $ARG1
532 | P108 $ARG2 managing director $ARG1
533 | P108 $ARG2 MD $ARG1
534 | P108 $ARG2 executive director $ARG1
535 | P108 $ARG2 ED $ARG1
536 | P108 $ARG2 president $ARG1
537 | P108 $ARG2 vice president $ARG1
538 | P108 $ARG2 director $ARG1
539 | P108 $ARG2 chairman $ARG1
540 | P108 $ARG2 executive vice president $ARG1
541 | P108 $ARG2 provost $ARG1
542 | P108 $ARG2 dean $ARG1
543 | P108 $ARG1 is a board member of $ARG2
544 | P108 $ARG1 is the CFO of $ARG2
545 | P108 $ARG1 is the CEO of $ARG2
546 | P108 $ARG1 is the president of $ARG2
547 | P108 $ARG1 is the vice president of $ARG2
548 | P108 $ARG1 is the director of $ARG2
549 | P108 $ARG1 is the chairman of $ARG2
550 | P108 $ARG1 is the executive vice president of $ARG2
551 | P108 $ARG1 is the provost of $ARG2
552 | P108 $ARG1 is the dean of $ARG2
553 | P108 $ARG1 , board member of $ARG2
554 | P108 $ARG1 , CFO of $ARG2
555 | P108 $ARG1 , CEO of $ARG2
556 | P108 $ARG1 , president of $ARG2
557 | P108 $ARG1 , vice president of $ARG2
558 | P108 $ARG1 , director of $ARG2
559 | P108 $ARG1 , chairman of $ARG2
560 | P108 $ARG1 , executive vice president of $ARG2
561 | P108 $ARG1 , provost of $ARG2
562 | P108 $ARG1 , dean of $ARG2
563 | P108 $ARG1 , a board member of $ARG2
564 | P108 $ARG1 , the CFO of $ARG2
565 | P108 $ARG1 , the CEO of $ARG2
566 | P108 $ARG1 , the president of $ARG2
567 | P108 $ARG1 , the vice president of $ARG2
568 | P108 $ARG1 , the director of $ARG2
569 | P108 $ARG1 , the chairman of $ARG2
570 | P108 $ARG1 , the executive vice president of $ARG2
571 | P108 $ARG1 , the provost of $ARG2
572 | P108 $ARG1 , the dean of $ARG2
573 | P108 $ARG1 , former board member of $ARG2
574 | P108 $ARG1 , former CFO of $ARG2
575 | P108 $ARG1 , former CEO of $ARG2
576 | P108 $ARG1 , former president of $ARG2
577 | P108 $ARG1 , former vice president of $ARG2
578 | P108 $ARG1 , former director of $ARG2
579 | P108 $ARG1 , former chairman of $ARG2
580 | P108 $ARG1 , former executive vice president of $ARG2
581 | P108 $ARG1 , former provost of $ARG2
582 | P108 $ARG1 , former dean of $ARG2
583 |
584 |
585 | # org:member_of -> P463
586 | P463 $ARG1 is a member organization of $ARG2
587 | P463 $ARG1 as a member organization of $ARG2
588 | P463 $ARG1 , as a member organization of $ARG2
589 | P463 $ARG1 , a member organization of $ARG2
590 | P463 $ARG1 is a member of $ARG2
591 | P463 $ARG1 are member of $ARG2
592 | P463 $ARG1 as a member of $ARG2
593 | P463 $ARG1 , as a member of $ARG2
594 | P463 $ARG1 , a member of $ARG2
595 | P463 $ARG1 is part of $ARG2
596 | P463 $ARG1 are part of $ARG2
597 | P463 $ARG1 plays in $ARG2
598 | P463 $ARG1 play in $ARG2
599 | P463 $ARG1 is relegated to $ARG2
600 | P463 $ARG1 are relegated to $ARG2
601 | P463 $ARG1 is promoted to $ARG2
602 | P463 $ARG1 are promoted to $ARG2
603 | P463 $ARG1 joined $ARG2
604 | P463 $ARG1 and other members of $ARG2
605 |
606 |
607 | # org:subsidiaries -> P355
608 | P355 $ARG2 , a subsidiary of $ARG1
609 | P355 $ARG2 is a subsidiary of $ARG1
610 | P355 $ARG2 as a subsidiary of $ARG1
611 | P355 $ARG2 operates as a subsidiary of $ARG1
612 | P355 $ARG2 , a branch of $ARG1
613 | P355 $ARG2 is a branch of $ARG1
614 | P355 $ARG2 as a branch of $ARG1
615 | P355 $ARG2 , a regional branch of $ARG1
616 | P355 $ARG2 is a regional branch of $ARG1
617 | P355 $ARG2 as a regional branch of $ARG1
618 | P355 $ARG2 , owned by $ARG1
619 | P355 $ARG2 is owned by $ARG1
620 | P355 $ARG1 owns $ARG2
621 | P355 $ARG1 owned $ARG2
622 | P355 $ARG2 belongs to $ARG1
623 | P355 $ARG2 , belonging to $ARG1
624 | P355 $ARG1 's $ARG2 subsidiary
625 | # Too low precision.
626 | #org:subsidiaries $ARG1 's $ARG2
627 | #org:subsidiaries $ARG2 of $ARG1
628 | P355 $ARG1 's subsidiary $ARG2
629 | P355 $ARG1 's unit $ARG2
630 | P355 $ARG1 's arm $ARG2
631 | P355 $ARG1 's * subsidiary $ARG2
632 | P355 $ARG1 's * branch $ARG2
633 | P355 $ARG1 's * unit $ARG2
634 | P355 $ARG1 's * arm $ARG2
635 | P355 $ARG2 subsidiary of $ARG1
636 | P355 $ARG2 is the * arm of $ARG1
637 | P355 $ARG2 is the * unit of $ARG1
638 | P355 $ARG2 is the * branch of $ARG1
639 | P355 $ARG2 is the * subsidiary of $ARG1
640 |
641 |
642 | # org:parent -> P749
643 | P749 $ARG1 , a subsidiary of $ARG2
644 | P749 $ARG1 is a subsidiary of $ARG2
645 | P749 $ARG1 as a subsidiary of $ARG2
646 | P749 $ARG1 operates as a subsidiary of $ARG2
647 | P749 $ARG1 , a branch of $ARG2
648 | P749 $ARG1 is a branch of $ARG2
649 | P749 $ARG1 as a branch of $ARG2
650 | P749 $ARG1 , a regional branch of $ARG2
651 | P749 $ARG1 is a regional branch of $ARG2
652 | P749 $ARG1 as a regional branch of $ARG2
653 | P749 $ARG1 , owned by $ARG2
654 | P749 $ARG1 is owned by $ARG2
655 | P749 $ARG2 owns $ARG1
656 | P749 $ARG2 owned $ARG1
657 | P749 $ARG1 belongs to $ARG2
658 | P749 $ARG1 , belonging to $ARG2
659 | P749 $ARG2 's $ARG1 subsidiary
660 | P749 $ARG2 's subsidiary $ARG1
661 | P749 $ARG2 's unit $ARG1
662 | P749 $ARG2 's arm $ARG1
663 | P749 $ARG2 's * subsidiary $ARG1
664 | P749 $ARG2 's * branch $ARG1
665 | P749 $ARG2 's * unit $ARG1
666 | P749 $ARG2 's * arm $ARG1
667 | P749 $ARG1 subsidiary of $ARG2
668 | P749 $ARG1 is the * arm of $ARG2
669 | P749 $ARG1 is the * unit of $ARG2
670 | P749 $ARG1 is the * branch of $ARG2
671 | P749 $ARG1 is the * subsidiary of $ARG2
672 |
673 | P112 $ARG2 was founded by $ARG1
674 | P112 $ARG2 , founded by $ARG1
675 | P112 $ARG1 founded $ARG2
676 | P112 $ARG1 , who founded $ARG2
677 | P112 $ARG1 , the founder $ARG2
678 | P112 $ARG2 was formed by $ARG1
679 | P112 $ARG2 , formed by $ARG1
680 | P112 $ARG1 formed $ARG2
681 | P112 $ARG1 , who formed $ARG2
682 | P112 $ARG2 was established by $ARG1
683 | P112 $ARG2 , established by $ARG1
684 | P112 $ARG1 established $ARG2
685 | P112 $ARG1 , who established $ARG2
686 | #org:founded $ARG1 , established $ARG2
687 | #org:founded $ARG1 , founded $ARG2
688 | #org:founded $ARG1 , incorporated $ARG2
689 | #org:founded $ARG1 was established $ARG2
690 | #org:founded $ARG1 was founded $ARG2
691 | #org:founded $ARG1 was incorporated $ARG2
692 | #org:founded established $ARG1 $ARG2
693 | #org:founded formed $ARG1 $ARG2
694 | #org:founded founded $ARG1 $ARG2
695 |
696 |
697 | # org:country_of_headquarters -> P131
698 | P131 $ARG1 's main complex in $ARG2
699 | P131 $ARG1 's main campus in $ARG2
700 | P131 $ARG1 's head office in $ARG2
701 | P131 $ARG1 's main office in $ARG2
702 | P131 $ARG1 's main offices in $ARG2
703 | P131 $ARG1 's headquarters in $ARG2
704 | P131 $ARG1 's main complex is in $ARG2
705 | P131 $ARG1 's main campus is in $ARG2
706 | P131 $ARG1 's head office is in $ARG2
707 | P131 $ARG1 's main office is in $ARG2
708 | P131 $ARG1 's main offices are in $ARG2
709 | P131 $ARG1 's headquarters are in $ARG2
710 | P131 main complex of $ARG1 is in $ARG2
711 | P131 main campus of $ARG1 is in $ARG2
712 | P131 head office of $ARG1 is in $ARG2
713 | P131 main office of $ARG1 is in $ARG2
714 | P131 main offices of $ARG1 are in $ARG2
715 | P131 headquarters of $ARG1 are in $ARG2
716 | P131 main complex of $ARG1 in $ARG2
717 | P131 main campus of $ARG1 in $ARG2
718 | P131 head office of $ARG1 in $ARG2
719 | P131 main office of $ARG1 in $ARG2
720 | P131 main offices of $ARG1 in $ARG2
721 | P131 headquarters of $ARG1 in $ARG2
722 | P131 $ARG1 is headquartered in $ARG2
723 | P131 $ARG1 is based in $ARG2
724 | P131 $ARG1 is located in $ARG2
725 | P131 $ARG1 , headquartered in $ARG2
726 | P131 $ARG1 , based in $ARG2
727 | P131 $ARG1 , located in $ARG2
728 | P131 $ARG2 main complex of $ARG1
729 | P131 $ARG2 main campus of $ARG1
730 | P131 $ARG2 head office of $ARG1
731 | P131 $ARG2 main office of $ARG1
732 | P131 $ARG2 main offices of $ARG1
733 | P131 $ARG2 headquarters of $ARG1
734 | P131 $ARG1 has its main campus in $ARG2
735 | P131 $ARG1 has its head office in $ARG2
736 | P131 $ARG1 has its main office in $ARG2
737 | P131 $ARG1 has its main offices in $ARG2
738 | P131 $ARG1 has its headquarters in $ARG2
739 | P131 $ARG1 , a $ARG2-based
740 | P131 $ARG1 , an $ARG2-based
741 | P131 $ARG1 is a $ARG2-based
742 | P131 $ARG1 is an $ARG2-based
743 | P131 $ARG1 is a * based in $ARG2
744 | P131 $ARG1 is based in * , $ARG2
745 | P131 $ARG1 , based in * , $ARG2
746 |
747 |
748 |
749 |
750 |
--------------------------------------------------------------------------------
/distant_supervision.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | ##########################################################
4 | #
5 | # Distant Supervision
6 | #
7 | ###########################################################
8 |
9 | import pandas as pd
10 | import numpy as np
11 | from bs4 import BeautifulSoup as bs
12 |
13 |
14 | import os
15 | import re
16 | import time
17 | import requests
18 | import urllib
19 | import glob
20 | from codecs import open
21 | from itertools import combinations
22 | from collections import Counter
23 |
24 | import util
25 |
26 | import sys
27 | reload(sys)
28 | sys.setdefaultencoding("utf-8")
29 |
30 |
31 | # global variables
32 | data_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
33 | orig_dir = os.path.join(data_dir, 'orig')
34 | ner_dir = os.path.join(data_dir, 'ner')
35 |
36 |
37 | ner_path = "/usr/local/Cellar/stanford-ner/3.5.2/libexec/"
38 | stanford_classifier = os.path.join(ner_path, 'classifiers', 'english.all.3class.distsim.crf.ser.gz')
39 | stanford_ner = os.path.join(ner_path, 'stanford-ner.jar')
40 |
41 | tag_map = {
42 | 'ORGANIZATION': 'Q43229', # https://www.wikidata.org/wiki/Q43229
43 | 'LOCATION': 'Q17334923', # https://www.wikidata.org/wiki/Q17334923
44 | 'PERSON': 'Q5' # https://www.wikidata.org/wiki/Q5
45 | }
46 |
47 | # column names in DataFrame
48 | col = ['doc_id', 'sent_id', 'sent', 'subj', 'subj_begin', 'subj_end', 'subj_tag',
49 | 'rel', 'obj', 'obj_begin', 'obj_end', 'obj_tag']
50 |
51 |
52 | def sanitize(string):
53 | """clean wikipedia article"""
54 | string = re.sub(r"\[\d{1,3}\]", " ", string)
55 | string = re.sub(r"\[edit\]", " ", string)
56 | string = re.sub(r" {2,}", " ", string)
57 | return string.strip()
58 |
59 |
60 | def download_wiki_articles(doc_id, limit=100, retry=False):
61 | """download wikipedia article via Mediawiki API"""
62 | base_path = "http://en.wikipedia.org/w/api.php?format=xml&action=query"
63 | query = base_path + "&list=random&rnnamespace=0&rnlimit=%d" % limit
64 | r = None
65 | try:
66 | r = urllib.urlopen(query).read()
67 | except Exception as e:
68 | if not retry:
69 | download_wiki_articles(doc_id, limit, retry=True)
70 | else:
71 | print e.message
72 | return None
73 | pages = bs(r, "html.parser").findAll('page')
74 | if len(pages) < 1:
75 | return None
76 | docs = []
77 | for page in pages:
78 | if int(page['id']) in doc_id:
79 | continue
80 |
81 | link = base_path + "&prop=revisions&pageids=%s&rvprop=content&rvparse" % page['id']
82 | content = urllib.urlopen(link).read()
83 | content = bs(content, "html.parser").find('rev').stripped_strings
84 |
85 | # extract paragraph elements only
86 | text = ''
87 | for p in bs(' '.join(content), "html.parser").findAll('p'):
88 | text += ' '.join(p.stripped_strings) + '\n'
89 | #text = text.encode('utf8')
90 | text = sanitize(text)
91 |
92 | # save
93 | if len(text) > 0:
94 | title = re.sub(r"[ /]", "_", page['title'])
95 | filename = page['id'] + '-' + title + '.txt'
96 | docs.append(filename)
97 | with open(os.path.join(orig_dir, filename), 'w', encoding='utf-8') as f:
98 | f.write(text)
99 | return docs
100 |
101 |
102 | def exec_ner(filenames):
103 | """execute Stanford NER"""
104 | for filename in filenames:
105 | in_path = os.path.join(orig_dir, filename)
106 | out_path = os.path.join(ner_dir, filename)
107 | cmd = 'java -mx700m -cp "%s:" edu.stanford.nlp.ie.crf.CRFClassifier' % stanford_ner
108 | cmd += ' -loadClassifier %s -outputFormat tabbedEntities' % stanford_classifier
109 | cmd += ' -textFile %s > %s' % (in_path, out_path)
110 | os.system(cmd)
111 |
112 |
113 | def read_ner_output(filenames):
114 | """read NER output files and store them in a pandas DataFrame"""
115 | rows = []
116 | for filename in filenames:
117 | path = os.path.join(ner_dir, filename)
118 | if not os.path.exists(path):
119 | continue
120 | with open(path, 'r', encoding='utf-8') as f:
121 | doc_id = filename.split('/')[-1].split('-', 1)[0]
122 | counter = 0
123 | tmp = []
124 | for line in f.readlines():
125 | if len(line.strip()) < 1 and len(tmp) > 2:
126 | ent = [i for i, t in enumerate(tmp) if t[1] in tag_map.keys()]
127 | for c in combinations(ent, 2):
128 | dic = {'sent': u''}
129 | dic['doc_id'] = doc_id
130 | dic['sent_id'] = counter
131 | for j, t in enumerate(tmp):
132 | if j == c[0]:
133 | if len(dic['sent']) > 0:
134 | dic['subj_begin'] = len(dic['sent']) + 1
135 | else:
136 | dic['subj_begin'] = 0
137 | if len(dic['sent']) > 0:
138 | dic['subj_end'] = len(dic['sent']) + len(t[0].strip()) + 1
139 | else:
140 | dic['subj_end'] = len(t[0].strip())
141 | dic['subj'] = t[0].strip()
142 | dic['subj_tag'] = t[1].strip()
143 | elif j == c[1]:
144 | dic['obj_begin'] = len(dic['sent']) + 1
145 | dic['obj_end'] = len(dic['sent']) + len(t[0].strip()) + 1
146 | dic['obj'] = t[0].strip()
147 | dic['obj_tag'] = t[1].strip()
148 |
149 | if len(dic['sent']) > 0:
150 | dic['sent'] += ' ' + t[0].strip()
151 | else:
152 | dic['sent'] += t[0].strip()
153 | if len(dic['sent']) > 0:
154 | dic['sent'] += ' ' + t[2].strip()
155 | else:
156 | dic['sent'] += t[2].strip()
157 | #print '"'+dic['sent']+'"', len(dic['sent'])
158 | rows.append(dic)
159 | #print dic
160 | counter += 1
161 | tmp = []
162 | elif len(line.strip()) < 1 and len(tmp) > 0 and len(tmp) <= 2:
163 | continue
164 | elif len(line.strip()) > 0:
165 | e = line.split('\t')
166 | if len(e) == 1:
167 | e.insert(0, '')
168 | e.insert(0, '')
169 | if len(e) == 2 and e[1].strip() in tag_map.keys():
170 | e.append('')
171 | if len(e) != 3:
172 | print e
173 | raise Exception
174 | tmp.append(e)
175 | else:
176 | continue
177 |
178 | return pd.DataFrame(rows)
179 |
180 |
181 | def name2qid(name, tag, alias=False, retry=False):
182 | """find QID (and Freebase ID if given) by name
183 |
184 | >>> name2qid('Barack Obama', 'PERSON') # perfect match
185 | (u'Q76', u'/m/02mjmr')
186 | >>> name2qid('Obama', 'PERSON', alias=True) # alias match
187 | (u'Q76', u'/m/02mjmr')
188 | """
189 |
190 | label = 'rdfs:label'
191 | if alias:
192 | label = 'skos:altLabel'
193 |
194 | hpCharURL = 'https://query.wikidata.org/sparql?query=\
195 | SELECT DISTINCT ?item ?fid \
196 | WHERE {\
197 | ?item '+label+' "'+name+'"@en.\
198 | ?item wdt:P31 ?_instanceOf.\
199 | ?_instanceOf wdt:P279* wd:'+tag_map[tag]+'.\
200 | OPTIONAL { ?item wdt:P646 ?fid. }\
201 | }\
202 | LIMIT 10'
203 | headers = {"Accept": "application/json"}
204 |
205 | # check response
206 | r = None
207 | try:
208 | r = requests.get(hpCharURL, headers=headers)
209 | except requests.exceptions.ConnectionError:
210 | if not retry:
211 | time.sleep(60)
212 | name2qid(name, tag, alias, retry=True)
213 | else:
214 | return None
215 | except Exception as e:
216 | print e.message
217 | return None
218 |
219 | # check json format
220 | try:
221 | response = r.json()
222 | except ValueError: # includes JSONDecodeError
223 | return None
224 |
225 | # parse results
226 | results = []
227 | for elm in response['results']['bindings']:
228 | fid = ''
229 | if elm.has_key('fid'):
230 | fid = elm['fid']['value']
231 | results.append((elm['item']['value'].split('/')[-1], fid))
232 |
233 | if len(results) < 1:
234 | return None
235 | else:
236 | return results[0]
237 |
238 |
239 | def search_property(qid1, qid2, retry=False):
240 | """find property (and schema.org relation if given)
241 |
242 | >>> search_property('Q76', 'Q30') # Q76: Barack Obama, Q30: United States
243 | [(u'P27', u'country of citizenship', u'nationality')]
244 | """
245 |
246 | hpCharURL = 'https://query.wikidata.org/sparql?query= \
247 | SELECT DISTINCT ?p ?l ?s \
248 | WHERE {\
249 | wd:'+qid1+' ?p wd:'+qid2+' .\
250 | ?property ?ref ?p .\
251 | ?property a wikibase:Property .\
252 | ?property rdfs:label ?l FILTER (lang(?l) = "en")\
253 | OPTIONAL { ?property wdt:P1628 ?s FILTER (SUBSTR(str(?s), 1, 18) = "http://schema.org/"). }\
254 | }\
255 | LIMIT 10'
256 | headers = {"Accept": "application/json"}
257 |
258 | # check response
259 | r = None
260 | try:
261 | r = requests.get(hpCharURL, headers=headers)
262 | except requests.exceptions.ConnectionError:
263 | if not retry:
264 | time.sleep(60)
265 | search_property(qid1, qid2, retry=True)
266 | else:
267 | return None
268 | except Exception as e:
269 | print e.message
270 | return None
271 |
272 | # check json format
273 | try:
274 | response = r.json()
275 | except ValueError:
276 | return None
277 |
278 | # parse results
279 | results = []
280 | for elm in response['results']['bindings']:
281 | schema = ''
282 | if elm.has_key('s'):
283 | schema = elm['s']['value'].split('/')[-1]
284 | results.append((elm['p']['value'].split('/')[-1], elm['l']['value'], schema))
285 |
286 | return results
287 |
288 |
289 | def slot_filling(qid, pid, tag, retry=False):
290 | """find slotfiller
291 |
292 | >>> slot_filling('Q76', 'P27', 'LOCATION') # Q76: Barack Obama, P27: country of citizenship
293 | [(u'United States', u'Q30', u'/m/09c7w0')]
294 | """
295 |
296 | hpCharURL = 'https://query.wikidata.org/sparql?query=\
297 | SELECT DISTINCT ?item ?itemLabel ?fid \
298 | WHERE {\
299 | wd:'+qid+' wdt:'+pid+' ?item.\
300 | ?item wdt:P31 ?_instanceOf.\
301 | ?_instanceOf wdt:P279* wd:'+tag_map[tag]+'.\
302 | SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }\
303 | OPTIONAL { ?item wdt:P646 ?fid. }\
304 | }\
305 | LIMIT 100'
306 | headers = {"Accept": "application/json"}
307 |
308 | # check response
309 | r = None
310 | try:
311 | r = requests.get(hpCharURL, headers=headers)
312 | except requests.exceptions.ConnectionError:
313 | if not retry:
314 | time.sleep(60)
315 | slot_filling(qid, pid, tag, retry=True)
316 | else:
317 | return None
318 | except Exception as e:
319 | print e.message
320 | return None
321 |
322 | # check json format
323 | try:
324 | response = r.json()
325 | except ValueError:
326 | return None
327 |
328 | # parse results
329 | results = []
330 | for elm in response['results']['bindings']:
331 | fid = ''
332 | if elm.has_key('fid'):
333 | fid = elm['fid']['value']
334 | results.append((elm['itemLabel']['value'], elm['item']['value'].split('/')[-1], fid))
335 |
336 | return results
337 |
338 |
339 | def loop(step, doc_id, limit, entities, relations, counter):
340 | """Distant Supervision Loop"""
341 | # Download wiki articles
342 | print '[1/4] Downloading wiki articles ...'
343 | docs = download_wiki_articles(doc_id, limit)
344 | if docs is None:
345 | return None
346 |
347 | # Named Entity Recognition
348 | print '[2/4] Performing named entity recognition ...'
349 | exec_ner(docs)
350 | wiki_data = read_ner_output(docs)
351 | path = os.path.join(data_dir, 'candidates%d.tsv' % step)
352 | wiki_data.to_csv(path, sep='\t', encoding='utf-8')
353 | doc_id.extend([int(s) for s in wiki_data.doc_id.unique()])
354 |
355 | # Prepare Containers
356 | unique_entities = set([])
357 | unique_entity_pairs = set([])
358 | for idx, row in wiki_data.iterrows():
359 | unique_entities.add((row['subj'], row['subj_tag']))
360 | unique_entities.add((row['obj'], row['obj_tag']))
361 | unique_entity_pairs.add((row['subj'], row['obj']))
362 |
363 | # Entity Linkage
364 | print '[3/4] Linking entities ...'
365 | for name, tag in unique_entities:
366 | if not entities.has_key(name) and tag in tag_map.keys():
367 | e = name2qid(name, tag, alias=False)
368 | if e is None:
369 | e = name2qid(name, tag, alias=True)
370 | entities[name] = e
371 | util.dump_to_file(os.path.join(data_dir, "entities.cPickle"), entities)
372 |
373 | # Predicate Linkage
374 | print '[4/4] Linking predicates ...'
375 | for subj, obj in unique_entity_pairs:
376 | if not relations.has_key((subj, obj)):
377 | if entities[subj] is not None and entities[obj] is not None:
378 | if (entities[subj][0] != entities[obj][0]) or (subj != obj):
379 | arg1 = entities[subj][0]
380 | arg2 = entities[obj][0]
381 | relations[(subj, obj)] = search_property(arg1, arg2)
382 | #TODO: alternative name relation
383 | #elif (entities[subj][0] == entities[obj][0]) and (subj != obj):
384 | # relations[(subj, obj)] = 'P'
385 | util.dump_to_file(os.path.join(data_dir, "relations.cPickle"), relations)
386 |
387 | # Assign relation
388 | wiki_data['rel'] = pd.Series(index=wiki_data.index, dtype=str)
389 | for idx, row in wiki_data.iterrows():
390 | entity_pair = (row['subj'], row['obj'])
391 |
392 | if relations.has_key(entity_pair):
393 | rel = relations[entity_pair]
394 | if rel is not None and len(rel) > 0:
395 | counter += 1
396 | wiki_data.set_value(idx, 'rel', ', '.join(set([s[0] for s in rel])))
397 | # Save
398 | path = os.path.join(data_dir, 'candidates%d.tsv' % step)
399 | wiki_data.to_csv(path, sep='\t', encoding='utf-8')
400 |
401 | # Cleanup
402 | for f in glob.glob(os.path.join(orig_dir, '*')):
403 | os.remove(f)
404 | for f in glob.glob(os.path.join(ner_dir, '*')):
405 | os.remove(f)
406 |
407 | return doc_id, entities, relations, counter
408 |
409 |
410 | def extract_relations(entities, relations):
411 | """extract relations"""
412 | rows = []
413 | for k, v in relations.iteritems():
414 | if v is not None and len(v) > 0:
415 | for r in v:
416 | dic = {}
417 | dic['subj_qid'] = entities[k[0]][0]
418 | dic['subj_fid'] = entities[k[0]][1]
419 | dic['subj'] = k[0]
420 | dic['obj_qid'] = entities[k[1]][0]
421 | dic['obj_fid'] = entities[k[1]][1]
422 | dic['obj'] = k[1]
423 | dic['rel_id'] = r[0]
424 | dic['rel'] = r[1]
425 | dic['rel_schema'] = r[2]
426 | #TODO: add number of mentions
427 | #dic['wikidata_idx'] = entity_pairs[k]
428 | rows.append(dic)
429 | return pd.DataFrame(rows)
430 |
431 |
432 | def positive_examples():
433 | entities = {}
434 | relations = {}
435 | counter = 0
436 | limit = 1000
437 | doc_id = []
438 | step = 1
439 |
440 | if not os.path.exists(orig_dir):
441 | os.mkdir(orig_dir)
442 | if not os.path.exists(ner_dir):
443 | os.mkdir(ner_dir)
444 |
445 | #for j in range(1, step):
446 | # wiki_data = pd.read_csv(os.path.join(data_dir, "candidates%d.tsv" % j), sep='\t', index_col=0)
447 | # doc_id.extend([int(s) for s in wiki_data.doc_id.unique()])
448 | # counter += int(wiki_data.rel.count())
449 |
450 | while counter < 10000 and step < 100:
451 | print '===== step %d =====' % step
452 | ret = loop(step, doc_id, limit, entities, relations, counter)
453 | if ret is not None:
454 | doc_id, entities, relations, counter = ret
455 |
456 | step += 1
457 |
458 | # positive candidates
459 | positive_data = []
460 | for f in glob.glob(os.path.join(data_dir, 'candidates*.tsv')):
461 | pos = pd.read_csv(f, sep='\t', encoding='utf-8', index_col=0)
462 | positive_data.append(pos[pd.notnull(pos.rel)])
463 | positive_df = pd.concat(positive_data, axis=0, ignore_index=True)
464 | positive_df[col].to_csv(os.path.join(data_dir, 'positive_candidates.tsv'), sep='\t', encoding='utf-8')
465 |
466 | # save relations
467 | pos_rel = extract_relations(entities, relations)
468 | pos_rel.to_csv(os.path.join(data_dir, 'positive_relations.tsv'), sep='\t', encoding='utf-8')
469 |
470 |
471 | def negative_examples():
472 | negative = {}
473 |
474 | unique_pair = set([])
475 | neg_candidates = []
476 |
477 | #TODO: replace with positive_relations.tsv
478 | entities = util.load_from_dump(os.path.join(data_dir, "entities.cPickle"))
479 | relations = util.load_from_dump(os.path.join(data_dir, "relations.cPickle"))
480 |
481 | rel_counter = Counter([u[0] for r in relations.values() if r is not None and len(r) > 0 for u in r])
482 | most_common_rel = [r[0] for r in rel_counter.most_common(10)]
483 |
484 |
485 | for data_path in glob.glob(os.path.join(data_dir, 'candidates*.tsv')):
486 | neg = pd.read_csv(data_path, sep='\t', encoding='utf-8', index_col=0)
487 | negative_df = neg[pd.isnull(neg.rel)]
488 |
489 | # Assign relation
490 | for idx, row in negative_df.iterrows():
491 | if (entities.has_key(row['subj']) and entities[row['subj']] is not None \
492 | and entities.has_key(row['obj']) and entities[row['obj']] is not None):
493 | qid = entities[row['subj']][0]
494 | target = entities[row['obj']][0]
495 | candidates = []
496 | for pid in most_common_rel:
497 | if (qid, pid) not in unique_pair:
498 | unique_pair.add((qid, pid))
499 | items = slot_filling(qid, pid, row['obj_tag'])
500 | if items is not None and len(items) > 1:
501 | qids = [q[1] for q in items]
502 | if target not in qids:
503 | candidates.append(pid)
504 |
505 | if len(candidates) > 0:
506 | row['rel'] = ', '.join(candidates)
507 | neg_candidates.append(row)
508 |
509 |
510 | neg_examples = pd.DataFrame(neg_candidates)
511 | neg_examples[col].to_csv(os.path.join(data_dir, 'negative_candidates.tsv'), sep='\t', encoding='utf-8')
512 |
513 |
514 | # save relations
515 | #pos_rel = extract_relations(entities, negative)
516 | #pos_rel.to_csv(os.path.join(data_dir, 'positive_relations.tsv'), sep='\t', encoding='utf-8')
517 |
518 |
519 | def load_gold_patterns():
520 | def clean_str(string):
521 | string = re.sub(r", ", " , ", string)
522 | string = re.sub(r"' ", " ' ", string)
523 | string = re.sub(r" \* ", " .* ", string)
524 | string = re.sub(r"\(", "-LRB-", string)
525 | string = re.sub(r"\)", "-RRB-", string)
526 | string = re.sub(r" {2,}", " ", string)
527 | return string.strip()
528 |
529 | g_patterns = []
530 | g_labels = []
531 | with open(os.path.join(data_dir, 'gold_patterns.tsv'), 'r') as f:
532 | for line in f.readlines():
533 | line = line.strip()
534 | if len(line) > 0 and not line.startswith('#'):
535 | e = line.split('\t', 1)
536 | if len(e) > 1:
537 | g_patterns.append(clean_str(e[1]))
538 | g_labels.append(e[0])
539 | else:
540 | print e
541 | raise Exception('Process Error: %s' % os.path.join(data_dir, 'gold_patterns.tsv'))
542 |
543 | return pd.DataFrame({'pattern': g_patterns, 'label': g_labels})
544 |
545 |
546 | def score_reliability(gold_patterns, sent, rel, subj, obj):
547 | for name, group in gold_patterns.groupby('label'):
548 | if name in [r.strip() for r in rel.split(',')]:
549 | for i, g in group.iterrows():
550 | pattern = g['pattern']
551 | pattern = re.sub(r'\$ARG(0|1)', subj, pattern, count=1)
552 | pattern = re.sub(r'\$ARG(0|2)', obj, pattern, count=1)
553 | match = re.search(pattern, sent)
554 | if match:
555 | return 1.0
556 | return 0.0
557 |
558 |
559 | def extract_positive():
560 | if not os.path.exists(os.path.join(data_dir, 'mlmi')):
561 | os.mkdir(os.path.join(data_dir, 'mlmi'))
562 | if not os.path.exists(os.path.join(data_dir, 'er')):
563 | os.mkdir(os.path.join(data_dir, 'er'))
564 |
565 | # read gold patterns to extract attention
566 | gold_patterns = load_gold_patterns()
567 |
568 |
569 | #TODO: replace with negative_relations.tsv
570 | entities = util.load_from_dump(os.path.join(data_dir, "entities.cPickle"))
571 | relations = util.load_from_dump(os.path.join(data_dir, "relations.cPickle"))
572 |
573 | # filter out the relations which occur less than 50 times
574 | rel_c = Counter([u[0] for r in relations.values() if r is not None and len(r) > 0 for u in r])
575 | rel_c_top = [k for k, v in rel_c.most_common(50) if v >= 50]
576 |
577 | # positive examples
578 | positive_df = pd.read_csv(os.path.join(data_dir, 'positive_candidates.tsv'),
579 | sep='\t', encoding='utf-8', index_col=0)
580 |
581 | positive_df['right'] = pd.Series(index=positive_df.index, dtype=str)
582 | positive_df['middle'] = pd.Series(index=positive_df.index, dtype=str)
583 | positive_df['left'] = pd.Series(index=positive_df.index, dtype=str)
584 | positive_df['clean'] = pd.Series(index=positive_df.index, dtype=str)
585 | positive_df['label'] = pd.Series(index=positive_df.index, dtype=str)
586 | positive_df['attention'] = pd.Series([0.0]*len(positive_df), index=positive_df.index, dtype=np.float32)
587 |
588 | num_er = 0
589 | with open(os.path.join(data_dir, 'er', 'source.txt'), 'w', encoding='utf-8') as f:
590 | for idx, row in positive_df.iterrows():
591 |
592 | # restore relation
593 | rel = ['<' + l.strip() + '>' for l in row['rel'].split(',') if l.strip() in rel_c_top]
594 | if len(rel) > 0:
595 |
596 | s = row['sent']
597 | subj = '<' + entities[row['subj'].encode('utf-8')][0] + '>'
598 | obj = '<' + entities[row['obj'].encode('utf-8')][0] + '>'
599 | left = s[:row['subj_begin']] + subj
600 | middle = s[row['subj_end']:row['obj_begin']]
601 | right = obj + s[row['obj_end']:]
602 | text = left.strip() + ' ' + middle.strip() + ' ' + right.strip()
603 |
604 | # check if begin-end position is correct
605 | assert s[row['subj_begin']:row['subj_end']] == row['subj']
606 | assert s[row['obj_begin']:row['obj_end']] == row['obj']
607 |
608 | # MLMI dataset
609 | # filter out too long sentences
610 | if len(left.split()) < 100 and len(middle.split()) < 100 and len(right.split()) < 100:
611 |
612 | positive_df.set_value(idx, 'right', right.strip())
613 | positive_df.set_value(idx, 'middle', middle.strip())
614 | positive_df.set_value(idx, 'left', left.strip())
615 | positive_df.set_value(idx, 'clean', text.strip())
616 |
617 | # binarize label
618 | label = ['0'] * len(rel_c_top)
619 | for u in row['rel'].split(','):
620 | if u.strip() in rel_c_top:
621 | label[rel_c_top.index(u.strip())] = '1'
622 | positive_df.set_value(idx, 'label', ' '.join(label))
623 |
624 | # score reliability if positive
625 | reliability = score_reliability(gold_patterns, s, row['rel'], row['subj'], row['obj'])
626 | positive_df.set_value(idx, 'attention', reliability)
627 |
628 | # ER dataset
629 | for r in rel:
630 | num_er += 1
631 | f.write(subj + ' ' + r + ' ' + obj + '\n')
632 |
633 | with open(os.path.join(data_dir, 'er', 'target.txt'), 'w', encoding='utf-8') as f:
634 | for _ in range(num_er):
635 | f.write('1 0\n')
636 |
637 | positive_df_valid = positive_df[pd.notnull(positive_df.clean)]
638 | assert len(positive_df_valid['clean']) == len(positive_df_valid['label'])
639 |
640 | positive_df_valid['right'].to_csv(os.path.join(data_dir, 'mlmi', 'source.right'),
641 | sep='\t', index=False, header=False, encoding='utf-8')
642 | positive_df_valid['middle'].to_csv(os.path.join(data_dir, 'mlmi', 'source.middle'),
643 | sep='\t', index=False, header=False, encoding='utf-8')
644 | positive_df_valid['left'].to_csv(os.path.join(data_dir, 'mlmi', 'source.left'),
645 | sep='\t', index=False, header=False, encoding='utf-8')
646 | positive_df_valid['clean'].to_csv(os.path.join(data_dir, 'mlmi', 'source.txt'),
647 | sep='\t', index=False, header=False, encoding='utf-8')
648 | positive_df_valid['label'].to_csv(os.path.join(data_dir, 'mlmi', 'target.txt'),
649 | sep='\t', index=False, header=False, encoding='utf-8')
650 | positive_df_valid['attention'].to_csv(os.path.join(data_dir, 'mlmi', 'source.att'),
651 | sep='\t', index=False, header=False, encoding='utf-8')
652 |
653 |
654 | def extract_negative():
655 | entities = util.load_from_dump(os.path.join(data_dir, "entities.cPickle"))
656 |
657 | # negative examples
658 | negative_df = pd.read_csv(os.path.join(data_dir, 'negative_candidates.tsv'),
659 | sep='\t', encoding='utf-8', index_col=0)
660 |
661 | with open(os.path.join(data_dir, 'er', 'source.txt'), 'a', encoding='utf-8') as source_file:
662 | with open(os.path.join(data_dir, 'er', 'target.txt'), 'a', encoding='utf-8') as target_file:
663 | for idx, row in negative_df.iterrows():
664 | s = row['sent']
665 |
666 | subj = '<' + entities[row['subj'].encode('utf-8')][0] + '>'
667 | obj = '<' + entities[row['obj'].encode('utf-8')][0] + '>'
668 | rel = ['<' + l.strip() + '>' for l in row['rel'].split(',')]
669 |
670 | assert s[row['subj_begin']:row['subj_end']] == row['subj']
671 | assert s[row['obj_begin']:row['obj_end']] == row['obj']
672 |
673 | if len(rel) > 0:
674 | for r in rel:
675 | source_file.write(subj + ' ' + r + ' ' + obj + '\n')
676 | target_file.write('0 1\n')
677 |
678 | def main():
679 | # gather positive examples
680 | if not os.path.exists(os.path.join(data_dir, 'positive_candidates.tsv')):
681 | positive_examples()
682 | extract_positive()
683 |
684 | # gather negative examples
685 | if not os.path.exists(os.path.join(data_dir, 'negative_candidates.tsv')):
686 | negative_examples()
687 | extract_negative()
688 |
689 |
690 |
691 |
692 | if __name__ == '__main__':
693 | main()
694 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | ##########################################################
4 | #
5 | # Attention-based Convolutional Neural Network
6 | # for Context-wise Learning
7 | #
8 | #
9 | # Note: this implementation is mostly based on
10 | # https://github.com/yuhaozhang/sentence-convnet/blob/master/eval.py
11 | #
12 | ##########################################################
13 |
14 | from datetime import datetime
15 | import os
16 | import tensorflow as tf
17 | import numpy as np
18 | import util
19 |
20 |
21 | FLAGS = tf.app.flags.FLAGS
22 | this_dir = os.path.abspath(os.path.dirname(__file__))
23 | tf.app.flags.DEFINE_string('train_dir', os.path.join(this_dir, 'models', 'er-cnn'), 'Directory of the checkpoint files')
24 |
25 |
26 | def evaluate(eval_data, config):
27 | """ Build evaluation graph and run. """
28 |
29 | with tf.Graph().as_default():
30 | with tf.variable_scope('cnn'):
31 | if config.has_key('contextwise') and config['contextwise']:
32 | import cnn_context
33 | m = cnn_context.Model(config, is_train=False)
34 | else:
35 | import cnn
36 | m = cnn.Model(config, is_train=False)
37 | saver = tf.train.Saver(tf.global_variables())
38 |
39 | with tf.Session() as sess:
40 | ckpt = tf.train.get_checkpoint_state(config['train_dir'])
41 | if ckpt and ckpt.model_checkpoint_path:
42 | saver.restore(sess, ckpt.model_checkpoint_path)
43 | else:
44 | raise IOError("Loading checkpoint file failed!")
45 |
46 | print "\nStart evaluation on test set ...\n"
47 | if config.has_key('contextwise') and config['contextwise']:
48 | left_batch, middle_batch, right_batch, y_batch, _ = zip(*eval_data)
49 | feed = {m.left: np.array(left_batch),
50 | m.middle: np.array(middle_batch),
51 | m.right: np.array(right_batch),
52 | m.labels: np.array(y_batch)}
53 | else:
54 | x_batch, y_batch, _ = zip(*eval_data)
55 | feed = {m.inputs: np.array(x_batch), m.labels: np.array(y_batch)}
56 | loss, eval = sess.run([m.total_loss, m.eval_op], feed_dict=feed)
57 | pre, rec = zip(*eval)
58 |
59 | auc = util.calc_auc_pr(pre, rec)
60 | f1 = (2.0 * pre[5] * rec[5]) / (pre[5] + rec[5])
61 | print '%s: loss = %.6f, p = %.4f, r = %4.4f, f1 = %.4f, auc = %.4f' % (datetime.now(), loss,
62 | pre[5], rec[5], f1, auc)
63 | return pre, rec
64 |
65 |
66 | def main(argv=None):
67 | restore_param = util.load_from_dump(os.path.join(FLAGS.train_dir, 'flags.cPickle'))
68 | restore_param['train_dir'] = FLAGS.train_dir
69 |
70 | if restore_param.has_key('contextwise') and restore_param['contextwise']:
71 | source_path = os.path.join(restore_param['data_dir'], "ids")
72 | target_path = os.path.join(restore_param['data_dir'], "target.txt")
73 | _, data = util.read_data_contextwise(source_path, target_path, restore_param['sent_len'],
74 | train_size=restore_param['train_size'])
75 | else:
76 | source_path = os.path.join(restore_param['data_dir'], "ids.txt")
77 | target_path = os.path.join(restore_param['data_dir'], "target.txt")
78 | _, data = util.read_data(source_path, target_path, restore_param['sent_len'],
79 | train_size=restore_param['train_size'])
80 |
81 | pre, rec = evaluate(data, restore_param)
82 | util.dump_to_file(os.path.join(FLAGS.train_dir, 'results.cPickle'), {'precision': pre, 'recall': rec})
83 |
84 |
85 | if __name__ == '__main__':
86 | tf.app.run()
87 |
--------------------------------------------------------------------------------
/img/auc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/auc.png
--------------------------------------------------------------------------------
/img/cnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/cnn.png
--------------------------------------------------------------------------------
/img/emb_er.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_er.png
--------------------------------------------------------------------------------
/img/emb_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_left.png
--------------------------------------------------------------------------------
/img/emb_mlmi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_mlmi.png
--------------------------------------------------------------------------------
/img/emb_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_right.png
--------------------------------------------------------------------------------
/img/f1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/f1.png
--------------------------------------------------------------------------------
/img/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/loss.png
--------------------------------------------------------------------------------
/img/pr_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/pr_curve.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from datetime import datetime
4 | import time
5 | import os
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 | import cnn
10 | import util
11 |
12 |
13 | FLAGS = tf.app.flags.FLAGS
14 |
15 | # train parameters
16 | this_dir = os.path.abspath(os.path.dirname(__file__))
17 | tf.app.flags.DEFINE_string('data_dir', os.path.join(this_dir, 'data'), 'Directory of the data')
18 | tf.app.flags.DEFINE_string('train_dir', os.path.join(this_dir, 'train'),
19 | 'Directory to save training checkpoint files')
20 | tf.app.flags.DEFINE_integer('train_size', 100000, 'Number of training examples')
21 | tf.app.flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to run')
22 | tf.app.flags.DEFINE_boolean('use_pretrain', False, 'Use word2vec pretrained embeddings or not')
23 |
24 | tf.app.flags.DEFINE_string('optimizer', 'adam',
25 | 'Optimizer to use. Must be one of "sgd", "adagrad", "adadelta" and "adam"')
26 | tf.app.flags.DEFINE_float('init_lr', 1e-3, 'Initial learning rate')
27 | tf.app.flags.DEFINE_float('lr_decay', 0.95, 'LR decay rate')
28 | tf.app.flags.DEFINE_integer('tolerance_step', 500,
29 | 'Decay the lr after loss remains unchanged for this number of steps')
30 | tf.app.flags.DEFINE_float('dropout', 0.5, 'Dropout rate. 0 is no dropout.')
31 |
32 | # logging
33 | tf.app.flags.DEFINE_integer('log_step', 10, 'Display log to stdout after this step')
34 | tf.app.flags.DEFINE_integer('summary_step', 50,
35 | 'Write summary (evaluate model on dev set) after this step')
36 | tf.app.flags.DEFINE_integer('checkpoint_step', 100, 'Save model after this step')
37 |
38 |
39 | def train(train_data, test_data):
40 | # train_dir
41 | timestamp = str(int(time.time()))
42 | out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, timestamp))
43 |
44 | # save flags
45 | if not os.path.exists(out_dir):
46 | os.mkdir(out_dir)
47 | FLAGS._parse_flags()
48 | config = dict(FLAGS.__flags.items())
49 |
50 | # Window_size must not be larger than the sent_len
51 | if config['sent_len'] < config['max_window']:
52 | config['max_window'] = config['sent_len']
53 |
54 | util.dump_to_file(os.path.join(out_dir, 'flags.cPickle'), config)
55 | print "Parameters:"
56 | for k, v in config.iteritems():
57 | print '%20s %r' % (k, v)
58 |
59 | num_batches_per_epoch = int(np.ceil(float(len(train_data))/FLAGS.batch_size))
60 | max_steps = num_batches_per_epoch * FLAGS.num_epochs
61 |
62 | with tf.Graph().as_default():
63 | with tf.variable_scope('cnn', reuse=None):
64 | m = cnn.Model(config, is_train=True)
65 | with tf.variable_scope('cnn', reuse=True):
66 | mtest = cnn.Model(config, is_train=False)
67 |
68 | # checkpoint
69 | saver = tf.train.Saver(tf.global_variables())
70 | save_path = os.path.join(out_dir, 'model.ckpt')
71 | summary_op = tf.summary.merge_all()
72 |
73 | # session
74 | with tf.Session().as_default() as sess:
75 | proj_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
76 | embedding = proj_config.embeddings.add()
77 | embedding.tensor_name = m.W_emb.name
78 | embedding.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt')
79 |
80 | train_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "train"), graph=sess.graph)
81 | dev_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "dev"), graph=sess.graph)
82 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_summary_writer, proj_config)
83 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(dev_summary_writer, proj_config)
84 |
85 | sess.run(tf.global_variables_initializer())
86 |
87 | # assign pretrained embeddings
88 | if FLAGS.use_pretrain:
89 | print "Initialize model with pretrained embeddings..."
90 | pretrained_embedding = np.load(os.path.join(FLAGS.data_dir, 'emb.npy'))
91 | m.assign_embedding(sess, pretrained_embedding)
92 |
93 | # initialize parameters
94 | current_lr = FLAGS.init_lr
95 | lowest_loss_value = float("inf")
96 | decay_step_counter = 0
97 | global_step = 0
98 |
99 | # evaluate on dev set
100 | def dev_step(mtest, sess):
101 | dev_loss = []
102 | dev_auc = []
103 | dev_f1_score = []
104 |
105 | # create batch
106 | test_batches = util.batch_iter(test_data, batch_size=FLAGS.batch_size, num_epochs=1, shuffle=False)
107 | for batch in test_batches:
108 | x_batch, y_batch, _ = zip(*batch)
109 | loss_value, eval_value = sess.run([mtest.total_loss, mtest.eval_op],
110 | feed_dict={mtest.inputs: np.array(x_batch), mtest.labels: np.array(y_batch)})
111 | dev_loss.append(loss_value)
112 | pre, rec = zip(*eval_value)
113 | dev_auc.append(util.calc_auc_pr(pre, rec))
114 | dev_f1_score.append((2.0 * pre[5] * rec[5]) / (pre[5] + rec[5])) # threshold = 0.5
115 |
116 | return np.mean(dev_loss), np.mean(dev_auc), np.mean(dev_f1_score)
117 |
118 | # train loop
119 | print "\nStart training (save checkpoints in %s)\n" % out_dir
120 | train_loss = []
121 | train_auc = []
122 | train_f1_score = []
123 | train_batches = util.batch_iter(train_data, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)
124 | for batch in train_batches:
125 | batch_size = len(batch)
126 |
127 | m.assign_lr(sess, current_lr)
128 | global_step += 1
129 |
130 | x_batch, y_batch, a_batch = zip(*batch)
131 | feed = {m.inputs: np.array(x_batch), m.labels: np.array(y_batch)}
132 | if FLAGS.attention:
133 | feed[m.attention] = np.array(a_batch)
134 | start_time = time.time()
135 | _, loss_value, eval_value = sess.run([m.train_op, m.total_loss, m.eval_op], feed_dict=feed)
136 | proc_duration = time.time() - start_time
137 | train_loss.append(loss_value)
138 | pre, rec = zip(*eval_value)
139 | auc = util.calc_auc_pr(pre, rec)
140 | f1 = (2.0 * pre[5] * rec[5]) / (pre[5] + rec[5]) # threshold = 0.5
141 | train_auc.append(auc)
142 | train_f1_score.append(f1)
143 |
144 | assert not np.isnan(loss_value), "Model loss is NaN."
145 |
146 | # print log
147 | if global_step % FLAGS.log_step == 0:
148 | examples_per_sec = batch_size / proc_duration
149 | format_str = '%s: step %d/%d, f1 = %.4f, auc = %.4f, loss = %.4f ' + \
150 | '(%.1f examples/sec; %.3f sec/batch), lr: %.6f'
151 | print format_str % (datetime.now(), global_step, max_steps, f1, auc, loss_value,
152 | examples_per_sec, proc_duration, current_lr)
153 |
154 | # write summary
155 | if global_step % FLAGS.summary_step == 0:
156 | summary_str = sess.run(summary_op)
157 | train_summary_writer.add_summary(summary_str, global_step)
158 | dev_summary_writer.add_summary(summary_str, global_step)
159 |
160 | # summary loss, f1
161 | train_summary_writer.add_summary(
162 | _summary_for_scalar('loss', np.mean(train_loss)), global_step=global_step)
163 | train_summary_writer.add_summary(
164 | _summary_for_scalar('auc', np.mean(train_auc)), global_step=global_step)
165 | train_summary_writer.add_summary(
166 | _summary_for_scalar('f1', np.mean(train_f1_score)), global_step=global_step)
167 |
168 | dev_loss, dev_auc, dev_f1 = dev_step(mtest, sess)
169 | dev_summary_writer.add_summary(
170 | _summary_for_scalar('loss', dev_loss), global_step=global_step)
171 | dev_summary_writer.add_summary(
172 | _summary_for_scalar('auc', dev_auc), global_step=global_step)
173 | dev_summary_writer.add_summary(
174 | _summary_for_scalar('f1', dev_f1), global_step=global_step)
175 |
176 | print "\n===== write summary ====="
177 | print "%s: step %d/%d: train_loss = %.6f, train_auc = %.4f, train_f1 = %.4f" \
178 | % (datetime.now(), global_step, max_steps,
179 | np.mean(train_loss), np.mean(train_auc), np.mean(train_f1_score))
180 | print "%s: step %d/%d: dev_loss = %.6f, dev_auc = %.4f, dev_f1 = %.4f\n" \
181 | % (datetime.now(), global_step, max_steps, dev_loss, dev_auc, dev_f1)
182 |
183 | # reset container
184 | train_loss = []
185 | train_auc = []
186 | train_f1_score = []
187 |
188 | # decay learning rate if necessary
189 | if loss_value < lowest_loss_value:
190 | lowest_loss_value = loss_value
191 | decay_step_counter = 0
192 | else:
193 | decay_step_counter += 1
194 | if decay_step_counter >= FLAGS.tolerance_step:
195 | current_lr *= FLAGS.lr_decay
196 | print '%s: step %d/%d, Learning rate decays to %.5f' % \
197 | (datetime.now(), global_step, max_steps, current_lr)
198 | decay_step_counter = 0
199 |
200 | # stop learning if learning rate is too low
201 | if current_lr < 1e-5:
202 | break
203 |
204 | # save checkpoint
205 | if global_step % FLAGS.checkpoint_step == 0:
206 | saver.save(sess, save_path, global_step=global_step)
207 | saver.save(sess, save_path, global_step=global_step)
208 |
209 |
210 | def _summary_for_scalar(name, value):
211 | return tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=float(value))])
212 |
213 |
214 | def main(argv=None):
215 | if not os.path.exists(FLAGS.train_dir):
216 | os.mkdir(FLAGS.train_dir)
217 |
218 | # load dataset
219 | source_path = os.path.join(FLAGS.data_dir, 'ids.txt')
220 | target_path = os.path.join(FLAGS.data_dir, 'target.txt')
221 | attention_path = None
222 | if FLAGS.attention:
223 | if os.path.exists(os.path.join(FLAGS.data_dir, 'source.att')):
224 | attention_path = os.path.join(FLAGS.data_dir, 'source.att')
225 | else:
226 | raise ValueError("Attention file %s not found.", os.path.join(FLAGS.data_dir, 'source.att'))
227 | train_data, test_data = util.read_data(source_path, target_path, FLAGS.sent_len,
228 | attention_path=attention_path, train_size=FLAGS.train_size)
229 | train(train_data, test_data)
230 |
231 |
232 | if __name__ == '__main__':
233 | tf.app.run()
234 |
--------------------------------------------------------------------------------
/train_context.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from datetime import datetime
4 | import time
5 | import os
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 | import cnn_context
10 | import util
11 |
12 |
13 | FLAGS = tf.app.flags.FLAGS
14 |
15 | # train parameters
16 | this_dir = os.path.abspath(os.path.dirname(__file__))
17 | tf.app.flags.DEFINE_string('data_dir', os.path.join(this_dir, 'data', 'mlmi'), 'Directory of the data')
18 | tf.app.flags.DEFINE_string('train_dir', os.path.join(this_dir, 'train'),
19 | 'Directory to save training checkpoint files')
20 | tf.app.flags.DEFINE_integer('train_size', 100000, 'Number of training examples')
21 | tf.app.flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to run')
22 | tf.app.flags.DEFINE_boolean('use_pretrain', False, 'Use word2vec pretrained embeddings or not')
23 |
24 | tf.app.flags.DEFINE_string('optimizer', 'adam',
25 | 'Optimizer to use. Must be one of "sgd", "adagrad", "adadelta" and "adam"')
26 | tf.app.flags.DEFINE_float('init_lr', 1e-3, 'Initial learning rate')
27 | tf.app.flags.DEFINE_float('lr_decay', 0.95, 'LR decay rate')
28 | tf.app.flags.DEFINE_integer('tolerance_step', 500,
29 | 'Decay the lr after loss remains unchanged for this number of steps')
30 | tf.app.flags.DEFINE_float('dropout', 0.5, 'Dropout rate. 0 is no dropout.')
31 |
32 | # logging
33 | tf.app.flags.DEFINE_integer('log_step', 10, 'Display log to stdout after this step')
34 | tf.app.flags.DEFINE_integer('summary_step', 50,
35 | 'Write summary (evaluate model on dev set) after this step')
36 | tf.app.flags.DEFINE_integer('checkpoint_step', 100, 'Save model after this step')
37 |
38 |
39 | def train(train_data, test_data):
40 | # train_dir
41 | timestamp = str(int(time.time()))
42 | out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, timestamp))
43 |
44 | # save flags
45 | if not os.path.exists(out_dir):
46 | os.mkdir(out_dir)
47 | FLAGS._parse_flags()
48 | config = dict(FLAGS.__flags.items())
49 |
50 | # Window_size must not be larger than the sent_len
51 | if config['sent_len'] < config['max_window']:
52 | config['max_window'] = config['sent_len']
53 |
54 | # flag to restore the contextwise model
55 | config['contextwise'] = True
56 |
57 | # save flags
58 | util.dump_to_file(os.path.join(out_dir, 'flags.cPickle'), config)
59 | print "Parameters:"
60 | for k, v in config.iteritems():
61 | print '%20s %r' % (k, v)
62 |
63 | # max number of steps
64 | num_batches_per_epoch = int(np.ceil(float(len(train_data))/FLAGS.batch_size))
65 | max_steps = num_batches_per_epoch * FLAGS.num_epochs
66 |
67 | with tf.Graph().as_default():
68 | with tf.variable_scope('cnn', reuse=None):
69 | m = cnn_context.Model(config, is_train=True)
70 | with tf.variable_scope('cnn', reuse=True):
71 | mtest = cnn_context.Model(config, is_train=False)
72 |
73 | # checkpoint
74 | saver = tf.train.Saver(tf.global_variables())
75 | save_path = os.path.join(out_dir, 'model.ckpt')
76 | summary_op = tf.summary.merge_all()
77 |
78 | # session
79 | with tf.Session().as_default() as sess:
80 | proj_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
81 | embedding_left = proj_config.embeddings.add()
82 | embedding_middle = proj_config.embeddings.add()
83 | embedding_right = proj_config.embeddings.add()
84 | embedding_left.tensor_name = m.W_emb_left.name
85 | embedding_middle.tensor_name = m.W_emb_middle.name
86 | embedding_right.tensor_name = m.W_emb_right.name
87 | embedding_left.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt')
88 | embedding_middle.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt')
89 | embedding_right.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt')
90 |
91 | train_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "train"), graph=sess.graph)
92 | dev_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "dev"), graph=sess.graph)
93 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_summary_writer, proj_config)
94 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(dev_summary_writer, proj_config)
95 |
96 | sess.run(tf.global_variables_initializer())
97 |
98 | # assign pretrained embeddings
99 | if FLAGS.use_pretrain:
100 | print "Initialize model with pretrained embeddings..."
101 | pretrained_embedding = np.load(os.path.join(FLAGS.data_dir, 'emb.npy'))
102 | m.assign_embedding(sess, pretrained_embedding)
103 |
104 | # initialize parameters
105 | current_lr = FLAGS.init_lr
106 | lowest_loss_value = float("inf")
107 | decay_step_counter = 0
108 | global_step = 0
109 |
110 | # evaluate on dev set
111 | def dev_step(mtest, sess):
112 | dev_loss = []
113 | dev_auc = []
114 | dev_f1_score = []
115 |
116 | # create batch
117 | test_batches = util.batch_iter(test_data, batch_size=FLAGS.batch_size, num_epochs=1, shuffle=False)
118 | for batch in test_batches:
119 | left_batch, middle_batch, right_batch, y_batch, _ = zip(*batch)
120 | feed = {mtest.left: np.array(left_batch),
121 | mtest.middle: np.array(middle_batch),
122 | mtest.right: np.array(right_batch),
123 | mtest.labels: np.array(y_batch)}
124 | loss_value, eval_value = sess.run([mtest.total_loss, mtest.eval_op], feed_dict=feed)
125 | dev_loss.append(loss_value)
126 | pre, rec = zip(*eval_value)
127 | dev_auc.append(util.calc_auc_pr(pre, rec))
128 | dev_f1_score.append((2.0 * pre[5] * rec[5]) / (pre[5] + rec[5])) # threshold = 0.5
129 |
130 | return np.mean(dev_loss), np.mean(dev_auc), np.mean(dev_f1_score)
131 |
132 | # train loop
133 | print "\nStart training (save checkpoints in %s)\n" % out_dir
134 | train_loss = []
135 | train_auc = []
136 | train_f1_score = []
137 | train_batches = util.batch_iter(train_data, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)
138 | for batch in train_batches:
139 | batch_size = len(batch)
140 |
141 | m.assign_lr(sess, current_lr)
142 | global_step += 1
143 |
144 | left_batch, middle_batch, right_batch, y_batch, a_batch = zip(*batch)
145 | feed = {m.left: np.array(left_batch),
146 | m.middle: np.array(middle_batch),
147 | m.right: np.array(right_batch),
148 | m.labels: np.array(y_batch)}
149 | if FLAGS.attention:
150 | feed[m.attention] = np.array(a_batch)
151 | start_time = time.time()
152 | _, loss_value, eval_value = sess.run([m.train_op, m.total_loss, m.eval_op], feed_dict=feed)
153 | proc_duration = time.time() - start_time
154 | train_loss.append(loss_value)
155 | pre, rec = zip(*eval_value)
156 | auc = util.calc_auc_pr(pre, rec)
157 | f1 = (2.0 * pre[5] * rec[5]) / (pre[5] + rec[5]) # threshold = 0.5
158 | train_auc.append(auc)
159 | train_f1_score.append(f1)
160 |
161 | assert not np.isnan(loss_value), "Model loss is NaN."
162 |
163 | # print log
164 | if global_step % FLAGS.log_step == 0:
165 | examples_per_sec = batch_size / proc_duration
166 | format_str = '%s: step %d/%d, f1 = %.4f, auc = %.4f, loss = %.4f ' + \
167 | '(%.1f examples/sec; %.3f sec/batch), lr: %.6f'
168 | print format_str % (datetime.now(), global_step, max_steps, f1, auc, loss_value,
169 | examples_per_sec, proc_duration, current_lr)
170 |
171 | # write summary
172 | if global_step % FLAGS.summary_step == 0:
173 | summary_str = sess.run(summary_op)
174 | train_summary_writer.add_summary(summary_str, global_step)
175 | dev_summary_writer.add_summary(summary_str, global_step)
176 |
177 | # summary loss, f1
178 | train_summary_writer.add_summary(
179 | _summary_for_scalar('loss', np.mean(train_loss)), global_step=global_step)
180 | train_summary_writer.add_summary(
181 | _summary_for_scalar('auc', np.mean(train_auc)), global_step=global_step)
182 | train_summary_writer.add_summary(
183 | _summary_for_scalar('f1', np.mean(train_f1_score)), global_step=global_step)
184 |
185 | dev_loss, dev_auc, dev_f1 = dev_step(mtest, sess)
186 | dev_summary_writer.add_summary(
187 | _summary_for_scalar('loss', dev_loss), global_step=global_step)
188 | dev_summary_writer.add_summary(
189 | _summary_for_scalar('auc', dev_auc), global_step=global_step)
190 | dev_summary_writer.add_summary(
191 | _summary_for_scalar('f1', dev_f1), global_step=global_step)
192 |
193 | print "\n===== write summary ====="
194 | print "%s: step %d/%d: train_loss = %.6f, train_auc = %.4f train_f1 = %.4f" \
195 | % (datetime.now(), global_step, max_steps,
196 | np.mean(train_loss), np.mean(train_auc), np.mean(train_f1_score))
197 | print "%s: step %d/%d: dev_loss = %.6f, dev_auc = %.4f dev_f1 = %.4f\n" \
198 | % (datetime.now(), global_step, max_steps, dev_loss, dev_auc, dev_f1)
199 |
200 | # reset container
201 | train_loss = []
202 | train_auc = []
203 | train_f1_score = []
204 |
205 | # decay learning rate if necessary
206 | if loss_value < lowest_loss_value:
207 | lowest_loss_value = loss_value
208 | decay_step_counter = 0
209 | else:
210 | decay_step_counter += 1
211 | if decay_step_counter >= FLAGS.tolerance_step:
212 | current_lr *= FLAGS.lr_decay
213 | print '%s: step %d/%d, Learning rate decays to %.5f' % \
214 | (datetime.now(), global_step, max_steps, current_lr)
215 | decay_step_counter = 0
216 |
217 | # stop learning if learning rate is too low
218 | if current_lr < 1e-5:
219 | break
220 |
221 | # save checkpoint
222 | if global_step % FLAGS.checkpoint_step == 0:
223 | saver.save(sess, save_path, global_step=global_step)
224 | saver.save(sess, save_path, global_step=global_step)
225 |
226 |
227 | def _summary_for_scalar(name, value):
228 | return tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=float(value))])
229 |
230 |
231 | def main(argv=None):
232 | if not os.path.exists(FLAGS.train_dir):
233 | os.mkdir(FLAGS.train_dir)
234 |
235 | # load contextwise dataset
236 | source_path = os.path.join(FLAGS.data_dir, 'ids')
237 | target_path = os.path.join(FLAGS.data_dir, 'target.txt')
238 | attention_path = None
239 | if FLAGS.attention:
240 | if os.path.exists(os.path.join(FLAGS.data_dir, 'source.att')):
241 | attention_path = os.path.join(FLAGS.data_dir, 'source.att')
242 | else:
243 | raise ValueError("Attention file %s not found.", os.path.join(FLAGS.data_dir, 'source.att'))
244 | train_data, test_data = util.read_data_contextwise(source_path, target_path, FLAGS.sent_len,
245 | attention_path=attention_path, train_size=FLAGS.train_size)
246 | train(train_data, test_data)
247 |
248 |
249 | if __name__ == '__main__':
250 | tf.app.run()
251 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | ##########################################################
4 | #
5 | # Helper functions to load data
6 | #
7 | ###########################################################
8 |
9 | import os
10 | import re
11 | from codecs import open as codecs_open
12 | import cPickle as pickle
13 | import numpy as np
14 |
15 |
16 | # Special vocabulary symbols.
17 | PAD_TOKEN = '' # pad symbol
18 | UNK_TOKEN = '' # unknown word
19 | BOS_TOKEN = '' # begin-of-sentence symbol
20 | EOS_TOKEN = '' # end-of-sentence symbol
21 | NUM_TOKEN = '' # numbers
22 |
23 | # we always put them at the start.
24 | _START_VOCAB = [PAD_TOKEN, UNK_TOKEN]
25 | PAD_ID = 0
26 | UNK_ID = 1
27 |
28 | # Regular expressions used to tokenize.
29 | _DIGIT_RE = re.compile(br"^\d+$")
30 |
31 |
32 | THIS_DIR = os.path.abspath(os.path.dirname(__file__))
33 | RANDOM_SEED = 1234
34 |
35 |
36 | def basic_tokenizer(sequence, bos=True, eos=True):
37 | sequence = re.sub(r'\s{2}', ' ' + EOS_TOKEN + ' ' + BOS_TOKEN + ' ', sequence)
38 | if bos:
39 | sequence = BOS_TOKEN + ' ' + sequence.strip()
40 | if eos:
41 | sequence = sequence.strip() + ' ' + EOS_TOKEN
42 | return sequence.lower().split()
43 |
44 |
45 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size=40000, tokenizer=None, bos=True, eos=True):
46 | """Create vocabulary file (if it does not exist yet) from data file.
47 |
48 | Original taken from
49 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py
50 | """
51 | if not os.path.exists(vocabulary_path):
52 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
53 | vocab = {}
54 | with codecs_open(data_path, "rb", encoding="utf-8") as f:
55 | for line in f.readlines():
56 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line, bos, eos)
57 | for w in tokens:
58 | word = re.sub(_DIGIT_RE, NUM_TOKEN, w)
59 | if word in vocab:
60 | vocab[word] += 1
61 | else:
62 | vocab[word] = 1
63 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
64 | if len(vocab_list) > max_vocabulary_size:
65 | print(" %d words found. Truncate to %d." % (len(vocab_list), max_vocabulary_size))
66 | vocab_list = vocab_list[:max_vocabulary_size]
67 | with codecs_open(vocabulary_path, "wb", encoding="utf-8") as vocab_file:
68 | for w in vocab_list:
69 | vocab_file.write(w + b"\n")
70 |
71 |
72 | def initialize_vocabulary(vocabulary_path):
73 | """Initialize vocabulary from file.
74 |
75 | Original taken from
76 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py
77 | """
78 | if os.path.exists(vocabulary_path):
79 | rev_vocab = []
80 | with codecs_open(vocabulary_path, "rb", encoding="utf-8") as f:
81 | rev_vocab.extend(f.readlines())
82 | rev_vocab = [line.strip() for line in rev_vocab]
83 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
84 | return vocab, rev_vocab
85 | else:
86 | raise ValueError("Vocabulary file %s not found.", vocabulary_path)
87 |
88 |
89 | def sentence_to_token_ids(sentence, vocabulary, tokenizer=None, bos=True, eos=True):
90 | """Convert a string to list of integers representing token-ids.
91 |
92 | Original taken from
93 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py
94 | """
95 | words = tokenizer(sentence) if tokenizer else basic_tokenizer(sentence, bos, eos)
96 | return [vocabulary.get(re.sub(_DIGIT_RE, NUM_TOKEN, w), UNK_ID) for w in words]
97 |
98 |
99 | def data_to_token_ids(data_path, target_path, vocabulary_path, tokenizer=None, bos=True, eos=True):
100 | """Tokenize data file and turn into token-ids using given vocabulary file.
101 |
102 | Original taken from
103 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py
104 | """
105 | if not os.path.exists(target_path):
106 | print("Vectorizing data in %s" % data_path)
107 | vocab, _ = initialize_vocabulary(vocabulary_path)
108 | with codecs_open(data_path, "rb", encoding="utf-8") as data_file:
109 | with codecs_open(target_path, "wb", encoding="utf-8") as tokens_file:
110 | for line in data_file:
111 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, bos, eos)
112 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
113 |
114 |
115 | def shuffle_split(X, y, a=None, train_size=10000, shuffle=True):
116 | """Shuffle and split data into train and test subset"""
117 | _X = np.array(X)
118 | _y = np.array(y)
119 | assert _X.shape[0] == _y.shape[0]
120 |
121 | _a = [None] * _y.shape[0]
122 | if a is not None and len(a) == len(y):
123 | _a = np.array(a)
124 | # compute softmax
125 | _a = np.reshape(np.exp(_a) / np.sum(np.exp(_a)), (_y.shape[0], 1))
126 | assert _a.shape[0] == _y.shape[0]
127 |
128 | print "Splitting data...",
129 | # split train-test
130 | data = np.array(zip(_X, _y, _a))
131 | data_size = _y.shape[0]
132 | if train_size > data_size:
133 | train_size = int(data_size * 0.9)
134 | if shuffle:
135 | np.random.seed(RANDOM_SEED)
136 | shuffle_indices = np.random.permutation(np.arange(data_size))
137 | shuffled_data = data[shuffle_indices]
138 | else:
139 | shuffled_data = data
140 | print "\t%d for train, %d for test" % (train_size, data_size - train_size)
141 | return shuffled_data[:train_size], shuffled_data[train_size:]
142 |
143 |
144 | def read_data(source_path, target_path, sent_len, attention_path=None, train_size=10000, shuffle=True):
145 | """Read source(x), target(y) and attention if given.
146 |
147 | Original taken from
148 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/translate.py
149 | """
150 | _X = []
151 | _y = []
152 | with codecs_open(source_path, mode="r", encoding="utf-8") as source_file:
153 | with codecs_open(target_path, mode="r", encoding="utf-8") as target_file:
154 | source, target = source_file.readline(), target_file.readline()
155 | #counter = 0
156 | print "Loading data...",
157 | while source and target:
158 | #counter += 1
159 | #if counter % 1000 == 0:
160 | # print(" reading data line %d" % counter)
161 | # sys.stdout.flush()
162 | source_ids = [np.int64(x.strip()) for x in source.split()]
163 | if sent_len > len(source_ids):
164 | source_ids += [PAD_ID] * (sent_len - len(source_ids))
165 | assert len(source_ids) == sent_len
166 |
167 | #target = target.split('\t')[0].strip()
168 | target_ids = [np.float32(y.strip()) for y in target.split()]
169 |
170 | _X.append(source_ids)
171 | _y.append(target_ids)
172 | source, target = source_file.readline(), target_file.readline()
173 |
174 | assert len(_X) == len(_y)
175 | print "\t%d examples found." % len(_y)
176 |
177 | _a = None
178 | if attention_path is not None:
179 | with codecs_open(attention_path, mode="r", encoding="utf-8") as att_file:
180 | _a = [np.float32(att.strip()) for att in att_file.readlines()]
181 | assert len(_a) == len(_y)
182 |
183 | return shuffle_split(_X, _y, a=_a, train_size=train_size, shuffle=shuffle)
184 |
185 |
186 | def shuffle_split_contextwise(X, y, a=None, train_size=10000, shuffle=True):
187 | """Shuffle and split data into train and test subset"""
188 |
189 | _left = np.array(X['left'])
190 | _middle = np.array(X['middle'])
191 | _right = np.array(X['right'])
192 | _y = np.array(y)
193 |
194 | _a = [None] * _y.shape[0]
195 | if a is not None and len(a) == len(y):
196 | _a = np.array(a)
197 | # compute softmax
198 | _a = np.reshape(np.exp(_a) / np.sum(np.exp(_a)), (_y.shape[0], 1))
199 | assert _a.shape[0] == _y.shape[0]
200 |
201 | print "Splitting data...",
202 | # split train-test
203 | data = np.array(zip(_left, _middle, _right, _y, _a))
204 | data_size = _y.shape[0]
205 | if train_size > data_size:
206 | train_size = int(data_size * 0.9)
207 | if shuffle:
208 | np.random.seed(RANDOM_SEED)
209 | shuffle_indices = np.random.permutation(np.arange(data_size))
210 | shuffled_data = data[shuffle_indices]
211 | else:
212 | shuffled_data = data
213 | print "\t%d for train, %d for test" % (train_size, data_size - train_size)
214 | return shuffled_data[:train_size], shuffled_data[train_size:]
215 |
216 |
217 | def read_data_contextwise(source_path, target_path, sent_len, attention_path=None, train_size=10000, shuffle=True):
218 | """Read source file and pad the sequence to sent_len,
219 | combine them with target (and attention if given).
220 |
221 | Original taken from
222 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/translate.py
223 | """
224 | print "Loading data...",
225 | _X = {'left': [], 'middle': [], 'right': []}
226 | for context in _X.keys():
227 | path = '%s.%s' % (source_path, context)
228 | with codecs_open(path, mode="r", encoding="utf-8") as source_file:
229 | for source in source_file.readlines():
230 | source_ids = [np.int64(x.strip()) for x in source.split()]
231 | if sent_len > len(source_ids):
232 | source_ids += [PAD_ID] * (sent_len - len(source_ids))
233 | assert len(source_ids) == sent_len
234 | _X[context].append(source_ids)
235 | assert len(_X['left']) == len(_X['middle'])
236 | assert len(_X['right']) == len(_X['middle'])
237 |
238 | _y = []
239 | with codecs_open(target_path, mode="r", encoding="utf-8") as target_file:
240 | for target in target_file.readlines():
241 | target_ids = [np.float32(y.strip()) for y in target.split()]
242 | _y.append(target_ids)
243 | assert len(_X['left']) == len(_y)
244 | print "\t%d examples found." % len(_y)
245 |
246 | _a = None
247 | if attention_path is not None:
248 | with codecs_open(attention_path, mode="r", encoding="utf-8") as att_file:
249 | _a = [np.float32(att.strip()) for att in att_file.readlines()]
250 | assert len(_a) == len(_y)
251 |
252 | return shuffle_split_contextwise(_X, _y, a=_a, train_size=train_size, shuffle=shuffle)
253 |
254 |
255 | def batch_iter(data, batch_size, num_epochs, shuffle=True):
256 | """Generates a batch iterator.
257 |
258 | Original taken from
259 | https://github.com/dennybritz/cnn-text-classification-tf/blob/master/data_helpers.py
260 | """
261 | data = np.array(data)
262 | data_size = len(data)
263 | num_batches_per_epoch = int(np.ceil(float(data_size)/batch_size))
264 | for epoch in range(num_epochs):
265 | # Shuffle data at each epoch
266 | if shuffle:
267 | #np.random.seed(RANDOM_SEED)
268 | shuffle_indices = np.random.permutation(np.arange(data_size))
269 | shuffled_data = data[shuffle_indices]
270 | else:
271 | shuffled_data = data
272 | for batch_num in range(num_batches_per_epoch):
273 | start_index = batch_num * batch_size
274 | end_index = min((batch_num + 1) * batch_size, data_size)
275 | yield shuffled_data[start_index:end_index]
276 |
277 |
278 | def dump_to_file(filename, obj):
279 | with open(filename, 'wb') as outfile:
280 | pickle.dump(obj, file=outfile)
281 | return
282 |
283 |
284 | def load_from_dump(filename):
285 | with open(filename, 'rb') as infile:
286 | obj = pickle.load(infile)
287 | return obj
288 |
289 |
290 | def _load_bin_vec(fname, vocab):
291 | """
292 | Loads 300x1 word vecs from Google (Mikolov) word2vec
293 |
294 | Original taken from
295 | https://github.com/yuhaozhang/sentence-convnet/blob/master/text_input.py
296 | """
297 | word_vecs = {}
298 | with open(fname, "rb") as f:
299 | header = f.readline()
300 | vocab_size, layer1_size = map(int, header.split())
301 | binary_len = np.dtype('float32').itemsize * layer1_size
302 | for line in xrange(vocab_size):
303 | word = []
304 | while True:
305 | ch = f.read(1)
306 | if ch == ' ':
307 | word = ''.join(word)
308 | break
309 | if ch != '\n':
310 | word.append(ch)
311 | if word in vocab:
312 | word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
313 | else:
314 | f.read(binary_len)
315 | return (word_vecs, layer1_size)
316 |
317 |
318 | def _add_random_vec(word_vecs, vocab, emb_size=300):
319 | for word in vocab:
320 | if word not in word_vecs:
321 | word_vecs[word] = np.random.uniform(-0.25,0.25,emb_size)
322 | return word_vecs
323 |
324 |
325 | def prepare_pretrained_embedding(fname, word2id):
326 | print 'Reading pretrained word vectors from file ...'
327 | word_vecs, emb_size = _load_bin_vec(fname, word2id)
328 | word_vecs = _add_random_vec(word_vecs, word2id, emb_size)
329 | embedding = np.zeros([len(word2id), emb_size])
330 | for w,idx in word2id.iteritems():
331 | embedding[idx,:] = word_vecs[w]
332 | print 'Generated embeddings with shape ' + str(embedding.shape)
333 | return embedding
334 |
335 |
336 | def offset(array, pre, post):
337 | ret = np.array(array)
338 | ret = np.insert(ret, 0, pre)
339 | ret = np.append(ret, post)
340 | return ret
341 |
342 |
343 | def calc_auc_pr(precision, recall):
344 | assert len(precision) == len(recall)
345 | return np.trapz(offset(precision, 1, 0), x=offset(recall, 0, 1), dx=5)
346 |
347 |
348 | def prepare_ids(data_dir, vocab_path):
349 | for context in ['left', 'middle', 'right', 'txt']:
350 | data_path = os.path.join(data_dir, 'mlmi', 'source.%s' % context)
351 | target_path = os.path.join(data_dir, 'mlmi', 'ids.%s' % context)
352 | if context == 'left':
353 | bos, eos = True, False
354 | elif context == 'middle':
355 | bos, eos = False, False
356 | elif context == 'right':
357 | bos, eos = False, True
358 | else:
359 | bos, eos = True, True
360 | data_to_token_ids(data_path, target_path, vocab_path, bos=bos, eos=eos)
361 |
362 |
363 | def main():
364 | data_dir = os.path.join(THIS_DIR, 'data')
365 |
366 | # multi-label multi-instance (MLMI-CNN) dataset
367 | vocab_path = os.path.join(data_dir, 'mlmi', 'vocab.txt')
368 | data_path = os.path.join(data_dir, 'mlmi', 'source.txt')
369 | max_vocab_size = 36500
370 | create_vocabulary(vocab_path, data_path, max_vocab_size)
371 | prepare_ids(data_dir, vocab_path)
372 |
373 | # pretrained embeddings
374 | embedding_path = os.path.join(THIS_DIR, 'word2vec', 'GoogleNews-vectors-negative300.bin')
375 | if os.path.exists(embedding_path):
376 | word2id, _ = initialize_vocabulary(vocab_path)
377 | embedding = prepare_pretrained_embedding(embedding_path, word2id)
378 | np.save(os.path.join(data_dir, 'mlmi', 'emb.npy'), embedding)
379 | else:
380 | print "Pretrained embeddings file %s not found." % embedding_path
381 |
382 | # single-label single-instance (ER-CNN) dataset
383 | vocab_er = os.path.join(data_dir, 'er', 'vocab.txt')
384 | data_er = os.path.join(data_dir, 'er', 'source.txt')
385 | target_er = os.path.join(data_dir, 'er', 'ids.txt')
386 | max_vocab_size = 11500
387 | tokenizer = lambda x: x.split()
388 | create_vocabulary(vocab_er, data_er, max_vocab_size, tokenizer=tokenizer)
389 | data_to_token_ids(data_er, target_er, vocab_er, tokenizer=tokenizer)
390 |
391 |
392 | if __name__ == '__main__':
393 | main()
394 |
--------------------------------------------------------------------------------
/word2vec/.gitignore:
--------------------------------------------------------------------------------
1 | GoogleNews-vectors-negative300.bin
2 |
--------------------------------------------------------------------------------