├── .gitignore ├── README.md ├── cnn.py ├── cnn_context.py ├── data ├── er │ ├── source.txt │ └── target.txt ├── gold_patterns.tsv ├── mlmi │ ├── source.att │ ├── source.left │ ├── source.middle │ ├── source.right │ ├── source.txt │ └── target.txt ├── negative_candidates.tsv ├── negative_relations.tsv ├── positive_candidates.tsv └── positive_relations.tsv ├── distant_supervision.py ├── eval.py ├── img ├── auc.png ├── cnn.png ├── emb_er.png ├── emb_left.png ├── emb_mlmi.png ├── emb_right.png ├── f1.png ├── loss.png └── pr_curve.png ├── train.py ├── train_context.py ├── util.py ├── visualize.ipynb └── word2vec └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | 4 | *.pyc 5 | .ipynb_checkpoints/ 6 | 7 | data/*.cPickle 8 | data/candidates*.tsv 9 | data/*/ids.* 10 | data/*/vocab.* 11 | data/*/emb.npy 12 | 13 | models/ 14 | train/ 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Neural Network for Relation Extraction 2 | 3 | **Note:** This project is mostly based on https://github.com/yuhaozhang/sentence-convnet 4 | 5 | --- 6 | 7 | 8 | ## Requirements 9 | 10 | - [Python 2.7](https://www.python.org/) 11 | - [Tensorflow](https://www.tensorflow.org/) (tested with version 0.10.0rc0 -> 1.0.1) 12 | - [Numpy](http://www.numpy.org/) 13 | 14 | To download wikipedia articles (`distant_supervision.py`) 15 | 16 | - [Beautifulsoup](https://www.crummy.com/software/BeautifulSoup/bs4/doc/) 17 | - [Pandas](http://pandas.pydata.org/) 18 | - [Stanford NER](http://nlp.stanford.edu/software/CRF-NER.shtml) 19 | *Path to Stanford-NER is specified in `ner_path` variable in `distant_supervision.py` 20 | 21 | To visualize the results (`visualize.ipynb`) 22 | 23 | - [Matplotlib](https://matplotlib.org/) 24 | - [Scikit-learn](http://scikit-learn.org/) 25 | 26 | 27 | ## Data 28 | - `data` directory includes preprocessed data: 29 | ``` 30 | cnn-re-tf 31 | ├── ... 32 | ├── word2vec 33 | └── data 34 | ├── er # binay-classification dataset 35 | │   ├── source.txt # source sentences 36 | │   └── target.txt # target labels 37 | └── mlmi # multi-label multi-instance dataset 38 | ├── source.att # attention 39 | ├── source.left # left context 40 | ├── source.middle # middle context 41 | ├── source.right # right context 42 | ├── source.txt # source sentences 43 | └── target.txt # target labels 44 | ``` 45 | To reproduce: 46 | ``` 47 | python ./distant_supervision.py 48 | ``` 49 | 50 | - `word2vec` directory is empty. Please download the Google News pretrained vector data from 51 | [this Google Drive link](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit), 52 | and unzip it to the directory. It will be a `.bin` file. 53 | 54 | 55 | 56 | ## Usage 57 | ### Preprocess 58 | 59 | ```sh 60 | python ./util.py 61 | ``` 62 | It creates `vocab.txt`, `ids.txt` and `emb.npy` files. 63 | 64 | ### Training 65 | 66 | - Binary classification (ER-CNN): 67 | ```sh 68 | python ./train.py --sent_len=3 --vocab_size=11208 --num_classes=2 --train_size=15000 \ 69 | --data_dir=./data/er --attention=False --multi_label=False --use_pretrain=False 70 | ``` 71 | 72 | - Multi-label multi-instance learning (MLMI-CNN): 73 | ```sh 74 | python ./train.py --sent_len=255 --vocab_size=36112 --num_classes=23 --train_size=10000 \ 75 | --data_dir=./data/mlmi --attention=True --multi_label=True --use_pretrain=True 76 | ``` 77 | 78 | - Multi-label multi-instance Context-wise learning (MLMI-CONT): 79 | ```sh 80 | python ./train_context.py --sent_len=102 --vocab_size=36112 --num_classes=23 --train_size=10000 \ 81 | --data_dir=./data/mlmi --attention=True --multi_label=True --use_pretrain=True 82 | ``` 83 | 84 | **Caution:** A wrong value for input-data-dependent options (`sent_len`, `vocab_size` and `num_class`) 85 | may cause an error. If you want to train the model on another dataset, please check these values. 86 | 87 | 88 | ### Evaluation 89 | 90 | ```sh 91 | python ./eval.py --train_dir=./train/1473898241 92 | ``` 93 | Replace the `--train_dir` with the output from the training. 94 | 95 | 96 | ### Run TensorBoard 97 | 98 | ```sh 99 | tensorboard --logdir=./train/1473898241 100 | ``` 101 | 102 | 103 | ## Architecture 104 | 105 | ![CNN Architecture](img/cnn.png) 106 | 107 | 108 | ## Results 109 | 110 | | | P | R | F | AUC |init_lr|l2_reg| 111 | |--------:|:----:|:----:|:----:|:----:|------:|-----:| 112 | | ER-CNN |0.9410|0.8630|0.9003|0.9303| 0.005| 0.05| 113 | | MLMI-CNN|0.8205|0.6406|0.7195|0.7424| 1e-3| 1e-4| 114 | |MLMI-CONT|0.8819|0.7158|0.7902|0.8156| 1e-3| 1e-4| 115 | 116 | ![F1](img/f1.png) 117 | ![AUC](img/auc.png) 118 | ![Loss](img/loss.png) 119 | ![PR_Curve](img/pr_curve.png) 120 | ![ER-CNN Embeddings](img/emb_er.png) 121 | ![MLMI-CNN Embeddings](img/emb_mlmi.png) 122 | ![MLMI-CONT Left Embeddings](img/emb_left.png) 123 | ![MLMI-CONT Right Embeddings](img/emb_right.png) 124 | 125 | *As you see above, these models somewhat suffer from overfitting ... 126 | 127 | 128 | ## References 129 | 130 | * http://github.com/yuhaozhang/sentence-convnet 131 | * http://github.com/dennybritz/cnn-text-classification-tf 132 | * http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/ 133 | * http://tkengo.github.io/blog/2016/03/14/text-classification-by-cnn/ 134 | * Adel et al. [Comparing Convolutional Neural Networks to Traditional Models for Slot Filling](http://arxiv.org/abs/1603.05157) NAACL 2016 135 | * Nguyen and Grishman. [Relation Extraction: Perspective from Convolutional Neural Networks](http://www.cs.nyu.edu/~thien/pubs/vector15.pdf) NAACL 2015 136 | * Lin et al. [Neural Relation Extraction with Selective Attention over Instances](http://www.aclweb.org/anthology/P/P16/P16-1200.pdf) ACL 2016 137 | -------------------------------------------------------------------------------- /cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ########################################################## 4 | # 5 | # Attention-based Convolutional Neural Network 6 | # for Multi-label Multi-instance Learning 7 | # 8 | # 9 | # Note: this implementation is mostly based on 10 | # https://github.com/yuhaozhang/sentence-convnet/blob/master/model.py 11 | # 12 | ########################################################## 13 | 14 | import tensorflow as tf 15 | 16 | # model parameters 17 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Training batch size') 18 | tf.app.flags.DEFINE_integer('emb_size', 300, 'Size of word embeddings') 19 | tf.app.flags.DEFINE_integer('num_kernel', 100, 'Number of filters for each window size') 20 | tf.app.flags.DEFINE_integer('min_window', 3, 'Minimum size of filter window') 21 | tf.app.flags.DEFINE_integer('max_window', 5, 'Maximum size of filter window') 22 | tf.app.flags.DEFINE_integer('vocab_size', 40000, 'Vocabulary size') 23 | tf.app.flags.DEFINE_integer('num_classes', 10, 'Number of class to consider') 24 | tf.app.flags.DEFINE_integer('sent_len', 400, 'Input sentence length.') 25 | tf.app.flags.DEFINE_float('l2_reg', 1e-4, 'l2 regularization weight') 26 | tf.app.flags.DEFINE_boolean('attention', False, 'Whether use attention or not') 27 | tf.app.flags.DEFINE_boolean('multi_label', False, 'Multilabel or not') 28 | 29 | 30 | def _variable_on_cpu(name, shape, initializer): 31 | with tf.device('/cpu:0'): 32 | var = tf.get_variable(name, shape, initializer=initializer) 33 | return var 34 | 35 | def _variable_with_weight_decay(name, shape, initializer, wd): 36 | var = _variable_on_cpu(name, shape, initializer) 37 | if wd is not None and wd != 0.: 38 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 39 | else: 40 | weight_decay = tf.constant(0.0, dtype=tf.float32) 41 | return var, weight_decay 42 | 43 | 44 | def _auc_pr(true, prob, threshold): 45 | pred = tf.where(prob > threshold, tf.ones_like(prob), tf.zeros_like(prob)) 46 | tp = tf.logical_and(tf.cast(pred, tf.bool), tf.cast(true, tf.bool)) 47 | fp = tf.logical_and(tf.cast(pred, tf.bool), tf.logical_not(tf.cast(true, tf.bool))) 48 | fn = tf.logical_and(tf.logical_not(tf.cast(pred, tf.bool)), tf.cast(true, tf.bool)) 49 | pre = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)), tf.reduce_sum(tf.cast(tf.logical_or(tp, fp), tf.int32))) 50 | rec = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)), tf.reduce_sum(tf.cast(tf.logical_or(tp, fn), tf.int32))) 51 | return pre, rec 52 | 53 | 54 | class Model(object): 55 | 56 | def __init__(self, config, is_train=True): 57 | self.is_train = is_train 58 | self.emb_size = config['emb_size'] 59 | self.batch_size = config['batch_size'] 60 | self.num_kernel = config['num_kernel'] 61 | self.min_window = config['min_window'] 62 | self.max_window = config['max_window'] 63 | self.vocab_size = config['vocab_size'] 64 | self.num_classes = config['num_classes'] 65 | self.sent_len = config['sent_len'] 66 | self.l2_reg = config['l2_reg'] 67 | self.multi_instance = config['attention'] 68 | self.multi_label = config['multi_label'] 69 | if is_train: 70 | self.optimizer = config['optimizer'] 71 | self.dropout = config['dropout'] 72 | self.build_graph() 73 | 74 | def build_graph(self): 75 | """ Build the computation graph. """ 76 | self._inputs = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_x') 77 | self._labels = tf.placeholder(dtype=tf.float32, shape=[None, self.num_classes], name='input_y') 78 | self._attention = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='attention') 79 | losses = [] 80 | 81 | # lookup layer 82 | with tf.variable_scope('embedding') as scope: 83 | self._W_emb = _variable_on_cpu(name='embedding', shape=[self.vocab_size, self.emb_size], 84 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0)) 85 | # sent_batch is of shape: (batch_size, sent_len, emb_size, 1), in order to use conv2d 86 | sent_batch = tf.nn.embedding_lookup(params=self._W_emb, ids=self._inputs) 87 | sent_batch = tf.expand_dims(sent_batch, -1) 88 | 89 | # conv + pooling layer 90 | pool_tensors = [] 91 | for k_size in range(self.min_window, self.max_window+1): 92 | with tf.variable_scope('conv-%d' % k_size) as scope: 93 | kernel, wd = _variable_with_weight_decay( 94 | name='kernel-%d' % k_size, 95 | shape=[k_size, self.emb_size, 1, self.num_kernel], 96 | initializer=tf.truncated_normal_initializer(stddev=0.01), 97 | wd=self.l2_reg) 98 | losses.append(wd) 99 | conv = tf.nn.conv2d(input=sent_batch, filter=kernel, strides=[1,1,1,1], padding='VALID') 100 | biases = _variable_on_cpu(name='bias-%d' % k_size, 101 | shape=[self.num_kernel], 102 | initializer=tf.constant_initializer(0.0)) 103 | bias = tf.nn.bias_add(conv, biases) 104 | activation = tf.nn.relu(bias, name=scope.name) 105 | # shape of activation: [batch_size, conv_len, 1, num_kernel] 106 | conv_len = activation.get_shape()[1] 107 | pool = tf.nn.max_pool(activation, ksize=[1,conv_len,1,1], strides=[1,1,1,1], padding='VALID') 108 | # shape of pool: [batch_size, 1, 1, num_kernel] 109 | pool_tensors.append(pool) 110 | 111 | # Combine all pooled tensors 112 | num_filters = self.max_window - self.min_window + 1 113 | pool_size = num_filters * self.num_kernel 114 | pool_layer = tf.concat(pool_tensors, num_filters, name='pool') 115 | pool_flat = tf.reshape(pool_layer, [-1, pool_size]) 116 | 117 | # drop out layer 118 | if self.is_train and self.dropout > 0: 119 | pool_dropout = tf.nn.dropout(pool_flat, 1 - self.dropout) 120 | else: 121 | pool_dropout = pool_flat 122 | 123 | # fully-connected layer 124 | with tf.variable_scope('output') as scope: 125 | W, wd = _variable_with_weight_decay('W', shape=[pool_size, self.num_classes], 126 | initializer=tf.truncated_normal_initializer(stddev=0.05), 127 | wd=self.l2_reg) 128 | losses.append(wd) 129 | biases = _variable_on_cpu('bias', shape=[self.num_classes], 130 | initializer=tf.constant_initializer(0.01)) 131 | self.logits = tf.nn.bias_add(tf.matmul(pool_dropout, W), biases, name='logits') 132 | 133 | # loss 134 | with tf.variable_scope('loss') as scope: 135 | if self.multi_label: 136 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self._labels, 137 | name='cross_entropy_per_example') 138 | else: 139 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self._labels, 140 | name='cross_entropy_per_example') 141 | 142 | if self.is_train and self.multi_instance: # apply attention 143 | cross_entropy_loss = tf.reduce_sum(tf.multiply(cross_entropy, self._attention), 144 | name='cross_entropy_loss') 145 | else: 146 | cross_entropy_loss = tf.reduce_mean(cross_entropy, name='cross_entropy_loss') 147 | 148 | losses.append(cross_entropy_loss) 149 | self._total_loss = tf.add_n(losses, name='total_loss') 150 | 151 | # eval with precision-recall 152 | with tf.variable_scope('evaluation') as scope: 153 | precision = [] 154 | recall = [] 155 | for threshold in range(10, -1, -1): 156 | pre, rec = _auc_pr(self._labels, tf.sigmoid(self.logits), threshold * 0.1) 157 | precision.append(pre) 158 | recall.append(rec) 159 | self._eval_op = zip(precision, recall) 160 | 161 | # f1 score on threshold=0.5 162 | #self._f1_score = tf.truediv(tf.mul(tf.constant(2.0, dtype=tf.float64), 163 | # tf.mul(precision[5], recall[5])), tf.add(precision, recall)) 164 | 165 | # train on a batch 166 | self._lr = tf.Variable(0.0, trainable=False) 167 | if self.is_train: 168 | if self.optimizer == 'adadelta': 169 | opt = tf.train.AdadeltaOptimizer(self._lr) 170 | elif self.optimizer == 'adagrad': 171 | opt = tf.train.AdagradOptimizer(self._lr) 172 | elif self.optimizer == 'adam': 173 | opt = tf.train.AdamOptimizer(self._lr) 174 | elif self.optimizer == 'sgd': 175 | opt = tf.train.GradientDescentOptimizer(self._lr) 176 | else: 177 | raise ValueError("Optimizer not supported.") 178 | grads = opt.compute_gradients(self._total_loss) 179 | self._train_op = opt.apply_gradients(grads) 180 | 181 | for var in tf.trainable_variables(): 182 | tf.summary.histogram(var.op.name, var) 183 | else: 184 | self._train_op = tf.no_op() 185 | 186 | return 187 | 188 | @property 189 | def inputs(self): 190 | return self._inputs 191 | 192 | @property 193 | def labels(self): 194 | return self._labels 195 | 196 | @property 197 | def attention(self): 198 | return self._attention 199 | 200 | @property 201 | def lr(self): 202 | return self._lr 203 | 204 | @property 205 | def train_op(self): 206 | return self._train_op 207 | 208 | @property 209 | def total_loss(self): 210 | return self._total_loss 211 | 212 | @property 213 | def eval_op(self): 214 | return self._eval_op 215 | 216 | @property 217 | def scores(self): 218 | return tf.sigmoid(self.logits) 219 | 220 | @property 221 | def W_emb(self): 222 | return self._W_emb 223 | 224 | def assign_lr(self, session, lr_value): 225 | session.run(tf.assign(self.lr, lr_value)) 226 | 227 | def assign_embedding(self, session, pretrained): 228 | session.run(tf.assign(self.W_emb, pretrained)) 229 | -------------------------------------------------------------------------------- /cnn_context.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ########################################################## 4 | # 5 | # Attention-based Convolutional Neural Network 6 | # for Context-wise Learning 7 | # 8 | # 9 | # Note: this implementation is mostly based on 10 | # https://github.com/yuhaozhang/sentence-convnet/blob/master/model.py 11 | # 12 | ########################################################## 13 | 14 | import tensorflow as tf 15 | 16 | # model parameters 17 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Training batch size') 18 | tf.app.flags.DEFINE_integer('emb_size', 300, 'Size of word embeddings') 19 | tf.app.flags.DEFINE_integer('num_kernel', 100, 'Number of filters for each window size') 20 | tf.app.flags.DEFINE_integer('min_window', 3, 'Minimum size of filter window') 21 | tf.app.flags.DEFINE_integer('max_window', 5, 'Maximum size of filter window') 22 | tf.app.flags.DEFINE_integer('vocab_size', 40000, 'Vocabulary size') 23 | tf.app.flags.DEFINE_integer('num_classes', 10, 'Number of class to consider') 24 | tf.app.flags.DEFINE_integer('sent_len', 400, 'Input sentence length.') 25 | tf.app.flags.DEFINE_float('l2_reg', 1e-4, 'l2 regularization weight') 26 | tf.app.flags.DEFINE_boolean('attention', False, 'Whether use attention or not') 27 | tf.app.flags.DEFINE_boolean('multi_label', False, 'Multilabel or not') 28 | 29 | 30 | def _variable_on_cpu(name, shape, initializer): 31 | with tf.device('/cpu:0'): 32 | var = tf.get_variable(name, shape, initializer=initializer) 33 | return var 34 | 35 | 36 | def _variable_with_weight_decay(name, shape, initializer, wd): 37 | var = _variable_on_cpu(name, shape, initializer) 38 | if wd is not None and wd != 0.: 39 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 40 | else: 41 | weight_decay = tf.constant(0.0, dtype=tf.float32) 42 | return var, weight_decay 43 | 44 | 45 | def _auc_pr(true, prob, threshold): 46 | pred = tf.where(prob > threshold, tf.ones_like(prob), tf.zeros_like(prob)) 47 | tp = tf.logical_and(tf.cast(pred, tf.bool), tf.cast(true, tf.bool)) 48 | fp = tf.logical_and(tf.cast(pred, tf.bool), tf.logical_not(tf.cast(true, tf.bool))) 49 | fn = tf.logical_and(tf.logical_not(tf.cast(pred, tf.bool)), tf.cast(true, tf.bool)) 50 | pre = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)), 51 | tf.reduce_sum(tf.cast(tf.logical_or(tp, fp), tf.int32))) 52 | rec = tf.truediv(tf.reduce_sum(tf.cast(tp, tf.int32)), 53 | tf.reduce_sum(tf.cast(tf.logical_or(tp, fn), tf.int32))) 54 | return pre, rec 55 | 56 | 57 | class Model(object): 58 | 59 | def __init__(self, config, is_train=True): 60 | self.is_train = is_train 61 | self.emb_size = config['emb_size'] 62 | self.batch_size = config['batch_size'] 63 | self.num_kernel = config['num_kernel'] 64 | self.min_window = config['min_window'] 65 | self.max_window = config['max_window'] 66 | self.vocab_size = config['vocab_size'] 67 | self.num_classes = config['num_classes'] 68 | self.sent_len = config['sent_len'] 69 | self.l2_reg = config['l2_reg'] 70 | self.multi_instance = config['attention'] 71 | self.multi_label = config['multi_label'] 72 | if is_train: 73 | self.optimizer = config['optimizer'] 74 | self.dropout = config['dropout'] 75 | self.build_graph() 76 | 77 | def conv_layer(self, input, context): 78 | pool_tensors = [] 79 | losses = [] 80 | for k_size in range(self.min_window, self.max_window+1): 81 | with tf.variable_scope('conv-%d-%s' % (k_size, context)) as scope: 82 | kernel, wd = _variable_with_weight_decay( 83 | name='kernel-%d-%s' % (k_size, context), 84 | shape=[k_size, self.emb_size, 1, self.num_kernel], 85 | initializer=tf.truncated_normal_initializer(stddev=0.01), 86 | wd=self.l2_reg) 87 | losses.append(wd) 88 | conv = tf.nn.conv2d(input=input, filter=kernel, strides=[1,1,1,1], padding='VALID') 89 | biases = _variable_on_cpu('bias-%d-%s' % (k_size, context), 90 | [self.num_kernel], tf.constant_initializer(0.0)) 91 | bias = tf.nn.bias_add(conv, biases) 92 | activation = tf.nn.relu(bias, name=scope.name) 93 | # shape of activation: [batch_size, conv_len, 1, num_kernel] 94 | conv_len = activation.get_shape()[1] 95 | pool = tf.nn.max_pool(activation, ksize=[1,conv_len,1,1], strides=[1,1,1,1], padding='VALID') 96 | # shape of pool: [batch_size, 1, 1, num_kernel] 97 | pool_tensors.append(pool) 98 | 99 | # Combine pooled tensors 100 | num_filters = self.max_window - self.min_window + 1 101 | pool_size = num_filters * self.num_kernel # 300 102 | pool_layer = tf.concat(pool_tensors, num_filters, name='pool-%s' % context) 103 | pool_flat = tf.reshape(pool_layer, [-1, pool_size]) 104 | 105 | return losses, pool_flat 106 | 107 | def build_graph(self): 108 | """ Build the computation graph. """ 109 | self._left = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_left') 110 | self._middle = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_middle') 111 | self._right = tf.placeholder(dtype=tf.int64, shape=[None, self.sent_len], name='input_right') 112 | self._labels = tf.placeholder(dtype=tf.float32, shape=[None, self.num_classes], name='input_y') 113 | self._attention = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='attention') 114 | losses = [] 115 | 116 | with tf.variable_scope('embedding-left') as scope: 117 | self._W_emb_left = _variable_on_cpu(name=scope.name, shape=[self.vocab_size, self.emb_size], 118 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0)) 119 | sent_batch_left = tf.nn.embedding_lookup(params=self._W_emb_left, ids=self._left) 120 | input_left = tf.expand_dims(sent_batch_left, -1) 121 | 122 | with tf.variable_scope('embedding-middle') as scope: 123 | self._W_emb_middle = _variable_on_cpu(name=scope.name, shape=[self.vocab_size, self.emb_size], 124 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0)) 125 | sent_batch_middle = tf.nn.embedding_lookup(params=self._W_emb_middle, ids=self._middle) 126 | input_middle = tf.expand_dims(sent_batch_middle, -1) 127 | 128 | with tf.variable_scope('embedding-right') as scope: 129 | self._W_emb_right = _variable_on_cpu(name=scope.name, shape=[self.vocab_size, self.emb_size], 130 | initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0)) 131 | sent_batch_right = tf.nn.embedding_lookup(params=self._W_emb_right, ids=self._right) 132 | input_right = tf.expand_dims(sent_batch_right, -1) 133 | 134 | # conv + pooling layer 135 | contexts = [] 136 | for contextwise_input, context in zip([input_left, input_middle, input_right], 137 | ['left', 'middle', 'right']): 138 | conv_losses, pool_flat = self.conv_layer(contextwise_input, context) 139 | losses.extend(conv_losses) 140 | contexts.append(pool_flat) 141 | # Combine context tensors 142 | num_filters = self.max_window - self.min_window + 1 143 | pool_size = num_filters * self.num_kernel # 300 144 | concat_context = tf.concat(contexts, 1, name='concat') 145 | flat_context = tf.reshape(concat_context, [-1, pool_size*3]) 146 | 147 | # drop out layer 148 | if self.is_train and self.dropout > 0: 149 | pool_dropout = tf.nn.dropout(flat_context, 1 - self.dropout) 150 | else: 151 | pool_dropout = flat_context 152 | 153 | # fully-connected layer 154 | with tf.variable_scope('output') as scope: 155 | W, wd = _variable_with_weight_decay('W', shape=[pool_size*3, self.num_classes], 156 | initializer=tf.truncated_normal_initializer(stddev=0.05), wd=self.l2_reg) 157 | losses.append(wd) 158 | biases = _variable_on_cpu('bias', shape=[self.num_classes], 159 | initializer=tf.constant_initializer(0.01)) 160 | self.logits = tf.nn.bias_add(tf.matmul(pool_dropout, W), biases, name='logits') 161 | 162 | # loss 163 | with tf.variable_scope('loss') as scope: 164 | if self.multi_label: 165 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self._labels, 166 | name='cross_entropy_per_example') 167 | else: 168 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self._labels, 169 | name='cross_entropy_per_example') 170 | 171 | if self.is_train and self.multi_instance: # apply attention 172 | cross_entropy_loss = tf.reduce_sum(tf.multiply(cross_entropy, self._attention), 173 | name='cross_entropy_loss') 174 | else: 175 | cross_entropy_loss = tf.reduce_mean(cross_entropy, name='cross_entropy_loss') 176 | 177 | losses.append(cross_entropy_loss) 178 | self._total_loss = tf.add_n(losses, name='total_loss') 179 | 180 | # eval with auc-pr metric 181 | with tf.variable_scope('evaluation') as scope: 182 | precision = [] 183 | recall = [] 184 | for threshold in range(10, -1, -1): 185 | pre, rec = _auc_pr(self._labels, tf.sigmoid(self.logits), threshold * 0.1) 186 | precision.append(pre) 187 | recall.append(rec) 188 | self._eval_op = zip(precision, recall) 189 | 190 | # train on a batch 191 | self._lr = tf.Variable(0.0, trainable=False) 192 | if self.is_train: 193 | if self.optimizer == 'adadelta': 194 | opt = tf.train.AdadeltaOptimizer(self._lr) 195 | elif self.optimizer == 'adagrad': 196 | opt = tf.train.AdagradOptimizer(self._lr) 197 | elif self.optimizer == 'adam': 198 | opt = tf.train.AdamOptimizer(self._lr) 199 | elif self.optimizer == 'sgd': 200 | opt = tf.train.GradientDescentOptimizer(self._lr) 201 | else: 202 | raise ValueError("Optimizer not supported.") 203 | grads = opt.compute_gradients(self._total_loss) 204 | self._train_op = opt.apply_gradients(grads) 205 | 206 | for var in tf.trainable_variables(): 207 | tf.summary.histogram(var.op.name, var) 208 | else: 209 | self._train_op = tf.no_op() 210 | 211 | return 212 | 213 | @property 214 | def left(self): 215 | return self._left 216 | 217 | @property 218 | def middle(self): 219 | return self._middle 220 | 221 | @property 222 | def right(self): 223 | return self._right 224 | 225 | @property 226 | def labels(self): 227 | return self._labels 228 | 229 | @property 230 | def attention(self): 231 | return self._attention 232 | 233 | @property 234 | def lr(self): 235 | return self._lr 236 | 237 | @property 238 | def train_op(self): 239 | return self._train_op 240 | 241 | @property 242 | def total_loss(self): 243 | return self._total_loss 244 | 245 | @property 246 | def eval_op(self): 247 | return self._eval_op 248 | 249 | @property 250 | def scores(self): 251 | return tf.sigmoid(self.logits) 252 | 253 | @property 254 | def W_emb_left(self): 255 | return self._W_emb_left 256 | 257 | @property 258 | def W_emb_middle(self): 259 | return self._W_emb_middle 260 | 261 | @property 262 | def W_emb_right(self): 263 | return self._W_emb_right 264 | 265 | def assign_lr(self, session, lr_value): 266 | session.run(tf.assign(self.lr, lr_value)) 267 | 268 | def assign_embedding(self, session, pretrained): 269 | session.run(tf.assign(self.W_emb_left, pretrained)) 270 | session.run(tf.assign(self.W_emb_middle, pretrained)) 271 | session.run(tf.assign(self.W_emb_right, pretrained)) 272 | -------------------------------------------------------------------------------- /data/gold_patterns.tsv: -------------------------------------------------------------------------------- 1 | # Original taken from 2 | # https://github.com/beroth/relationfactory/blob/master/resources/manual_annotation/context_patterns2012.txt 3 | 4 | 5 | #per:alternate_names $ARG1 formerly known as $ARG2 6 | #per:alternate_names $ARG1 aka $ARG2 7 | #per:alternate_names $ARG1 a.k.a. $ARG2 8 | #per:alternate_names $ARG1 a. k. a. $ARG2 9 | #per:alternate_names $ARG1 is also known as $ARG2 10 | #per:alternate_names $ARG1 is known as $ARG2 11 | #per:alternate_names $ARG1 , also known as $ARG2 12 | #per:alternate_names $ARG1 , better known as $ARG2 13 | #per:alternate_names $ARG1 , best known as $ARG2 14 | #per:alternate_names $ARG1 , aka $ARG2 15 | #per:alternate_names $ARG1 , a.k.a. $ARG2 16 | #per:alternate_names $ARG1 , a. k. a. $ARG2 17 | #per:alternate_names $ARG1 ( also known as $ARG2 ) 18 | #per:alternate_names $ARG1 ( better known as $ARG2 ) 19 | # This overgeneralizes too much (e.g. actors playing roles). 20 | #per:alternate_names $ARG1 ( $ARG2 ) 21 | #per:alternate_names $ARG1 ( aka $ARG2 ) 22 | #per:alternate_names $ARG1 ( a.k.a. $ARG2 ) 23 | #per:alternate_names $ARG1 ( a. k. a. $ARG2 ) 24 | 25 | #org:alternate_names $ARG1 formerly known as $ARG2 26 | #org:alternate_names $ARG1 aka $ARG2 27 | #org:alternate_names $ARG1 a.k.a. $ARG2 28 | #org:alternate_names $ARG1 is also known as $ARG2 29 | #org:alternate_names $ARG1 is known as $ARG2 30 | #org:alternate_names $ARG1 , also known as $ARG2 31 | #org:alternate_names $ARG1 , better known as $ARG2 32 | #org:alternate_names $ARG1 , best known as $ARG2 33 | #org:alternate_names $ARG1 ( also known as $ARG2 ) 34 | #org:alternate_names $ARG1 ( better known as $ARG2 ) 35 | #org:alternate_names $ARG1 formerly known as `` $ARG2 36 | #org:alternate_names $ARG1 aka `` $ARG2 37 | #org:alternate_names $ARG1 a.k.a. `` $ARG2 38 | #org:alternate_names $ARG1 is also known as `` $ARG2 39 | #org:alternate_names $ARG1 is known as `` $ARG2 40 | #org:alternate_names $ARG1 , also known as `` $ARG2 41 | #org:alternate_names $ARG1 , better known as `` $ARG2 42 | #org:alternate_names $ARG1 , best known as `` $ARG2 43 | #org:alternate_names $ARG1 ( also known as `` $ARG2 '' ) 44 | #org:alternate_names $ARG1 ( better known as `` $ARG2 '' ) 45 | 46 | # per:country_of_birth -> P19 47 | P19 $ARG1 , the $ARG2-born 48 | P19 $ARG1 , $ARG2-born 49 | P19 $ARG1 was born in $ARG2 50 | P19 $ARG1 is born in $ARG2 51 | P19 $ARG1 , born in $ARG2 52 | P19 $ARG1 , being born in $ARG2 53 | P19 $ARG1 is a native $ARG2 54 | P19 $ARG1 , a native $ARG2 55 | P19 $ARG1 was born * in $ARG2 56 | P19 $ARG1 was born * in * , $ARG2 57 | P19 $ARG1 , born in * , $ARG2 58 | P19 $ARG1 , born * in * , $ARG2 59 | P19 $ARG1 is a native of $ARG2 60 | P19 $ARG1 , a native of $ARG2 61 | P19 $ARG2 born $ARG1 62 | P19 $ARG2 -born $ARG1 63 | P19 $ARG2 - born $ARG1 64 | 65 | # per:origin -> P27 country of citizenship 66 | P27 $ARG1 , a $ARG2 citizen 67 | P27 $ARG1 , an $ARG2 citizen 68 | P27 $ARG1 , who is a $ARG2 citizen 69 | P27 $ARG1 , who is an $ARG2 citizen 70 | P27 $ARG1 is a $ARG2 citizen 71 | P27 $ARG1 , originally from $ARG2 72 | P27 $ARG1 from $ARG2 73 | P27 $ARG2 citizen $ARG1 74 | P27 $ARG1 , born in $ARG2 75 | P27 $ARG1 was born in $ARG2 76 | 77 | # per:country_of_death -> P20 78 | P20 $ARG1 died in $ARG2 79 | P20 $ARG1 was killed in $ARG2 80 | P20 $ARG1 succumbed to * in $ARG2 81 | P20 $ARG1 passed away $ARG2 82 | P20 $ARG1 who died in $ARG2 83 | P20 $ARG1 who was killed in $ARG2 84 | P20 $ARG1 who succumbed to * in $ARG2 85 | P20 $ARG1 who passed away $ARG2 86 | P20 $ARG1 , who died in $ARG2 87 | P20 $ARG1 , who was killed in $ARG2 88 | P20 $ARG1 , who succumbed to * in $ARG2 89 | P20 $ARG1 , who passed away $ARG2 90 | P20 $ARG1 died * in $ARG2 91 | P20 $ARG1 , who died * in $ARG2 92 | P20 $ARG1 died * in * , $ARG2 93 | P20 $ARG1 , who died * in * , $ARG2 94 | 95 | 96 | # per:countries_of_residence -> P551 97 | # TODO: include frequent 'born-in' patterns 98 | P551 $ARG1 grew up in $ARG2 99 | P551 $ARG1 grew up in * , $ARG2 100 | P551 $ARG1 lives in $ARG2 101 | P551 $ARG1 lived in $ARG2 102 | P551 $ARG1 lives in * $ARG2 103 | P551 $ARG1 lived in * $ARG2 104 | P551 $ARG1 lives in * , $ARG2 105 | P551 $ARG1 lived in * , $ARG2 106 | P551 $ARG1 moves to $ARG2 107 | P551 $ARG1 moved to $ARG2 108 | P551 $ARG1 moves to * $ARG2 109 | P551 $ARG1 moved to * $ARG2 110 | P551 $ARG1 moves to * , $ARG2 111 | P551 $ARG1 moved to * , $ARG2 112 | P551 $ARG1 was raised in $ARG2 113 | P551 $ARG1 has been living in $ARG2 114 | P551 $ARG1 resides in $ARG2 115 | P551 $ARG1 resided in $ARG2 116 | P551 $ARG1 's house in $ARG2 117 | P551 $ARG1 's home in $ARG2 118 | P551 $ARG1 has a house in $ARG2 119 | P551 $ARG1 immigrated to $ARG2 120 | P551 $ARG1 spent his childhood in $ARG2 121 | P551 $ARG1 , born in $ARG2 122 | P551 $ARG1 was born in $ARG2 123 | 124 | P551 $ARG2 citizen $ARG1 125 | P551 $ARG2 , home of $ARG1 126 | P551 $ARG2 , hometown of $ARG1 127 | P551 $ARG2 is the hometown of $ARG1 128 | P551 $ARG2 citizen , $ARG1 129 | P551 $ARG1 is a $ARG2 citizen 130 | P551 $ARG1 is an $ARG2 citizen 131 | P551 $ARG1, a $ARG2 citizen 132 | P551 $ARG1, an $ARG2 citizen 133 | 134 | # capital of -> P1376 135 | P1376 $ARG1 , the capital of $ARG2 136 | P1376 $ARG1 is the capital of $ARG2 137 | P1376 $ARG1 was the capital of $ARG2 138 | P1376 $ARG1 , the * capital of $ARG2 139 | P1376 $ARG1 is the * capital of $ARG2 140 | P1376 $ARG1 was the * capital of $ARG2 141 | P1376 $ARG1 , the capital of * $ARG2 142 | P1376 $ARG1 is the capital of * $ARG2 143 | P1376 $ARG1 was the capital of $ARG2 144 | P1376 $ARG1 , the * capital of * $ARG2 145 | P1376 $ARG1 is the * capital of * $ARG2 146 | P1376 $ARG1 was the * capital of *$ARG2 147 | 148 | # country -> P17 149 | P17 $ARG1 * the state of $ARG2 150 | P17 $ARG1 * the state of * , $ARG2 151 | P17 $ARG1 * the state of * , * $ARG2 152 | P17 $ARG1 * the * state of * , $ARG2 153 | P17 $ARG1 * the * state of * , $ARG2 154 | P17 $ARG1 * the * state of * , * $ARG2 155 | P17 $ARG1 * the province of $ARG2 156 | P17 $ARG1 * the province of * , $ARG2 157 | P17 $ARG1 * the province of * , * $ARG2 158 | P17 $ARG1 * the * province of * , $ARG2 159 | P17 $ARG1 * the * province of * , $ARG2 160 | P17 $ARG1 * the * province of * , * $ARG2 161 | P17 $ARG1 * the prefecture of $ARG2 162 | P17 $ARG1 * the prefecture of * , $ARG2 163 | P17 $ARG1 * the prefecture of * , * $ARG2 164 | P17 $ARG1 * the * prefecture of * , $ARG2 165 | P17 $ARG1 * the * prefecture of * , $ARG2 166 | P17 $ARG1 * the * prefecture of * , * $ARG2 167 | 168 | 169 | # educated at P69 170 | P69 $ARG1 graduated from $ARG2 171 | P69 $ARG1 graduated in * from $ARG2 172 | P69 $ARG1 attended * at $ARG2 173 | P69 $ARG1 attended $ARG2 174 | P69 $ARG1 earned his * at $ARG2 175 | P69 $ARG1 holds a * from $ARG2 176 | P69 $ARG1 holds a * in * from $ARG2 177 | P69 $ARG1 , who holds a * in * from $ARG2 178 | P69 $ARG1 , who holds a * from $ARG2 179 | P69 $ARG1 has a * from $ARG2 180 | P69 $ARG1 has a * in * from $ARG2 181 | P69 $ARG1 , who has a * in * from $ARG2 182 | P69 $ARG1 , who has a * from $ARG2 183 | # This should match the sentence: 184 | #Erraguntla , who has a masters degree in computer engineering from the University of Lousiana 185 | P69 $ARG1 studied at $ARG2 186 | P69 $ARG1 is a student at $ARG2 187 | P69 $ARG1 was a student at $ARG2 188 | P69 $ARG1 is a $ARG2 student 189 | P69 $ARG1, a $ARG2 student 190 | P69 $ARG1 is an $ARG2 student 191 | P69 $ARG1, an $ARG2 student 192 | P69 $ARG1 is a $ARG2 graduate 193 | P69 $ARG1, a $ARG2 graduate 194 | P69 $ARG1 is an $ARG2 graduate 195 | P69 $ARG1, an $ARG2 graduate 196 | P69 $ARG1 is a $ARG2 alumnus 197 | P69 $ARG1, a $ARG2 alumnus 198 | P69 $ARG1 is an $ARG2 alumnus 199 | P69 $ARG1, an $ARG2 alumnus 200 | P69 $ARG2 gradute $ARG1 201 | P69 $ARG2 student $ARG1 202 | P69 $ARG2 alumnus $ARG1 203 | 204 | 205 | # job title 206 | #per:title $ARG1 was appointed $ARG2 207 | #per:title $ARG1 was nominated $ARG2 208 | #per:title $ARG1 was elected as $ARG2 209 | #per:title $ARG1 , a $ARG2 210 | #per:title $ARG1 , an $ARG2 211 | #per:title $ARG1 , the $ARG2 212 | #per:title $ARG1 , a former $ARG2 213 | #per:title $ARG1 , an a former $ARG2 214 | #per:title $ARG1 is a $ARG2 215 | #per:title $ARG1 is an $ARG2 216 | #per:title $ARG1 is the $ARG2 217 | #per:title $ARG2 works as a $ARG1 218 | #per:title $ARG2 works as an $ARG1 219 | #per:title $ARG1 was a $ARG2 220 | #per:title $ARG1 was an $ARG2 221 | #per:title $ARG2 worked as a $ARG1 222 | #per:title $ARG2 worked as an $ARG1 223 | # TODO: high-recall vs. high-precision patterns 224 | #per:title $ARG2 $ARG1 225 | 226 | 227 | 228 | # per:member_of -> P463, P54 229 | P463 $ARG1 is a member of $ARG2 230 | P463 $ARG2 member $ARG1 231 | P463 $ARG1 is a fellow of $ARG2 232 | P463 $ARG2 fellow $ARG1 233 | P463 $ARG1 became a member of $ARG2 234 | P463 $ARG1 became $ARG2 member 235 | P463 $ARG1 became a fellow of $ARG2 236 | P463 $ARG1 became $ARG2 fellow 237 | P463 $ARG1 joined $ARG2 238 | P463 $ARG1 is leaving $ARG2 239 | P463 $ARG1 left $ARG2 240 | 241 | P463 $ARG1 worked for $ARG2 242 | P463 $ARG1 works for $ARG2 243 | P463 $ARG1 has been working for $ARG2 244 | P463 $ARG1 was working for $ARG2 245 | P463 $ARG1 had working for $ARG2 246 | P463 $ARG1 was employed by $ARG2 247 | P463 $ARG1 is employed by $ARG2 248 | P463 $ARG1 , an $ARG2 employee 249 | P463 $ARG1 , a $ARG2 employee 250 | P463 $ARG1 was hired by $ARG2 251 | P463 $ARG1 has been hired by $ARG2 252 | P463 $ARG1 had been hired by $ARG2 253 | P463 $ARG2 manager $ARG1 254 | P463 $ARG2 is a manager of $ARG1 255 | P463 $ARG2 coach $ARG1 256 | P463 $ARG2 is the coach of $ARG1 257 | P463 $ARG1 served $ARG2 258 | 259 | # sports team P54 260 | P54 $ARG1 is a member of $ARG2 261 | P54 $ARG2 member $ARG1 262 | P54 $ARG1 is a fellow of $ARG2 263 | P54 $ARG2 fellow $ARG1 264 | P54 $ARG1 became a member of $ARG2 265 | P54 $ARG1 became $ARG2 member 266 | P54 $ARG1 became a fellow of $ARG2 267 | P54 $ARG1 became $ARG2 fellow 268 | P54 $ARG1 joined $ARG2 269 | P54 $ARG1 is leaving $ARG2 270 | P54 $ARG1 left $ARG2 271 | 272 | P54 $ARG1 worked for $ARG2 273 | P54 $ARG1 works for $ARG2 274 | P54 $ARG1 has been working for $ARG2 275 | P54 $ARG1 was working for $ARG2 276 | P54 $ARG1 had working for $ARG2 277 | P54 $ARG1 was employed by $ARG2 278 | P54 $ARG1 is employed by $ARG2 279 | P54 $ARG1 , an $ARG2 employee 280 | P54 $ARG1 , a $ARG2 employee 281 | P54 $ARG1 was hired by $ARG2 282 | P54 $ARG1 has been hired by $ARG2 283 | P54 $ARG1 had been hired by $ARG2 284 | P54 $ARG2 manager $ARG1 285 | P54 $ARG2 is a manager of $ARG1 286 | P54 $ARG2 coach $ARG1 287 | P54 $ARG2 is the coach of $ARG1 288 | P54 $ARG1 served $ARG2 289 | 290 | 291 | # per:employee_of -> P108 292 | P108 $ARG1 worked for $ARG2 293 | P108 $ARG1 works for $ARG2 294 | P108 $ARG1 has been working for $ARG2 295 | P108 $ARG1 was working for $ARG2 296 | P108 $ARG1 had working for $ARG2 297 | P108 $ARG1 was employed by $ARG2 298 | P108 $ARG1 is employed by $ARG2 299 | P108 $ARG1 , an $ARG2 employee 300 | P108 $ARG1 , a $ARG2 employee 301 | P108 $ARG1 was hired by $ARG2 302 | P108 $ARG1 has been hired by $ARG2 303 | P108 $ARG1 had been hired by $ARG2 304 | P108 $ARG2 manager $ARG1 305 | P108 $ARG2 is a manager of $ARG1 306 | P108 $ARG2 coach $ARG1 307 | P108 $ARG2 is the coach of $ARG1 308 | P108 $ARG1 joined $ARG2 309 | P108 $ARG1 is leaving $ARG2 310 | P108 $ARG1 left $ARG2 311 | P108 $ARG1 served $ARG2 312 | 313 | P108 $ARG1 is a member of $ARG2 314 | P108 $ARG2 member $ARG1 315 | P108 $ARG1 is a fellow of $ARG2 316 | P108 $ARG2 fellow $ARG1 317 | P108 $ARG1 became a member of $ARG2 318 | P108 $ARG1 became $ARG2 member 319 | P108 $ARG1 became a fellow of $ARG2 320 | P108 $ARG1 became $ARG2 fellow 321 | P108 $ARG1 joined $ARG2 322 | P108 $ARG1 is leaving $ARG2 323 | P108 $ARG1 left $ARG2 324 | P108 $ARG1 worked for $ARG2 325 | P108 $ARG1 works for $ARG2 326 | P108 $ARG1 has been working for $ARG2 327 | P108 $ARG1 was working for $ARG2 328 | P108 $ARG1 had working for $ARG2 329 | P108 $ARG1 was employed by $ARG2 330 | P108 $ARG1 is employed by $ARG2 331 | P108 $ARG1 , an $ARG2 employee 332 | P108 $ARG1 , a $ARG2 employee 333 | P108 $ARG1 was hired by $ARG2 334 | P108 $ARG1 has been hired by $ARG2 335 | P108 $ARG1 had been hired by $ARG2 336 | P108 $ARG2 manager $ARG1 337 | P108 $ARG2 is a manager of $ARG1 338 | P108 $ARG2 coach $ARG1 339 | P108 $ARG2 is the coach of $ARG1 340 | P108 $ARG1 joined $ARG2 341 | P108 $ARG1 is leaving $ARG2 342 | P108 $ARG1 left $ARG2 343 | P108 $ARG1 served $ARG2 344 | 345 | 346 | # per:spouse -> P26 347 | P26 $ARG0 and $ARG0 are married 348 | P26 $ARG0 is married to $ARG0 349 | P26 $ARG0 is the wife of $ARG0 350 | P26 $ARG0 's wife , $ARG0 351 | P26 $ARG0' wife , $ARG0 352 | P26 $ARG0 , $ARG0 's wife 353 | P26 $ARG0 , $ARG0' wife 354 | P26 $ARG0 , the wife of $ARG0 355 | P26 $ARG0 and his wife $ARG0 356 | P26 $ARG0 is the husband of $ARG0 357 | P26 $ARG0 's husband , $ARG0 358 | P26 $ARG0' husband , $ARG0 359 | P26 $ARG0 , $ARG0 's husband 360 | P26 $ARG0 , $ARG0' husband 361 | P26 $ARG0 , the husband of $ARG0 362 | P26 $ARG0 and her husband $ARG0 363 | 364 | P40 $ARG1 's child $ARG2 365 | P40 $ARG1 ' child $ARG2 366 | P40 $ARG2 is the child of $ARG1 367 | P40 $ARG1 has a child named $ARG2 368 | P40 $ARG1 's oldest child $ARG2 369 | P40 $ARG1 ' oldest child $ARG2 370 | P40 $ARG2 is the oldest child of $ARG1 371 | P40 $ARG1 's youngest child $ARG2 372 | P40 $ARG1 ' youngest child $ARG2 373 | P40 $ARG2 is the youngest child of $ARG1 374 | P40 $ARG1 's daughter $ARG2 375 | P40 $ARG1 ' daughter $ARG2 376 | P40 $ARG2 is the daughter of $ARG1 377 | P40 $ARG1 has a doughter named $ARG2 378 | P40 $ARG1 's son $ARG2 379 | P40 $ARG1 ' son $ARG2 380 | P40 $ARG2 is the son of $ARG1 381 | P40 $ARG1 has a son named $ARG2 382 | P40 $ARG1 is the father of $ARG2 383 | P40 $ARG2 's father , $ARG1 384 | P40 $ARG2 ' father , $ARG1 385 | P40 $ARG2 's father $ARG1 386 | P40 $ARG2 ' father $ARG1 387 | P40 $ARG1 is the mother of $ARG2 388 | P40 $ARG2 's mother , $ARG1 389 | P40 $ARG2 ' mother , $ARG1 390 | P40 $ARG2 's mother $ARG1 391 | P40 $ARG2 ' mother $ARG1 392 | 393 | P22 $ARG2 's child $ARG1 394 | P22 $ARG2 ' child $ARG1 395 | P22 $ARG1 is the child of $ARG2 396 | P22 $ARG2 has a child named $ARG1 397 | P22 $ARG2 's oldest child $ARG1 398 | P22 $ARG2 ' oldest child $ARG1 399 | P22 $ARG1 is the oldest child of $ARG2 400 | P22 $ARG2 ' youngest child $ARG1 401 | P22 $ARG1 is the youngest child of $ARG2 402 | P22 $ARG2 's daughter $ARG1 403 | P22 $ARG2 ' daughter $ARG1 404 | P22 $ARG1 is the daughter of $ARG2 405 | P22 $ARG2 has a daughter named $ARG1 406 | P22 $ARG2 's son $ARG1 407 | P22 $ARG2 ' son $ARG1 408 | P22 $ARG1 is the son of $ARG2 409 | P22 $ARG2 has a son named $ARG1 410 | P22 $ARG2 is the father of $ARG1 411 | P22 $ARG1 's father , $ARG2 412 | P22 $ARG1 ' father , $ARG2 413 | P22 $ARG1 's father $ARG2 414 | P22 $ARG1 ' father $ARG2 415 | 416 | P25 $ARG2 's child $ARG1 417 | P25 $ARG2 ' child $ARG1 418 | P25 $ARG1 is the child of $ARG2 419 | P25 $ARG2 has a child named $ARG1 420 | P25 $ARG2 's oldest child $ARG1 421 | P25 $ARG2 ' oldest child $ARG1 422 | P25 $ARG1 is the oldest child of $ARG2 423 | P25 $ARG2 ' youngest child $ARG1 424 | P25 $ARG1 is the youngest child of $ARG2 425 | P25 $ARG2 's daughter $ARG1 426 | P25 $ARG2 ' daughter $ARG1 427 | P25 $ARG1 is the daughter of $ARG2 428 | P25 $ARG2 has a daughter named $ARG1 429 | P25 $ARG2 's son $ARG1 430 | P25 $ARG2 ' son $ARG1 431 | P25 $ARG1 is the son of $ARG2 432 | P25 $ARG2 has a son named $ARG1 433 | P25 $ARG2 is the mother of $ARG1 434 | P25 $ARG1 's mother , $ARG2 435 | P25 $ARG1 ' mother , $ARG2 436 | P25 $ARG1 's mother $ARG2 437 | P25 $ARG1 ' mother $ARG2 438 | 439 | P7 $ARG0 is the brother of $ARG0 440 | P7 $ARG0 is a brother of $ARG0 441 | P7 $ARG0 's brother , $ARG0 442 | P7 $ARG0' brother , $ARG0 443 | P7 $ARG0 's brother $ARG0 444 | P7 $ARG0' brother $ARG0 445 | P7 $ARG0 is the younger brother of $ARG0 446 | P7 $ARG0 's younger brother , $ARG0 447 | P7 $ARG0' younger brother , $ARG0 448 | P7 $ARG0 is the older brother of $ARG0 449 | P7 $ARG0 's older brother , $ARG0 450 | P7 $ARG0' older brother , $ARG0 451 | P9 $ARG0 is the sister of $ARG0 452 | P9 $ARG0 is a sister of $ARG0 453 | P9 $ARG0 's sister $ARG0 454 | P9 $ARG0' sister $ARG0 455 | P9 $ARG0 's sister , $ARG0 456 | P9 $ARG0' sister , $ARG0 457 | P9 $ARG0 is the younger sister of $ARG0 458 | P9 $ARG0 's younger sister , $ARG0 459 | P9 $ARG0' younger sister , $ARG0 460 | P9 $ARG0 is the older sister of $ARG0 461 | P9 $ARG0 's older sister , $ARG0 462 | P9 $ARG0' older sister , $ARG0 463 | 464 | P1038 $ARG0 's grandson $ARG0 465 | P1038 $ARG0 's granddaughter $ARG0 466 | P1038 $ARG0 's grandfather $ARG0 467 | P1038 $ARG0 's grandmother $ARG0 468 | P1038 $ARG0 's uncle $ARG0 469 | P1038 $ARG0 's aunt $ARG0 470 | P1038 $ARG0 's cousin $ARG0 471 | P1038 $ARG0 's nephew $ARG0 472 | P1038 $ARG0 's relative $ARG0 473 | P1038 $ARG0 's grandson , $ARG0 474 | P1038 $ARG0 's granddaughter , $ARG0 475 | P1038 $ARG0 's grandfather , $ARG0 476 | P1038 $ARG0 's grandmother , $ARG0 477 | P1038 $ARG0 's uncle , $ARG0 478 | P1038 $ARG0 's aunt , $ARG0 479 | P1038 $ARG0 's cousin , $ARG0 480 | P1038 $ARG0 's nephew , $ARG0 481 | P1038 $ARG0 's relative , $ARG0 482 | P1038 $ARG0 is a grandson of $ARG0 483 | P1038 $ARG0 is the grandson of $ARG0 484 | P1038 $ARG0 is a granddaughter of $ARG0 485 | P1038 $ARG0 is the granddaughter of $ARG0 486 | P1038 $ARG0 is the grandfather of $ARG0 487 | P1038 $ARG0 is the grandmother of $ARG0 488 | P1038 $ARG0 is the uncle of $ARG0 489 | P1038 $ARG0 is the aunt of $ARG0 490 | P1038 $ARG0 is the cousin of $ARG0 491 | P1038 $ARG0 is a cousin of $ARG0 492 | P1038 $ARG0 is the nephew of $ARG0 493 | P1038 $ARG0 is a nephew of $ARG0 494 | P1038 $ARG0 is a relative of $ARG0 495 | P1038 $ARG0 is related to $ARG0 496 | 497 | 498 | # org:political_religious_affiliation -> P102 499 | P102 $ARG1 is affiliated with $ARG2 500 | P102 $ARG1 , affiliated with $ARG2 501 | P102 $ARG1 is connected to $ARG2 502 | P102 $ARG1 , connected to $ARG2 503 | P102 $ARG1 is strongly connected to $ARG2 504 | P102 $ARG1 , strongly connected to $ARG2 505 | P102 $ARG1 is tightly connected to $ARG2 506 | P102 $ARG1 , tightly connected to $ARG2 507 | P102 $ARG1 is a $ARG2 organization 508 | P102 $ARG1 is an $ARG2 organization 509 | P102 $ARG1 is a $ARG2 group 510 | P102 $ARG1 is an $ARG2 group 511 | P102 $ARG1 as a $ARG2 organization 512 | P102 $ARG1 as an $ARG2 organization 513 | P102 $ARG1 as a $ARG2 group 514 | P102 $ARG1 as an $ARG2 group 515 | P102 $ARG1 , a $ARG2 organization 516 | P102 $ARG1 , an $ARG2 organization 517 | P102 $ARG1 , a $ARG2 group 518 | P102 $ARG1 , an $ARG2 group 519 | P102 $ARG1 and other $ARG2 organizations 520 | P102 $ARG1 and other $ARG2 groups 521 | P102 $ARG1 is an affiliate of $ARG2 522 | 523 | 524 | P108 $ARG1 leads $ARG2 525 | P108 $ARG1 commands $ARG2 526 | P108 $ARG1 oversees $ARG2 527 | P108 $ARG1 commanded $ARG2 528 | P108 $ARG2 board member $ARG1 529 | P108 $ARG2 CFO $ARG1 530 | P108 $ARG2 CEO $ARG1 531 | P108 $ARG2 chief executive officer $ARG1 532 | P108 $ARG2 managing director $ARG1 533 | P108 $ARG2 MD $ARG1 534 | P108 $ARG2 executive director $ARG1 535 | P108 $ARG2 ED $ARG1 536 | P108 $ARG2 president $ARG1 537 | P108 $ARG2 vice president $ARG1 538 | P108 $ARG2 director $ARG1 539 | P108 $ARG2 chairman $ARG1 540 | P108 $ARG2 executive vice president $ARG1 541 | P108 $ARG2 provost $ARG1 542 | P108 $ARG2 dean $ARG1 543 | P108 $ARG1 is a board member of $ARG2 544 | P108 $ARG1 is the CFO of $ARG2 545 | P108 $ARG1 is the CEO of $ARG2 546 | P108 $ARG1 is the president of $ARG2 547 | P108 $ARG1 is the vice president of $ARG2 548 | P108 $ARG1 is the director of $ARG2 549 | P108 $ARG1 is the chairman of $ARG2 550 | P108 $ARG1 is the executive vice president of $ARG2 551 | P108 $ARG1 is the provost of $ARG2 552 | P108 $ARG1 is the dean of $ARG2 553 | P108 $ARG1 , board member of $ARG2 554 | P108 $ARG1 , CFO of $ARG2 555 | P108 $ARG1 , CEO of $ARG2 556 | P108 $ARG1 , president of $ARG2 557 | P108 $ARG1 , vice president of $ARG2 558 | P108 $ARG1 , director of $ARG2 559 | P108 $ARG1 , chairman of $ARG2 560 | P108 $ARG1 , executive vice president of $ARG2 561 | P108 $ARG1 , provost of $ARG2 562 | P108 $ARG1 , dean of $ARG2 563 | P108 $ARG1 , a board member of $ARG2 564 | P108 $ARG1 , the CFO of $ARG2 565 | P108 $ARG1 , the CEO of $ARG2 566 | P108 $ARG1 , the president of $ARG2 567 | P108 $ARG1 , the vice president of $ARG2 568 | P108 $ARG1 , the director of $ARG2 569 | P108 $ARG1 , the chairman of $ARG2 570 | P108 $ARG1 , the executive vice president of $ARG2 571 | P108 $ARG1 , the provost of $ARG2 572 | P108 $ARG1 , the dean of $ARG2 573 | P108 $ARG1 , former board member of $ARG2 574 | P108 $ARG1 , former CFO of $ARG2 575 | P108 $ARG1 , former CEO of $ARG2 576 | P108 $ARG1 , former president of $ARG2 577 | P108 $ARG1 , former vice president of $ARG2 578 | P108 $ARG1 , former director of $ARG2 579 | P108 $ARG1 , former chairman of $ARG2 580 | P108 $ARG1 , former executive vice president of $ARG2 581 | P108 $ARG1 , former provost of $ARG2 582 | P108 $ARG1 , former dean of $ARG2 583 | 584 | 585 | # org:member_of -> P463 586 | P463 $ARG1 is a member organization of $ARG2 587 | P463 $ARG1 as a member organization of $ARG2 588 | P463 $ARG1 , as a member organization of $ARG2 589 | P463 $ARG1 , a member organization of $ARG2 590 | P463 $ARG1 is a member of $ARG2 591 | P463 $ARG1 are member of $ARG2 592 | P463 $ARG1 as a member of $ARG2 593 | P463 $ARG1 , as a member of $ARG2 594 | P463 $ARG1 , a member of $ARG2 595 | P463 $ARG1 is part of $ARG2 596 | P463 $ARG1 are part of $ARG2 597 | P463 $ARG1 plays in $ARG2 598 | P463 $ARG1 play in $ARG2 599 | P463 $ARG1 is relegated to $ARG2 600 | P463 $ARG1 are relegated to $ARG2 601 | P463 $ARG1 is promoted to $ARG2 602 | P463 $ARG1 are promoted to $ARG2 603 | P463 $ARG1 joined $ARG2 604 | P463 $ARG1 and other members of $ARG2 605 | 606 | 607 | # org:subsidiaries -> P355 608 | P355 $ARG2 , a subsidiary of $ARG1 609 | P355 $ARG2 is a subsidiary of $ARG1 610 | P355 $ARG2 as a subsidiary of $ARG1 611 | P355 $ARG2 operates as a subsidiary of $ARG1 612 | P355 $ARG2 , a branch of $ARG1 613 | P355 $ARG2 is a branch of $ARG1 614 | P355 $ARG2 as a branch of $ARG1 615 | P355 $ARG2 , a regional branch of $ARG1 616 | P355 $ARG2 is a regional branch of $ARG1 617 | P355 $ARG2 as a regional branch of $ARG1 618 | P355 $ARG2 , owned by $ARG1 619 | P355 $ARG2 is owned by $ARG1 620 | P355 $ARG1 owns $ARG2 621 | P355 $ARG1 owned $ARG2 622 | P355 $ARG2 belongs to $ARG1 623 | P355 $ARG2 , belonging to $ARG1 624 | P355 $ARG1 's $ARG2 subsidiary 625 | # Too low precision. 626 | #org:subsidiaries $ARG1 's $ARG2 627 | #org:subsidiaries $ARG2 of $ARG1 628 | P355 $ARG1 's subsidiary $ARG2 629 | P355 $ARG1 's unit $ARG2 630 | P355 $ARG1 's arm $ARG2 631 | P355 $ARG1 's * subsidiary $ARG2 632 | P355 $ARG1 's * branch $ARG2 633 | P355 $ARG1 's * unit $ARG2 634 | P355 $ARG1 's * arm $ARG2 635 | P355 $ARG2 subsidiary of $ARG1 636 | P355 $ARG2 is the * arm of $ARG1 637 | P355 $ARG2 is the * unit of $ARG1 638 | P355 $ARG2 is the * branch of $ARG1 639 | P355 $ARG2 is the * subsidiary of $ARG1 640 | 641 | 642 | # org:parent -> P749 643 | P749 $ARG1 , a subsidiary of $ARG2 644 | P749 $ARG1 is a subsidiary of $ARG2 645 | P749 $ARG1 as a subsidiary of $ARG2 646 | P749 $ARG1 operates as a subsidiary of $ARG2 647 | P749 $ARG1 , a branch of $ARG2 648 | P749 $ARG1 is a branch of $ARG2 649 | P749 $ARG1 as a branch of $ARG2 650 | P749 $ARG1 , a regional branch of $ARG2 651 | P749 $ARG1 is a regional branch of $ARG2 652 | P749 $ARG1 as a regional branch of $ARG2 653 | P749 $ARG1 , owned by $ARG2 654 | P749 $ARG1 is owned by $ARG2 655 | P749 $ARG2 owns $ARG1 656 | P749 $ARG2 owned $ARG1 657 | P749 $ARG1 belongs to $ARG2 658 | P749 $ARG1 , belonging to $ARG2 659 | P749 $ARG2 's $ARG1 subsidiary 660 | P749 $ARG2 's subsidiary $ARG1 661 | P749 $ARG2 's unit $ARG1 662 | P749 $ARG2 's arm $ARG1 663 | P749 $ARG2 's * subsidiary $ARG1 664 | P749 $ARG2 's * branch $ARG1 665 | P749 $ARG2 's * unit $ARG1 666 | P749 $ARG2 's * arm $ARG1 667 | P749 $ARG1 subsidiary of $ARG2 668 | P749 $ARG1 is the * arm of $ARG2 669 | P749 $ARG1 is the * unit of $ARG2 670 | P749 $ARG1 is the * branch of $ARG2 671 | P749 $ARG1 is the * subsidiary of $ARG2 672 | 673 | P112 $ARG2 was founded by $ARG1 674 | P112 $ARG2 , founded by $ARG1 675 | P112 $ARG1 founded $ARG2 676 | P112 $ARG1 , who founded $ARG2 677 | P112 $ARG1 , the founder $ARG2 678 | P112 $ARG2 was formed by $ARG1 679 | P112 $ARG2 , formed by $ARG1 680 | P112 $ARG1 formed $ARG2 681 | P112 $ARG1 , who formed $ARG2 682 | P112 $ARG2 was established by $ARG1 683 | P112 $ARG2 , established by $ARG1 684 | P112 $ARG1 established $ARG2 685 | P112 $ARG1 , who established $ARG2 686 | #org:founded $ARG1 , established $ARG2 687 | #org:founded $ARG1 , founded $ARG2 688 | #org:founded $ARG1 , incorporated $ARG2 689 | #org:founded $ARG1 was established $ARG2 690 | #org:founded $ARG1 was founded $ARG2 691 | #org:founded $ARG1 was incorporated $ARG2 692 | #org:founded established $ARG1 $ARG2 693 | #org:founded formed $ARG1 $ARG2 694 | #org:founded founded $ARG1 $ARG2 695 | 696 | 697 | # org:country_of_headquarters -> P131 698 | P131 $ARG1 's main complex in $ARG2 699 | P131 $ARG1 's main campus in $ARG2 700 | P131 $ARG1 's head office in $ARG2 701 | P131 $ARG1 's main office in $ARG2 702 | P131 $ARG1 's main offices in $ARG2 703 | P131 $ARG1 's headquarters in $ARG2 704 | P131 $ARG1 's main complex is in $ARG2 705 | P131 $ARG1 's main campus is in $ARG2 706 | P131 $ARG1 's head office is in $ARG2 707 | P131 $ARG1 's main office is in $ARG2 708 | P131 $ARG1 's main offices are in $ARG2 709 | P131 $ARG1 's headquarters are in $ARG2 710 | P131 main complex of $ARG1 is in $ARG2 711 | P131 main campus of $ARG1 is in $ARG2 712 | P131 head office of $ARG1 is in $ARG2 713 | P131 main office of $ARG1 is in $ARG2 714 | P131 main offices of $ARG1 are in $ARG2 715 | P131 headquarters of $ARG1 are in $ARG2 716 | P131 main complex of $ARG1 in $ARG2 717 | P131 main campus of $ARG1 in $ARG2 718 | P131 head office of $ARG1 in $ARG2 719 | P131 main office of $ARG1 in $ARG2 720 | P131 main offices of $ARG1 in $ARG2 721 | P131 headquarters of $ARG1 in $ARG2 722 | P131 $ARG1 is headquartered in $ARG2 723 | P131 $ARG1 is based in $ARG2 724 | P131 $ARG1 is located in $ARG2 725 | P131 $ARG1 , headquartered in $ARG2 726 | P131 $ARG1 , based in $ARG2 727 | P131 $ARG1 , located in $ARG2 728 | P131 $ARG2 main complex of $ARG1 729 | P131 $ARG2 main campus of $ARG1 730 | P131 $ARG2 head office of $ARG1 731 | P131 $ARG2 main office of $ARG1 732 | P131 $ARG2 main offices of $ARG1 733 | P131 $ARG2 headquarters of $ARG1 734 | P131 $ARG1 has its main campus in $ARG2 735 | P131 $ARG1 has its head office in $ARG2 736 | P131 $ARG1 has its main office in $ARG2 737 | P131 $ARG1 has its main offices in $ARG2 738 | P131 $ARG1 has its headquarters in $ARG2 739 | P131 $ARG1 , a $ARG2-based 740 | P131 $ARG1 , an $ARG2-based 741 | P131 $ARG1 is a $ARG2-based 742 | P131 $ARG1 is an $ARG2-based 743 | P131 $ARG1 is a * based in $ARG2 744 | P131 $ARG1 is based in * , $ARG2 745 | P131 $ARG1 , based in * , $ARG2 746 | 747 | 748 | 749 | 750 | -------------------------------------------------------------------------------- /distant_supervision.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ########################################################## 4 | # 5 | # Distant Supervision 6 | # 7 | ########################################################### 8 | 9 | import pandas as pd 10 | import numpy as np 11 | from bs4 import BeautifulSoup as bs 12 | 13 | 14 | import os 15 | import re 16 | import time 17 | import requests 18 | import urllib 19 | import glob 20 | from codecs import open 21 | from itertools import combinations 22 | from collections import Counter 23 | 24 | import util 25 | 26 | import sys 27 | reload(sys) 28 | sys.setdefaultencoding("utf-8") 29 | 30 | 31 | # global variables 32 | data_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') 33 | orig_dir = os.path.join(data_dir, 'orig') 34 | ner_dir = os.path.join(data_dir, 'ner') 35 | 36 | 37 | ner_path = "/usr/local/Cellar/stanford-ner/3.5.2/libexec/" 38 | stanford_classifier = os.path.join(ner_path, 'classifiers', 'english.all.3class.distsim.crf.ser.gz') 39 | stanford_ner = os.path.join(ner_path, 'stanford-ner.jar') 40 | 41 | tag_map = { 42 | 'ORGANIZATION': 'Q43229', # https://www.wikidata.org/wiki/Q43229 43 | 'LOCATION': 'Q17334923', # https://www.wikidata.org/wiki/Q17334923 44 | 'PERSON': 'Q5' # https://www.wikidata.org/wiki/Q5 45 | } 46 | 47 | # column names in DataFrame 48 | col = ['doc_id', 'sent_id', 'sent', 'subj', 'subj_begin', 'subj_end', 'subj_tag', 49 | 'rel', 'obj', 'obj_begin', 'obj_end', 'obj_tag'] 50 | 51 | 52 | def sanitize(string): 53 | """clean wikipedia article""" 54 | string = re.sub(r"\[\d{1,3}\]", " ", string) 55 | string = re.sub(r"\[edit\]", " ", string) 56 | string = re.sub(r" {2,}", " ", string) 57 | return string.strip() 58 | 59 | 60 | def download_wiki_articles(doc_id, limit=100, retry=False): 61 | """download wikipedia article via Mediawiki API""" 62 | base_path = "http://en.wikipedia.org/w/api.php?format=xml&action=query" 63 | query = base_path + "&list=random&rnnamespace=0&rnlimit=%d" % limit 64 | r = None 65 | try: 66 | r = urllib.urlopen(query).read() 67 | except Exception as e: 68 | if not retry: 69 | download_wiki_articles(doc_id, limit, retry=True) 70 | else: 71 | print e.message 72 | return None 73 | pages = bs(r, "html.parser").findAll('page') 74 | if len(pages) < 1: 75 | return None 76 | docs = [] 77 | for page in pages: 78 | if int(page['id']) in doc_id: 79 | continue 80 | 81 | link = base_path + "&prop=revisions&pageids=%s&rvprop=content&rvparse" % page['id'] 82 | content = urllib.urlopen(link).read() 83 | content = bs(content, "html.parser").find('rev').stripped_strings 84 | 85 | # extract paragraph elements only 86 | text = '' 87 | for p in bs(' '.join(content), "html.parser").findAll('p'): 88 | text += ' '.join(p.stripped_strings) + '\n' 89 | #text = text.encode('utf8') 90 | text = sanitize(text) 91 | 92 | # save 93 | if len(text) > 0: 94 | title = re.sub(r"[ /]", "_", page['title']) 95 | filename = page['id'] + '-' + title + '.txt' 96 | docs.append(filename) 97 | with open(os.path.join(orig_dir, filename), 'w', encoding='utf-8') as f: 98 | f.write(text) 99 | return docs 100 | 101 | 102 | def exec_ner(filenames): 103 | """execute Stanford NER""" 104 | for filename in filenames: 105 | in_path = os.path.join(orig_dir, filename) 106 | out_path = os.path.join(ner_dir, filename) 107 | cmd = 'java -mx700m -cp "%s:" edu.stanford.nlp.ie.crf.CRFClassifier' % stanford_ner 108 | cmd += ' -loadClassifier %s -outputFormat tabbedEntities' % stanford_classifier 109 | cmd += ' -textFile %s > %s' % (in_path, out_path) 110 | os.system(cmd) 111 | 112 | 113 | def read_ner_output(filenames): 114 | """read NER output files and store them in a pandas DataFrame""" 115 | rows = [] 116 | for filename in filenames: 117 | path = os.path.join(ner_dir, filename) 118 | if not os.path.exists(path): 119 | continue 120 | with open(path, 'r', encoding='utf-8') as f: 121 | doc_id = filename.split('/')[-1].split('-', 1)[0] 122 | counter = 0 123 | tmp = [] 124 | for line in f.readlines(): 125 | if len(line.strip()) < 1 and len(tmp) > 2: 126 | ent = [i for i, t in enumerate(tmp) if t[1] in tag_map.keys()] 127 | for c in combinations(ent, 2): 128 | dic = {'sent': u''} 129 | dic['doc_id'] = doc_id 130 | dic['sent_id'] = counter 131 | for j, t in enumerate(tmp): 132 | if j == c[0]: 133 | if len(dic['sent']) > 0: 134 | dic['subj_begin'] = len(dic['sent']) + 1 135 | else: 136 | dic['subj_begin'] = 0 137 | if len(dic['sent']) > 0: 138 | dic['subj_end'] = len(dic['sent']) + len(t[0].strip()) + 1 139 | else: 140 | dic['subj_end'] = len(t[0].strip()) 141 | dic['subj'] = t[0].strip() 142 | dic['subj_tag'] = t[1].strip() 143 | elif j == c[1]: 144 | dic['obj_begin'] = len(dic['sent']) + 1 145 | dic['obj_end'] = len(dic['sent']) + len(t[0].strip()) + 1 146 | dic['obj'] = t[0].strip() 147 | dic['obj_tag'] = t[1].strip() 148 | 149 | if len(dic['sent']) > 0: 150 | dic['sent'] += ' ' + t[0].strip() 151 | else: 152 | dic['sent'] += t[0].strip() 153 | if len(dic['sent']) > 0: 154 | dic['sent'] += ' ' + t[2].strip() 155 | else: 156 | dic['sent'] += t[2].strip() 157 | #print '"'+dic['sent']+'"', len(dic['sent']) 158 | rows.append(dic) 159 | #print dic 160 | counter += 1 161 | tmp = [] 162 | elif len(line.strip()) < 1 and len(tmp) > 0 and len(tmp) <= 2: 163 | continue 164 | elif len(line.strip()) > 0: 165 | e = line.split('\t') 166 | if len(e) == 1: 167 | e.insert(0, '') 168 | e.insert(0, '') 169 | if len(e) == 2 and e[1].strip() in tag_map.keys(): 170 | e.append('') 171 | if len(e) != 3: 172 | print e 173 | raise Exception 174 | tmp.append(e) 175 | else: 176 | continue 177 | 178 | return pd.DataFrame(rows) 179 | 180 | 181 | def name2qid(name, tag, alias=False, retry=False): 182 | """find QID (and Freebase ID if given) by name 183 | 184 | >>> name2qid('Barack Obama', 'PERSON') # perfect match 185 | (u'Q76', u'/m/02mjmr') 186 | >>> name2qid('Obama', 'PERSON', alias=True) # alias match 187 | (u'Q76', u'/m/02mjmr') 188 | """ 189 | 190 | label = 'rdfs:label' 191 | if alias: 192 | label = 'skos:altLabel' 193 | 194 | hpCharURL = 'https://query.wikidata.org/sparql?query=\ 195 | SELECT DISTINCT ?item ?fid \ 196 | WHERE {\ 197 | ?item '+label+' "'+name+'"@en.\ 198 | ?item wdt:P31 ?_instanceOf.\ 199 | ?_instanceOf wdt:P279* wd:'+tag_map[tag]+'.\ 200 | OPTIONAL { ?item wdt:P646 ?fid. }\ 201 | }\ 202 | LIMIT 10' 203 | headers = {"Accept": "application/json"} 204 | 205 | # check response 206 | r = None 207 | try: 208 | r = requests.get(hpCharURL, headers=headers) 209 | except requests.exceptions.ConnectionError: 210 | if not retry: 211 | time.sleep(60) 212 | name2qid(name, tag, alias, retry=True) 213 | else: 214 | return None 215 | except Exception as e: 216 | print e.message 217 | return None 218 | 219 | # check json format 220 | try: 221 | response = r.json() 222 | except ValueError: # includes JSONDecodeError 223 | return None 224 | 225 | # parse results 226 | results = [] 227 | for elm in response['results']['bindings']: 228 | fid = '' 229 | if elm.has_key('fid'): 230 | fid = elm['fid']['value'] 231 | results.append((elm['item']['value'].split('/')[-1], fid)) 232 | 233 | if len(results) < 1: 234 | return None 235 | else: 236 | return results[0] 237 | 238 | 239 | def search_property(qid1, qid2, retry=False): 240 | """find property (and schema.org relation if given) 241 | 242 | >>> search_property('Q76', 'Q30') # Q76: Barack Obama, Q30: United States 243 | [(u'P27', u'country of citizenship', u'nationality')] 244 | """ 245 | 246 | hpCharURL = 'https://query.wikidata.org/sparql?query= \ 247 | SELECT DISTINCT ?p ?l ?s \ 248 | WHERE {\ 249 | wd:'+qid1+' ?p wd:'+qid2+' .\ 250 | ?property ?ref ?p .\ 251 | ?property a wikibase:Property .\ 252 | ?property rdfs:label ?l FILTER (lang(?l) = "en")\ 253 | OPTIONAL { ?property wdt:P1628 ?s FILTER (SUBSTR(str(?s), 1, 18) = "http://schema.org/"). }\ 254 | }\ 255 | LIMIT 10' 256 | headers = {"Accept": "application/json"} 257 | 258 | # check response 259 | r = None 260 | try: 261 | r = requests.get(hpCharURL, headers=headers) 262 | except requests.exceptions.ConnectionError: 263 | if not retry: 264 | time.sleep(60) 265 | search_property(qid1, qid2, retry=True) 266 | else: 267 | return None 268 | except Exception as e: 269 | print e.message 270 | return None 271 | 272 | # check json format 273 | try: 274 | response = r.json() 275 | except ValueError: 276 | return None 277 | 278 | # parse results 279 | results = [] 280 | for elm in response['results']['bindings']: 281 | schema = '' 282 | if elm.has_key('s'): 283 | schema = elm['s']['value'].split('/')[-1] 284 | results.append((elm['p']['value'].split('/')[-1], elm['l']['value'], schema)) 285 | 286 | return results 287 | 288 | 289 | def slot_filling(qid, pid, tag, retry=False): 290 | """find slotfiller 291 | 292 | >>> slot_filling('Q76', 'P27', 'LOCATION') # Q76: Barack Obama, P27: country of citizenship 293 | [(u'United States', u'Q30', u'/m/09c7w0')] 294 | """ 295 | 296 | hpCharURL = 'https://query.wikidata.org/sparql?query=\ 297 | SELECT DISTINCT ?item ?itemLabel ?fid \ 298 | WHERE {\ 299 | wd:'+qid+' wdt:'+pid+' ?item.\ 300 | ?item wdt:P31 ?_instanceOf.\ 301 | ?_instanceOf wdt:P279* wd:'+tag_map[tag]+'.\ 302 | SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }\ 303 | OPTIONAL { ?item wdt:P646 ?fid. }\ 304 | }\ 305 | LIMIT 100' 306 | headers = {"Accept": "application/json"} 307 | 308 | # check response 309 | r = None 310 | try: 311 | r = requests.get(hpCharURL, headers=headers) 312 | except requests.exceptions.ConnectionError: 313 | if not retry: 314 | time.sleep(60) 315 | slot_filling(qid, pid, tag, retry=True) 316 | else: 317 | return None 318 | except Exception as e: 319 | print e.message 320 | return None 321 | 322 | # check json format 323 | try: 324 | response = r.json() 325 | except ValueError: 326 | return None 327 | 328 | # parse results 329 | results = [] 330 | for elm in response['results']['bindings']: 331 | fid = '' 332 | if elm.has_key('fid'): 333 | fid = elm['fid']['value'] 334 | results.append((elm['itemLabel']['value'], elm['item']['value'].split('/')[-1], fid)) 335 | 336 | return results 337 | 338 | 339 | def loop(step, doc_id, limit, entities, relations, counter): 340 | """Distant Supervision Loop""" 341 | # Download wiki articles 342 | print '[1/4] Downloading wiki articles ...' 343 | docs = download_wiki_articles(doc_id, limit) 344 | if docs is None: 345 | return None 346 | 347 | # Named Entity Recognition 348 | print '[2/4] Performing named entity recognition ...' 349 | exec_ner(docs) 350 | wiki_data = read_ner_output(docs) 351 | path = os.path.join(data_dir, 'candidates%d.tsv' % step) 352 | wiki_data.to_csv(path, sep='\t', encoding='utf-8') 353 | doc_id.extend([int(s) for s in wiki_data.doc_id.unique()]) 354 | 355 | # Prepare Containers 356 | unique_entities = set([]) 357 | unique_entity_pairs = set([]) 358 | for idx, row in wiki_data.iterrows(): 359 | unique_entities.add((row['subj'], row['subj_tag'])) 360 | unique_entities.add((row['obj'], row['obj_tag'])) 361 | unique_entity_pairs.add((row['subj'], row['obj'])) 362 | 363 | # Entity Linkage 364 | print '[3/4] Linking entities ...' 365 | for name, tag in unique_entities: 366 | if not entities.has_key(name) and tag in tag_map.keys(): 367 | e = name2qid(name, tag, alias=False) 368 | if e is None: 369 | e = name2qid(name, tag, alias=True) 370 | entities[name] = e 371 | util.dump_to_file(os.path.join(data_dir, "entities.cPickle"), entities) 372 | 373 | # Predicate Linkage 374 | print '[4/4] Linking predicates ...' 375 | for subj, obj in unique_entity_pairs: 376 | if not relations.has_key((subj, obj)): 377 | if entities[subj] is not None and entities[obj] is not None: 378 | if (entities[subj][0] != entities[obj][0]) or (subj != obj): 379 | arg1 = entities[subj][0] 380 | arg2 = entities[obj][0] 381 | relations[(subj, obj)] = search_property(arg1, arg2) 382 | #TODO: alternative name relation 383 | #elif (entities[subj][0] == entities[obj][0]) and (subj != obj): 384 | # relations[(subj, obj)] = 'P' 385 | util.dump_to_file(os.path.join(data_dir, "relations.cPickle"), relations) 386 | 387 | # Assign relation 388 | wiki_data['rel'] = pd.Series(index=wiki_data.index, dtype=str) 389 | for idx, row in wiki_data.iterrows(): 390 | entity_pair = (row['subj'], row['obj']) 391 | 392 | if relations.has_key(entity_pair): 393 | rel = relations[entity_pair] 394 | if rel is not None and len(rel) > 0: 395 | counter += 1 396 | wiki_data.set_value(idx, 'rel', ', '.join(set([s[0] for s in rel]))) 397 | # Save 398 | path = os.path.join(data_dir, 'candidates%d.tsv' % step) 399 | wiki_data.to_csv(path, sep='\t', encoding='utf-8') 400 | 401 | # Cleanup 402 | for f in glob.glob(os.path.join(orig_dir, '*')): 403 | os.remove(f) 404 | for f in glob.glob(os.path.join(ner_dir, '*')): 405 | os.remove(f) 406 | 407 | return doc_id, entities, relations, counter 408 | 409 | 410 | def extract_relations(entities, relations): 411 | """extract relations""" 412 | rows = [] 413 | for k, v in relations.iteritems(): 414 | if v is not None and len(v) > 0: 415 | for r in v: 416 | dic = {} 417 | dic['subj_qid'] = entities[k[0]][0] 418 | dic['subj_fid'] = entities[k[0]][1] 419 | dic['subj'] = k[0] 420 | dic['obj_qid'] = entities[k[1]][0] 421 | dic['obj_fid'] = entities[k[1]][1] 422 | dic['obj'] = k[1] 423 | dic['rel_id'] = r[0] 424 | dic['rel'] = r[1] 425 | dic['rel_schema'] = r[2] 426 | #TODO: add number of mentions 427 | #dic['wikidata_idx'] = entity_pairs[k] 428 | rows.append(dic) 429 | return pd.DataFrame(rows) 430 | 431 | 432 | def positive_examples(): 433 | entities = {} 434 | relations = {} 435 | counter = 0 436 | limit = 1000 437 | doc_id = [] 438 | step = 1 439 | 440 | if not os.path.exists(orig_dir): 441 | os.mkdir(orig_dir) 442 | if not os.path.exists(ner_dir): 443 | os.mkdir(ner_dir) 444 | 445 | #for j in range(1, step): 446 | # wiki_data = pd.read_csv(os.path.join(data_dir, "candidates%d.tsv" % j), sep='\t', index_col=0) 447 | # doc_id.extend([int(s) for s in wiki_data.doc_id.unique()]) 448 | # counter += int(wiki_data.rel.count()) 449 | 450 | while counter < 10000 and step < 100: 451 | print '===== step %d =====' % step 452 | ret = loop(step, doc_id, limit, entities, relations, counter) 453 | if ret is not None: 454 | doc_id, entities, relations, counter = ret 455 | 456 | step += 1 457 | 458 | # positive candidates 459 | positive_data = [] 460 | for f in glob.glob(os.path.join(data_dir, 'candidates*.tsv')): 461 | pos = pd.read_csv(f, sep='\t', encoding='utf-8', index_col=0) 462 | positive_data.append(pos[pd.notnull(pos.rel)]) 463 | positive_df = pd.concat(positive_data, axis=0, ignore_index=True) 464 | positive_df[col].to_csv(os.path.join(data_dir, 'positive_candidates.tsv'), sep='\t', encoding='utf-8') 465 | 466 | # save relations 467 | pos_rel = extract_relations(entities, relations) 468 | pos_rel.to_csv(os.path.join(data_dir, 'positive_relations.tsv'), sep='\t', encoding='utf-8') 469 | 470 | 471 | def negative_examples(): 472 | negative = {} 473 | 474 | unique_pair = set([]) 475 | neg_candidates = [] 476 | 477 | #TODO: replace with positive_relations.tsv 478 | entities = util.load_from_dump(os.path.join(data_dir, "entities.cPickle")) 479 | relations = util.load_from_dump(os.path.join(data_dir, "relations.cPickle")) 480 | 481 | rel_counter = Counter([u[0] for r in relations.values() if r is not None and len(r) > 0 for u in r]) 482 | most_common_rel = [r[0] for r in rel_counter.most_common(10)] 483 | 484 | 485 | for data_path in glob.glob(os.path.join(data_dir, 'candidates*.tsv')): 486 | neg = pd.read_csv(data_path, sep='\t', encoding='utf-8', index_col=0) 487 | negative_df = neg[pd.isnull(neg.rel)] 488 | 489 | # Assign relation 490 | for idx, row in negative_df.iterrows(): 491 | if (entities.has_key(row['subj']) and entities[row['subj']] is not None \ 492 | and entities.has_key(row['obj']) and entities[row['obj']] is not None): 493 | qid = entities[row['subj']][0] 494 | target = entities[row['obj']][0] 495 | candidates = [] 496 | for pid in most_common_rel: 497 | if (qid, pid) not in unique_pair: 498 | unique_pair.add((qid, pid)) 499 | items = slot_filling(qid, pid, row['obj_tag']) 500 | if items is not None and len(items) > 1: 501 | qids = [q[1] for q in items] 502 | if target not in qids: 503 | candidates.append(pid) 504 | 505 | if len(candidates) > 0: 506 | row['rel'] = ', '.join(candidates) 507 | neg_candidates.append(row) 508 | 509 | 510 | neg_examples = pd.DataFrame(neg_candidates) 511 | neg_examples[col].to_csv(os.path.join(data_dir, 'negative_candidates.tsv'), sep='\t', encoding='utf-8') 512 | 513 | 514 | # save relations 515 | #pos_rel = extract_relations(entities, negative) 516 | #pos_rel.to_csv(os.path.join(data_dir, 'positive_relations.tsv'), sep='\t', encoding='utf-8') 517 | 518 | 519 | def load_gold_patterns(): 520 | def clean_str(string): 521 | string = re.sub(r", ", " , ", string) 522 | string = re.sub(r"' ", " ' ", string) 523 | string = re.sub(r" \* ", " .* ", string) 524 | string = re.sub(r"\(", "-LRB-", string) 525 | string = re.sub(r"\)", "-RRB-", string) 526 | string = re.sub(r" {2,}", " ", string) 527 | return string.strip() 528 | 529 | g_patterns = [] 530 | g_labels = [] 531 | with open(os.path.join(data_dir, 'gold_patterns.tsv'), 'r') as f: 532 | for line in f.readlines(): 533 | line = line.strip() 534 | if len(line) > 0 and not line.startswith('#'): 535 | e = line.split('\t', 1) 536 | if len(e) > 1: 537 | g_patterns.append(clean_str(e[1])) 538 | g_labels.append(e[0]) 539 | else: 540 | print e 541 | raise Exception('Process Error: %s' % os.path.join(data_dir, 'gold_patterns.tsv')) 542 | 543 | return pd.DataFrame({'pattern': g_patterns, 'label': g_labels}) 544 | 545 | 546 | def score_reliability(gold_patterns, sent, rel, subj, obj): 547 | for name, group in gold_patterns.groupby('label'): 548 | if name in [r.strip() for r in rel.split(',')]: 549 | for i, g in group.iterrows(): 550 | pattern = g['pattern'] 551 | pattern = re.sub(r'\$ARG(0|1)', subj, pattern, count=1) 552 | pattern = re.sub(r'\$ARG(0|2)', obj, pattern, count=1) 553 | match = re.search(pattern, sent) 554 | if match: 555 | return 1.0 556 | return 0.0 557 | 558 | 559 | def extract_positive(): 560 | if not os.path.exists(os.path.join(data_dir, 'mlmi')): 561 | os.mkdir(os.path.join(data_dir, 'mlmi')) 562 | if not os.path.exists(os.path.join(data_dir, 'er')): 563 | os.mkdir(os.path.join(data_dir, 'er')) 564 | 565 | # read gold patterns to extract attention 566 | gold_patterns = load_gold_patterns() 567 | 568 | 569 | #TODO: replace with negative_relations.tsv 570 | entities = util.load_from_dump(os.path.join(data_dir, "entities.cPickle")) 571 | relations = util.load_from_dump(os.path.join(data_dir, "relations.cPickle")) 572 | 573 | # filter out the relations which occur less than 50 times 574 | rel_c = Counter([u[0] for r in relations.values() if r is not None and len(r) > 0 for u in r]) 575 | rel_c_top = [k for k, v in rel_c.most_common(50) if v >= 50] 576 | 577 | # positive examples 578 | positive_df = pd.read_csv(os.path.join(data_dir, 'positive_candidates.tsv'), 579 | sep='\t', encoding='utf-8', index_col=0) 580 | 581 | positive_df['right'] = pd.Series(index=positive_df.index, dtype=str) 582 | positive_df['middle'] = pd.Series(index=positive_df.index, dtype=str) 583 | positive_df['left'] = pd.Series(index=positive_df.index, dtype=str) 584 | positive_df['clean'] = pd.Series(index=positive_df.index, dtype=str) 585 | positive_df['label'] = pd.Series(index=positive_df.index, dtype=str) 586 | positive_df['attention'] = pd.Series([0.0]*len(positive_df), index=positive_df.index, dtype=np.float32) 587 | 588 | num_er = 0 589 | with open(os.path.join(data_dir, 'er', 'source.txt'), 'w', encoding='utf-8') as f: 590 | for idx, row in positive_df.iterrows(): 591 | 592 | # restore relation 593 | rel = ['<' + l.strip() + '>' for l in row['rel'].split(',') if l.strip() in rel_c_top] 594 | if len(rel) > 0: 595 | 596 | s = row['sent'] 597 | subj = '<' + entities[row['subj'].encode('utf-8')][0] + '>' 598 | obj = '<' + entities[row['obj'].encode('utf-8')][0] + '>' 599 | left = s[:row['subj_begin']] + subj 600 | middle = s[row['subj_end']:row['obj_begin']] 601 | right = obj + s[row['obj_end']:] 602 | text = left.strip() + ' ' + middle.strip() + ' ' + right.strip() 603 | 604 | # check if begin-end position is correct 605 | assert s[row['subj_begin']:row['subj_end']] == row['subj'] 606 | assert s[row['obj_begin']:row['obj_end']] == row['obj'] 607 | 608 | # MLMI dataset 609 | # filter out too long sentences 610 | if len(left.split()) < 100 and len(middle.split()) < 100 and len(right.split()) < 100: 611 | 612 | positive_df.set_value(idx, 'right', right.strip()) 613 | positive_df.set_value(idx, 'middle', middle.strip()) 614 | positive_df.set_value(idx, 'left', left.strip()) 615 | positive_df.set_value(idx, 'clean', text.strip()) 616 | 617 | # binarize label 618 | label = ['0'] * len(rel_c_top) 619 | for u in row['rel'].split(','): 620 | if u.strip() in rel_c_top: 621 | label[rel_c_top.index(u.strip())] = '1' 622 | positive_df.set_value(idx, 'label', ' '.join(label)) 623 | 624 | # score reliability if positive 625 | reliability = score_reliability(gold_patterns, s, row['rel'], row['subj'], row['obj']) 626 | positive_df.set_value(idx, 'attention', reliability) 627 | 628 | # ER dataset 629 | for r in rel: 630 | num_er += 1 631 | f.write(subj + ' ' + r + ' ' + obj + '\n') 632 | 633 | with open(os.path.join(data_dir, 'er', 'target.txt'), 'w', encoding='utf-8') as f: 634 | for _ in range(num_er): 635 | f.write('1 0\n') 636 | 637 | positive_df_valid = positive_df[pd.notnull(positive_df.clean)] 638 | assert len(positive_df_valid['clean']) == len(positive_df_valid['label']) 639 | 640 | positive_df_valid['right'].to_csv(os.path.join(data_dir, 'mlmi', 'source.right'), 641 | sep='\t', index=False, header=False, encoding='utf-8') 642 | positive_df_valid['middle'].to_csv(os.path.join(data_dir, 'mlmi', 'source.middle'), 643 | sep='\t', index=False, header=False, encoding='utf-8') 644 | positive_df_valid['left'].to_csv(os.path.join(data_dir, 'mlmi', 'source.left'), 645 | sep='\t', index=False, header=False, encoding='utf-8') 646 | positive_df_valid['clean'].to_csv(os.path.join(data_dir, 'mlmi', 'source.txt'), 647 | sep='\t', index=False, header=False, encoding='utf-8') 648 | positive_df_valid['label'].to_csv(os.path.join(data_dir, 'mlmi', 'target.txt'), 649 | sep='\t', index=False, header=False, encoding='utf-8') 650 | positive_df_valid['attention'].to_csv(os.path.join(data_dir, 'mlmi', 'source.att'), 651 | sep='\t', index=False, header=False, encoding='utf-8') 652 | 653 | 654 | def extract_negative(): 655 | entities = util.load_from_dump(os.path.join(data_dir, "entities.cPickle")) 656 | 657 | # negative examples 658 | negative_df = pd.read_csv(os.path.join(data_dir, 'negative_candidates.tsv'), 659 | sep='\t', encoding='utf-8', index_col=0) 660 | 661 | with open(os.path.join(data_dir, 'er', 'source.txt'), 'a', encoding='utf-8') as source_file: 662 | with open(os.path.join(data_dir, 'er', 'target.txt'), 'a', encoding='utf-8') as target_file: 663 | for idx, row in negative_df.iterrows(): 664 | s = row['sent'] 665 | 666 | subj = '<' + entities[row['subj'].encode('utf-8')][0] + '>' 667 | obj = '<' + entities[row['obj'].encode('utf-8')][0] + '>' 668 | rel = ['<' + l.strip() + '>' for l in row['rel'].split(',')] 669 | 670 | assert s[row['subj_begin']:row['subj_end']] == row['subj'] 671 | assert s[row['obj_begin']:row['obj_end']] == row['obj'] 672 | 673 | if len(rel) > 0: 674 | for r in rel: 675 | source_file.write(subj + ' ' + r + ' ' + obj + '\n') 676 | target_file.write('0 1\n') 677 | 678 | def main(): 679 | # gather positive examples 680 | if not os.path.exists(os.path.join(data_dir, 'positive_candidates.tsv')): 681 | positive_examples() 682 | extract_positive() 683 | 684 | # gather negative examples 685 | if not os.path.exists(os.path.join(data_dir, 'negative_candidates.tsv')): 686 | negative_examples() 687 | extract_negative() 688 | 689 | 690 | 691 | 692 | if __name__ == '__main__': 693 | main() 694 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ########################################################## 4 | # 5 | # Attention-based Convolutional Neural Network 6 | # for Context-wise Learning 7 | # 8 | # 9 | # Note: this implementation is mostly based on 10 | # https://github.com/yuhaozhang/sentence-convnet/blob/master/eval.py 11 | # 12 | ########################################################## 13 | 14 | from datetime import datetime 15 | import os 16 | import tensorflow as tf 17 | import numpy as np 18 | import util 19 | 20 | 21 | FLAGS = tf.app.flags.FLAGS 22 | this_dir = os.path.abspath(os.path.dirname(__file__)) 23 | tf.app.flags.DEFINE_string('train_dir', os.path.join(this_dir, 'models', 'er-cnn'), 'Directory of the checkpoint files') 24 | 25 | 26 | def evaluate(eval_data, config): 27 | """ Build evaluation graph and run. """ 28 | 29 | with tf.Graph().as_default(): 30 | with tf.variable_scope('cnn'): 31 | if config.has_key('contextwise') and config['contextwise']: 32 | import cnn_context 33 | m = cnn_context.Model(config, is_train=False) 34 | else: 35 | import cnn 36 | m = cnn.Model(config, is_train=False) 37 | saver = tf.train.Saver(tf.global_variables()) 38 | 39 | with tf.Session() as sess: 40 | ckpt = tf.train.get_checkpoint_state(config['train_dir']) 41 | if ckpt and ckpt.model_checkpoint_path: 42 | saver.restore(sess, ckpt.model_checkpoint_path) 43 | else: 44 | raise IOError("Loading checkpoint file failed!") 45 | 46 | print "\nStart evaluation on test set ...\n" 47 | if config.has_key('contextwise') and config['contextwise']: 48 | left_batch, middle_batch, right_batch, y_batch, _ = zip(*eval_data) 49 | feed = {m.left: np.array(left_batch), 50 | m.middle: np.array(middle_batch), 51 | m.right: np.array(right_batch), 52 | m.labels: np.array(y_batch)} 53 | else: 54 | x_batch, y_batch, _ = zip(*eval_data) 55 | feed = {m.inputs: np.array(x_batch), m.labels: np.array(y_batch)} 56 | loss, eval = sess.run([m.total_loss, m.eval_op], feed_dict=feed) 57 | pre, rec = zip(*eval) 58 | 59 | auc = util.calc_auc_pr(pre, rec) 60 | f1 = (2.0 * pre[5] * rec[5]) / (pre[5] + rec[5]) 61 | print '%s: loss = %.6f, p = %.4f, r = %4.4f, f1 = %.4f, auc = %.4f' % (datetime.now(), loss, 62 | pre[5], rec[5], f1, auc) 63 | return pre, rec 64 | 65 | 66 | def main(argv=None): 67 | restore_param = util.load_from_dump(os.path.join(FLAGS.train_dir, 'flags.cPickle')) 68 | restore_param['train_dir'] = FLAGS.train_dir 69 | 70 | if restore_param.has_key('contextwise') and restore_param['contextwise']: 71 | source_path = os.path.join(restore_param['data_dir'], "ids") 72 | target_path = os.path.join(restore_param['data_dir'], "target.txt") 73 | _, data = util.read_data_contextwise(source_path, target_path, restore_param['sent_len'], 74 | train_size=restore_param['train_size']) 75 | else: 76 | source_path = os.path.join(restore_param['data_dir'], "ids.txt") 77 | target_path = os.path.join(restore_param['data_dir'], "target.txt") 78 | _, data = util.read_data(source_path, target_path, restore_param['sent_len'], 79 | train_size=restore_param['train_size']) 80 | 81 | pre, rec = evaluate(data, restore_param) 82 | util.dump_to_file(os.path.join(FLAGS.train_dir, 'results.cPickle'), {'precision': pre, 'recall': rec}) 83 | 84 | 85 | if __name__ == '__main__': 86 | tf.app.run() 87 | -------------------------------------------------------------------------------- /img/auc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/auc.png -------------------------------------------------------------------------------- /img/cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/cnn.png -------------------------------------------------------------------------------- /img/emb_er.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_er.png -------------------------------------------------------------------------------- /img/emb_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_left.png -------------------------------------------------------------------------------- /img/emb_mlmi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_mlmi.png -------------------------------------------------------------------------------- /img/emb_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/emb_right.png -------------------------------------------------------------------------------- /img/f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/f1.png -------------------------------------------------------------------------------- /img/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/loss.png -------------------------------------------------------------------------------- /img/pr_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/may-/cnn-re-tf/b18c9b5a71861658ecf932851edc334fdb616f87/img/pr_curve.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from datetime import datetime 4 | import time 5 | import os 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | import cnn 10 | import util 11 | 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | 15 | # train parameters 16 | this_dir = os.path.abspath(os.path.dirname(__file__)) 17 | tf.app.flags.DEFINE_string('data_dir', os.path.join(this_dir, 'data'), 'Directory of the data') 18 | tf.app.flags.DEFINE_string('train_dir', os.path.join(this_dir, 'train'), 19 | 'Directory to save training checkpoint files') 20 | tf.app.flags.DEFINE_integer('train_size', 100000, 'Number of training examples') 21 | tf.app.flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to run') 22 | tf.app.flags.DEFINE_boolean('use_pretrain', False, 'Use word2vec pretrained embeddings or not') 23 | 24 | tf.app.flags.DEFINE_string('optimizer', 'adam', 25 | 'Optimizer to use. Must be one of "sgd", "adagrad", "adadelta" and "adam"') 26 | tf.app.flags.DEFINE_float('init_lr', 1e-3, 'Initial learning rate') 27 | tf.app.flags.DEFINE_float('lr_decay', 0.95, 'LR decay rate') 28 | tf.app.flags.DEFINE_integer('tolerance_step', 500, 29 | 'Decay the lr after loss remains unchanged for this number of steps') 30 | tf.app.flags.DEFINE_float('dropout', 0.5, 'Dropout rate. 0 is no dropout.') 31 | 32 | # logging 33 | tf.app.flags.DEFINE_integer('log_step', 10, 'Display log to stdout after this step') 34 | tf.app.flags.DEFINE_integer('summary_step', 50, 35 | 'Write summary (evaluate model on dev set) after this step') 36 | tf.app.flags.DEFINE_integer('checkpoint_step', 100, 'Save model after this step') 37 | 38 | 39 | def train(train_data, test_data): 40 | # train_dir 41 | timestamp = str(int(time.time())) 42 | out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, timestamp)) 43 | 44 | # save flags 45 | if not os.path.exists(out_dir): 46 | os.mkdir(out_dir) 47 | FLAGS._parse_flags() 48 | config = dict(FLAGS.__flags.items()) 49 | 50 | # Window_size must not be larger than the sent_len 51 | if config['sent_len'] < config['max_window']: 52 | config['max_window'] = config['sent_len'] 53 | 54 | util.dump_to_file(os.path.join(out_dir, 'flags.cPickle'), config) 55 | print "Parameters:" 56 | for k, v in config.iteritems(): 57 | print '%20s %r' % (k, v) 58 | 59 | num_batches_per_epoch = int(np.ceil(float(len(train_data))/FLAGS.batch_size)) 60 | max_steps = num_batches_per_epoch * FLAGS.num_epochs 61 | 62 | with tf.Graph().as_default(): 63 | with tf.variable_scope('cnn', reuse=None): 64 | m = cnn.Model(config, is_train=True) 65 | with tf.variable_scope('cnn', reuse=True): 66 | mtest = cnn.Model(config, is_train=False) 67 | 68 | # checkpoint 69 | saver = tf.train.Saver(tf.global_variables()) 70 | save_path = os.path.join(out_dir, 'model.ckpt') 71 | summary_op = tf.summary.merge_all() 72 | 73 | # session 74 | with tf.Session().as_default() as sess: 75 | proj_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig() 76 | embedding = proj_config.embeddings.add() 77 | embedding.tensor_name = m.W_emb.name 78 | embedding.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt') 79 | 80 | train_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "train"), graph=sess.graph) 81 | dev_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "dev"), graph=sess.graph) 82 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_summary_writer, proj_config) 83 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(dev_summary_writer, proj_config) 84 | 85 | sess.run(tf.global_variables_initializer()) 86 | 87 | # assign pretrained embeddings 88 | if FLAGS.use_pretrain: 89 | print "Initialize model with pretrained embeddings..." 90 | pretrained_embedding = np.load(os.path.join(FLAGS.data_dir, 'emb.npy')) 91 | m.assign_embedding(sess, pretrained_embedding) 92 | 93 | # initialize parameters 94 | current_lr = FLAGS.init_lr 95 | lowest_loss_value = float("inf") 96 | decay_step_counter = 0 97 | global_step = 0 98 | 99 | # evaluate on dev set 100 | def dev_step(mtest, sess): 101 | dev_loss = [] 102 | dev_auc = [] 103 | dev_f1_score = [] 104 | 105 | # create batch 106 | test_batches = util.batch_iter(test_data, batch_size=FLAGS.batch_size, num_epochs=1, shuffle=False) 107 | for batch in test_batches: 108 | x_batch, y_batch, _ = zip(*batch) 109 | loss_value, eval_value = sess.run([mtest.total_loss, mtest.eval_op], 110 | feed_dict={mtest.inputs: np.array(x_batch), mtest.labels: np.array(y_batch)}) 111 | dev_loss.append(loss_value) 112 | pre, rec = zip(*eval_value) 113 | dev_auc.append(util.calc_auc_pr(pre, rec)) 114 | dev_f1_score.append((2.0 * pre[5] * rec[5]) / (pre[5] + rec[5])) # threshold = 0.5 115 | 116 | return np.mean(dev_loss), np.mean(dev_auc), np.mean(dev_f1_score) 117 | 118 | # train loop 119 | print "\nStart training (save checkpoints in %s)\n" % out_dir 120 | train_loss = [] 121 | train_auc = [] 122 | train_f1_score = [] 123 | train_batches = util.batch_iter(train_data, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs) 124 | for batch in train_batches: 125 | batch_size = len(batch) 126 | 127 | m.assign_lr(sess, current_lr) 128 | global_step += 1 129 | 130 | x_batch, y_batch, a_batch = zip(*batch) 131 | feed = {m.inputs: np.array(x_batch), m.labels: np.array(y_batch)} 132 | if FLAGS.attention: 133 | feed[m.attention] = np.array(a_batch) 134 | start_time = time.time() 135 | _, loss_value, eval_value = sess.run([m.train_op, m.total_loss, m.eval_op], feed_dict=feed) 136 | proc_duration = time.time() - start_time 137 | train_loss.append(loss_value) 138 | pre, rec = zip(*eval_value) 139 | auc = util.calc_auc_pr(pre, rec) 140 | f1 = (2.0 * pre[5] * rec[5]) / (pre[5] + rec[5]) # threshold = 0.5 141 | train_auc.append(auc) 142 | train_f1_score.append(f1) 143 | 144 | assert not np.isnan(loss_value), "Model loss is NaN." 145 | 146 | # print log 147 | if global_step % FLAGS.log_step == 0: 148 | examples_per_sec = batch_size / proc_duration 149 | format_str = '%s: step %d/%d, f1 = %.4f, auc = %.4f, loss = %.4f ' + \ 150 | '(%.1f examples/sec; %.3f sec/batch), lr: %.6f' 151 | print format_str % (datetime.now(), global_step, max_steps, f1, auc, loss_value, 152 | examples_per_sec, proc_duration, current_lr) 153 | 154 | # write summary 155 | if global_step % FLAGS.summary_step == 0: 156 | summary_str = sess.run(summary_op) 157 | train_summary_writer.add_summary(summary_str, global_step) 158 | dev_summary_writer.add_summary(summary_str, global_step) 159 | 160 | # summary loss, f1 161 | train_summary_writer.add_summary( 162 | _summary_for_scalar('loss', np.mean(train_loss)), global_step=global_step) 163 | train_summary_writer.add_summary( 164 | _summary_for_scalar('auc', np.mean(train_auc)), global_step=global_step) 165 | train_summary_writer.add_summary( 166 | _summary_for_scalar('f1', np.mean(train_f1_score)), global_step=global_step) 167 | 168 | dev_loss, dev_auc, dev_f1 = dev_step(mtest, sess) 169 | dev_summary_writer.add_summary( 170 | _summary_for_scalar('loss', dev_loss), global_step=global_step) 171 | dev_summary_writer.add_summary( 172 | _summary_for_scalar('auc', dev_auc), global_step=global_step) 173 | dev_summary_writer.add_summary( 174 | _summary_for_scalar('f1', dev_f1), global_step=global_step) 175 | 176 | print "\n===== write summary =====" 177 | print "%s: step %d/%d: train_loss = %.6f, train_auc = %.4f, train_f1 = %.4f" \ 178 | % (datetime.now(), global_step, max_steps, 179 | np.mean(train_loss), np.mean(train_auc), np.mean(train_f1_score)) 180 | print "%s: step %d/%d: dev_loss = %.6f, dev_auc = %.4f, dev_f1 = %.4f\n" \ 181 | % (datetime.now(), global_step, max_steps, dev_loss, dev_auc, dev_f1) 182 | 183 | # reset container 184 | train_loss = [] 185 | train_auc = [] 186 | train_f1_score = [] 187 | 188 | # decay learning rate if necessary 189 | if loss_value < lowest_loss_value: 190 | lowest_loss_value = loss_value 191 | decay_step_counter = 0 192 | else: 193 | decay_step_counter += 1 194 | if decay_step_counter >= FLAGS.tolerance_step: 195 | current_lr *= FLAGS.lr_decay 196 | print '%s: step %d/%d, Learning rate decays to %.5f' % \ 197 | (datetime.now(), global_step, max_steps, current_lr) 198 | decay_step_counter = 0 199 | 200 | # stop learning if learning rate is too low 201 | if current_lr < 1e-5: 202 | break 203 | 204 | # save checkpoint 205 | if global_step % FLAGS.checkpoint_step == 0: 206 | saver.save(sess, save_path, global_step=global_step) 207 | saver.save(sess, save_path, global_step=global_step) 208 | 209 | 210 | def _summary_for_scalar(name, value): 211 | return tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=float(value))]) 212 | 213 | 214 | def main(argv=None): 215 | if not os.path.exists(FLAGS.train_dir): 216 | os.mkdir(FLAGS.train_dir) 217 | 218 | # load dataset 219 | source_path = os.path.join(FLAGS.data_dir, 'ids.txt') 220 | target_path = os.path.join(FLAGS.data_dir, 'target.txt') 221 | attention_path = None 222 | if FLAGS.attention: 223 | if os.path.exists(os.path.join(FLAGS.data_dir, 'source.att')): 224 | attention_path = os.path.join(FLAGS.data_dir, 'source.att') 225 | else: 226 | raise ValueError("Attention file %s not found.", os.path.join(FLAGS.data_dir, 'source.att')) 227 | train_data, test_data = util.read_data(source_path, target_path, FLAGS.sent_len, 228 | attention_path=attention_path, train_size=FLAGS.train_size) 229 | train(train_data, test_data) 230 | 231 | 232 | if __name__ == '__main__': 233 | tf.app.run() 234 | -------------------------------------------------------------------------------- /train_context.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from datetime import datetime 4 | import time 5 | import os 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | import cnn_context 10 | import util 11 | 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | 15 | # train parameters 16 | this_dir = os.path.abspath(os.path.dirname(__file__)) 17 | tf.app.flags.DEFINE_string('data_dir', os.path.join(this_dir, 'data', 'mlmi'), 'Directory of the data') 18 | tf.app.flags.DEFINE_string('train_dir', os.path.join(this_dir, 'train'), 19 | 'Directory to save training checkpoint files') 20 | tf.app.flags.DEFINE_integer('train_size', 100000, 'Number of training examples') 21 | tf.app.flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to run') 22 | tf.app.flags.DEFINE_boolean('use_pretrain', False, 'Use word2vec pretrained embeddings or not') 23 | 24 | tf.app.flags.DEFINE_string('optimizer', 'adam', 25 | 'Optimizer to use. Must be one of "sgd", "adagrad", "adadelta" and "adam"') 26 | tf.app.flags.DEFINE_float('init_lr', 1e-3, 'Initial learning rate') 27 | tf.app.flags.DEFINE_float('lr_decay', 0.95, 'LR decay rate') 28 | tf.app.flags.DEFINE_integer('tolerance_step', 500, 29 | 'Decay the lr after loss remains unchanged for this number of steps') 30 | tf.app.flags.DEFINE_float('dropout', 0.5, 'Dropout rate. 0 is no dropout.') 31 | 32 | # logging 33 | tf.app.flags.DEFINE_integer('log_step', 10, 'Display log to stdout after this step') 34 | tf.app.flags.DEFINE_integer('summary_step', 50, 35 | 'Write summary (evaluate model on dev set) after this step') 36 | tf.app.flags.DEFINE_integer('checkpoint_step', 100, 'Save model after this step') 37 | 38 | 39 | def train(train_data, test_data): 40 | # train_dir 41 | timestamp = str(int(time.time())) 42 | out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, timestamp)) 43 | 44 | # save flags 45 | if not os.path.exists(out_dir): 46 | os.mkdir(out_dir) 47 | FLAGS._parse_flags() 48 | config = dict(FLAGS.__flags.items()) 49 | 50 | # Window_size must not be larger than the sent_len 51 | if config['sent_len'] < config['max_window']: 52 | config['max_window'] = config['sent_len'] 53 | 54 | # flag to restore the contextwise model 55 | config['contextwise'] = True 56 | 57 | # save flags 58 | util.dump_to_file(os.path.join(out_dir, 'flags.cPickle'), config) 59 | print "Parameters:" 60 | for k, v in config.iteritems(): 61 | print '%20s %r' % (k, v) 62 | 63 | # max number of steps 64 | num_batches_per_epoch = int(np.ceil(float(len(train_data))/FLAGS.batch_size)) 65 | max_steps = num_batches_per_epoch * FLAGS.num_epochs 66 | 67 | with tf.Graph().as_default(): 68 | with tf.variable_scope('cnn', reuse=None): 69 | m = cnn_context.Model(config, is_train=True) 70 | with tf.variable_scope('cnn', reuse=True): 71 | mtest = cnn_context.Model(config, is_train=False) 72 | 73 | # checkpoint 74 | saver = tf.train.Saver(tf.global_variables()) 75 | save_path = os.path.join(out_dir, 'model.ckpt') 76 | summary_op = tf.summary.merge_all() 77 | 78 | # session 79 | with tf.Session().as_default() as sess: 80 | proj_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig() 81 | embedding_left = proj_config.embeddings.add() 82 | embedding_middle = proj_config.embeddings.add() 83 | embedding_right = proj_config.embeddings.add() 84 | embedding_left.tensor_name = m.W_emb_left.name 85 | embedding_middle.tensor_name = m.W_emb_middle.name 86 | embedding_right.tensor_name = m.W_emb_right.name 87 | embedding_left.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt') 88 | embedding_middle.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt') 89 | embedding_right.metadata_path = os.path.join(FLAGS.data_dir, 'vocab.txt') 90 | 91 | train_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "train"), graph=sess.graph) 92 | dev_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "dev"), graph=sess.graph) 93 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_summary_writer, proj_config) 94 | tf.contrib.tensorboard.plugins.projector.visualize_embeddings(dev_summary_writer, proj_config) 95 | 96 | sess.run(tf.global_variables_initializer()) 97 | 98 | # assign pretrained embeddings 99 | if FLAGS.use_pretrain: 100 | print "Initialize model with pretrained embeddings..." 101 | pretrained_embedding = np.load(os.path.join(FLAGS.data_dir, 'emb.npy')) 102 | m.assign_embedding(sess, pretrained_embedding) 103 | 104 | # initialize parameters 105 | current_lr = FLAGS.init_lr 106 | lowest_loss_value = float("inf") 107 | decay_step_counter = 0 108 | global_step = 0 109 | 110 | # evaluate on dev set 111 | def dev_step(mtest, sess): 112 | dev_loss = [] 113 | dev_auc = [] 114 | dev_f1_score = [] 115 | 116 | # create batch 117 | test_batches = util.batch_iter(test_data, batch_size=FLAGS.batch_size, num_epochs=1, shuffle=False) 118 | for batch in test_batches: 119 | left_batch, middle_batch, right_batch, y_batch, _ = zip(*batch) 120 | feed = {mtest.left: np.array(left_batch), 121 | mtest.middle: np.array(middle_batch), 122 | mtest.right: np.array(right_batch), 123 | mtest.labels: np.array(y_batch)} 124 | loss_value, eval_value = sess.run([mtest.total_loss, mtest.eval_op], feed_dict=feed) 125 | dev_loss.append(loss_value) 126 | pre, rec = zip(*eval_value) 127 | dev_auc.append(util.calc_auc_pr(pre, rec)) 128 | dev_f1_score.append((2.0 * pre[5] * rec[5]) / (pre[5] + rec[5])) # threshold = 0.5 129 | 130 | return np.mean(dev_loss), np.mean(dev_auc), np.mean(dev_f1_score) 131 | 132 | # train loop 133 | print "\nStart training (save checkpoints in %s)\n" % out_dir 134 | train_loss = [] 135 | train_auc = [] 136 | train_f1_score = [] 137 | train_batches = util.batch_iter(train_data, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs) 138 | for batch in train_batches: 139 | batch_size = len(batch) 140 | 141 | m.assign_lr(sess, current_lr) 142 | global_step += 1 143 | 144 | left_batch, middle_batch, right_batch, y_batch, a_batch = zip(*batch) 145 | feed = {m.left: np.array(left_batch), 146 | m.middle: np.array(middle_batch), 147 | m.right: np.array(right_batch), 148 | m.labels: np.array(y_batch)} 149 | if FLAGS.attention: 150 | feed[m.attention] = np.array(a_batch) 151 | start_time = time.time() 152 | _, loss_value, eval_value = sess.run([m.train_op, m.total_loss, m.eval_op], feed_dict=feed) 153 | proc_duration = time.time() - start_time 154 | train_loss.append(loss_value) 155 | pre, rec = zip(*eval_value) 156 | auc = util.calc_auc_pr(pre, rec) 157 | f1 = (2.0 * pre[5] * rec[5]) / (pre[5] + rec[5]) # threshold = 0.5 158 | train_auc.append(auc) 159 | train_f1_score.append(f1) 160 | 161 | assert not np.isnan(loss_value), "Model loss is NaN." 162 | 163 | # print log 164 | if global_step % FLAGS.log_step == 0: 165 | examples_per_sec = batch_size / proc_duration 166 | format_str = '%s: step %d/%d, f1 = %.4f, auc = %.4f, loss = %.4f ' + \ 167 | '(%.1f examples/sec; %.3f sec/batch), lr: %.6f' 168 | print format_str % (datetime.now(), global_step, max_steps, f1, auc, loss_value, 169 | examples_per_sec, proc_duration, current_lr) 170 | 171 | # write summary 172 | if global_step % FLAGS.summary_step == 0: 173 | summary_str = sess.run(summary_op) 174 | train_summary_writer.add_summary(summary_str, global_step) 175 | dev_summary_writer.add_summary(summary_str, global_step) 176 | 177 | # summary loss, f1 178 | train_summary_writer.add_summary( 179 | _summary_for_scalar('loss', np.mean(train_loss)), global_step=global_step) 180 | train_summary_writer.add_summary( 181 | _summary_for_scalar('auc', np.mean(train_auc)), global_step=global_step) 182 | train_summary_writer.add_summary( 183 | _summary_for_scalar('f1', np.mean(train_f1_score)), global_step=global_step) 184 | 185 | dev_loss, dev_auc, dev_f1 = dev_step(mtest, sess) 186 | dev_summary_writer.add_summary( 187 | _summary_for_scalar('loss', dev_loss), global_step=global_step) 188 | dev_summary_writer.add_summary( 189 | _summary_for_scalar('auc', dev_auc), global_step=global_step) 190 | dev_summary_writer.add_summary( 191 | _summary_for_scalar('f1', dev_f1), global_step=global_step) 192 | 193 | print "\n===== write summary =====" 194 | print "%s: step %d/%d: train_loss = %.6f, train_auc = %.4f train_f1 = %.4f" \ 195 | % (datetime.now(), global_step, max_steps, 196 | np.mean(train_loss), np.mean(train_auc), np.mean(train_f1_score)) 197 | print "%s: step %d/%d: dev_loss = %.6f, dev_auc = %.4f dev_f1 = %.4f\n" \ 198 | % (datetime.now(), global_step, max_steps, dev_loss, dev_auc, dev_f1) 199 | 200 | # reset container 201 | train_loss = [] 202 | train_auc = [] 203 | train_f1_score = [] 204 | 205 | # decay learning rate if necessary 206 | if loss_value < lowest_loss_value: 207 | lowest_loss_value = loss_value 208 | decay_step_counter = 0 209 | else: 210 | decay_step_counter += 1 211 | if decay_step_counter >= FLAGS.tolerance_step: 212 | current_lr *= FLAGS.lr_decay 213 | print '%s: step %d/%d, Learning rate decays to %.5f' % \ 214 | (datetime.now(), global_step, max_steps, current_lr) 215 | decay_step_counter = 0 216 | 217 | # stop learning if learning rate is too low 218 | if current_lr < 1e-5: 219 | break 220 | 221 | # save checkpoint 222 | if global_step % FLAGS.checkpoint_step == 0: 223 | saver.save(sess, save_path, global_step=global_step) 224 | saver.save(sess, save_path, global_step=global_step) 225 | 226 | 227 | def _summary_for_scalar(name, value): 228 | return tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=float(value))]) 229 | 230 | 231 | def main(argv=None): 232 | if not os.path.exists(FLAGS.train_dir): 233 | os.mkdir(FLAGS.train_dir) 234 | 235 | # load contextwise dataset 236 | source_path = os.path.join(FLAGS.data_dir, 'ids') 237 | target_path = os.path.join(FLAGS.data_dir, 'target.txt') 238 | attention_path = None 239 | if FLAGS.attention: 240 | if os.path.exists(os.path.join(FLAGS.data_dir, 'source.att')): 241 | attention_path = os.path.join(FLAGS.data_dir, 'source.att') 242 | else: 243 | raise ValueError("Attention file %s not found.", os.path.join(FLAGS.data_dir, 'source.att')) 244 | train_data, test_data = util.read_data_contextwise(source_path, target_path, FLAGS.sent_len, 245 | attention_path=attention_path, train_size=FLAGS.train_size) 246 | train(train_data, test_data) 247 | 248 | 249 | if __name__ == '__main__': 250 | tf.app.run() 251 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ########################################################## 4 | # 5 | # Helper functions to load data 6 | # 7 | ########################################################### 8 | 9 | import os 10 | import re 11 | from codecs import open as codecs_open 12 | import cPickle as pickle 13 | import numpy as np 14 | 15 | 16 | # Special vocabulary symbols. 17 | PAD_TOKEN = '' # pad symbol 18 | UNK_TOKEN = '' # unknown word 19 | BOS_TOKEN = '' # begin-of-sentence symbol 20 | EOS_TOKEN = '' # end-of-sentence symbol 21 | NUM_TOKEN = '' # numbers 22 | 23 | # we always put them at the start. 24 | _START_VOCAB = [PAD_TOKEN, UNK_TOKEN] 25 | PAD_ID = 0 26 | UNK_ID = 1 27 | 28 | # Regular expressions used to tokenize. 29 | _DIGIT_RE = re.compile(br"^\d+$") 30 | 31 | 32 | THIS_DIR = os.path.abspath(os.path.dirname(__file__)) 33 | RANDOM_SEED = 1234 34 | 35 | 36 | def basic_tokenizer(sequence, bos=True, eos=True): 37 | sequence = re.sub(r'\s{2}', ' ' + EOS_TOKEN + ' ' + BOS_TOKEN + ' ', sequence) 38 | if bos: 39 | sequence = BOS_TOKEN + ' ' + sequence.strip() 40 | if eos: 41 | sequence = sequence.strip() + ' ' + EOS_TOKEN 42 | return sequence.lower().split() 43 | 44 | 45 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size=40000, tokenizer=None, bos=True, eos=True): 46 | """Create vocabulary file (if it does not exist yet) from data file. 47 | 48 | Original taken from 49 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py 50 | """ 51 | if not os.path.exists(vocabulary_path): 52 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 53 | vocab = {} 54 | with codecs_open(data_path, "rb", encoding="utf-8") as f: 55 | for line in f.readlines(): 56 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line, bos, eos) 57 | for w in tokens: 58 | word = re.sub(_DIGIT_RE, NUM_TOKEN, w) 59 | if word in vocab: 60 | vocab[word] += 1 61 | else: 62 | vocab[word] = 1 63 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 64 | if len(vocab_list) > max_vocabulary_size: 65 | print(" %d words found. Truncate to %d." % (len(vocab_list), max_vocabulary_size)) 66 | vocab_list = vocab_list[:max_vocabulary_size] 67 | with codecs_open(vocabulary_path, "wb", encoding="utf-8") as vocab_file: 68 | for w in vocab_list: 69 | vocab_file.write(w + b"\n") 70 | 71 | 72 | def initialize_vocabulary(vocabulary_path): 73 | """Initialize vocabulary from file. 74 | 75 | Original taken from 76 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py 77 | """ 78 | if os.path.exists(vocabulary_path): 79 | rev_vocab = [] 80 | with codecs_open(vocabulary_path, "rb", encoding="utf-8") as f: 81 | rev_vocab.extend(f.readlines()) 82 | rev_vocab = [line.strip() for line in rev_vocab] 83 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 84 | return vocab, rev_vocab 85 | else: 86 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 87 | 88 | 89 | def sentence_to_token_ids(sentence, vocabulary, tokenizer=None, bos=True, eos=True): 90 | """Convert a string to list of integers representing token-ids. 91 | 92 | Original taken from 93 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py 94 | """ 95 | words = tokenizer(sentence) if tokenizer else basic_tokenizer(sentence, bos, eos) 96 | return [vocabulary.get(re.sub(_DIGIT_RE, NUM_TOKEN, w), UNK_ID) for w in words] 97 | 98 | 99 | def data_to_token_ids(data_path, target_path, vocabulary_path, tokenizer=None, bos=True, eos=True): 100 | """Tokenize data file and turn into token-ids using given vocabulary file. 101 | 102 | Original taken from 103 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/data_utils.py 104 | """ 105 | if not os.path.exists(target_path): 106 | print("Vectorizing data in %s" % data_path) 107 | vocab, _ = initialize_vocabulary(vocabulary_path) 108 | with codecs_open(data_path, "rb", encoding="utf-8") as data_file: 109 | with codecs_open(target_path, "wb", encoding="utf-8") as tokens_file: 110 | for line in data_file: 111 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, bos, eos) 112 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 113 | 114 | 115 | def shuffle_split(X, y, a=None, train_size=10000, shuffle=True): 116 | """Shuffle and split data into train and test subset""" 117 | _X = np.array(X) 118 | _y = np.array(y) 119 | assert _X.shape[0] == _y.shape[0] 120 | 121 | _a = [None] * _y.shape[0] 122 | if a is not None and len(a) == len(y): 123 | _a = np.array(a) 124 | # compute softmax 125 | _a = np.reshape(np.exp(_a) / np.sum(np.exp(_a)), (_y.shape[0], 1)) 126 | assert _a.shape[0] == _y.shape[0] 127 | 128 | print "Splitting data...", 129 | # split train-test 130 | data = np.array(zip(_X, _y, _a)) 131 | data_size = _y.shape[0] 132 | if train_size > data_size: 133 | train_size = int(data_size * 0.9) 134 | if shuffle: 135 | np.random.seed(RANDOM_SEED) 136 | shuffle_indices = np.random.permutation(np.arange(data_size)) 137 | shuffled_data = data[shuffle_indices] 138 | else: 139 | shuffled_data = data 140 | print "\t%d for train, %d for test" % (train_size, data_size - train_size) 141 | return shuffled_data[:train_size], shuffled_data[train_size:] 142 | 143 | 144 | def read_data(source_path, target_path, sent_len, attention_path=None, train_size=10000, shuffle=True): 145 | """Read source(x), target(y) and attention if given. 146 | 147 | Original taken from 148 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/translate.py 149 | """ 150 | _X = [] 151 | _y = [] 152 | with codecs_open(source_path, mode="r", encoding="utf-8") as source_file: 153 | with codecs_open(target_path, mode="r", encoding="utf-8") as target_file: 154 | source, target = source_file.readline(), target_file.readline() 155 | #counter = 0 156 | print "Loading data...", 157 | while source and target: 158 | #counter += 1 159 | #if counter % 1000 == 0: 160 | # print(" reading data line %d" % counter) 161 | # sys.stdout.flush() 162 | source_ids = [np.int64(x.strip()) for x in source.split()] 163 | if sent_len > len(source_ids): 164 | source_ids += [PAD_ID] * (sent_len - len(source_ids)) 165 | assert len(source_ids) == sent_len 166 | 167 | #target = target.split('\t')[0].strip() 168 | target_ids = [np.float32(y.strip()) for y in target.split()] 169 | 170 | _X.append(source_ids) 171 | _y.append(target_ids) 172 | source, target = source_file.readline(), target_file.readline() 173 | 174 | assert len(_X) == len(_y) 175 | print "\t%d examples found." % len(_y) 176 | 177 | _a = None 178 | if attention_path is not None: 179 | with codecs_open(attention_path, mode="r", encoding="utf-8") as att_file: 180 | _a = [np.float32(att.strip()) for att in att_file.readlines()] 181 | assert len(_a) == len(_y) 182 | 183 | return shuffle_split(_X, _y, a=_a, train_size=train_size, shuffle=shuffle) 184 | 185 | 186 | def shuffle_split_contextwise(X, y, a=None, train_size=10000, shuffle=True): 187 | """Shuffle and split data into train and test subset""" 188 | 189 | _left = np.array(X['left']) 190 | _middle = np.array(X['middle']) 191 | _right = np.array(X['right']) 192 | _y = np.array(y) 193 | 194 | _a = [None] * _y.shape[0] 195 | if a is not None and len(a) == len(y): 196 | _a = np.array(a) 197 | # compute softmax 198 | _a = np.reshape(np.exp(_a) / np.sum(np.exp(_a)), (_y.shape[0], 1)) 199 | assert _a.shape[0] == _y.shape[0] 200 | 201 | print "Splitting data...", 202 | # split train-test 203 | data = np.array(zip(_left, _middle, _right, _y, _a)) 204 | data_size = _y.shape[0] 205 | if train_size > data_size: 206 | train_size = int(data_size * 0.9) 207 | if shuffle: 208 | np.random.seed(RANDOM_SEED) 209 | shuffle_indices = np.random.permutation(np.arange(data_size)) 210 | shuffled_data = data[shuffle_indices] 211 | else: 212 | shuffled_data = data 213 | print "\t%d for train, %d for test" % (train_size, data_size - train_size) 214 | return shuffled_data[:train_size], shuffled_data[train_size:] 215 | 216 | 217 | def read_data_contextwise(source_path, target_path, sent_len, attention_path=None, train_size=10000, shuffle=True): 218 | """Read source file and pad the sequence to sent_len, 219 | combine them with target (and attention if given). 220 | 221 | Original taken from 222 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/translate.py 223 | """ 224 | print "Loading data...", 225 | _X = {'left': [], 'middle': [], 'right': []} 226 | for context in _X.keys(): 227 | path = '%s.%s' % (source_path, context) 228 | with codecs_open(path, mode="r", encoding="utf-8") as source_file: 229 | for source in source_file.readlines(): 230 | source_ids = [np.int64(x.strip()) for x in source.split()] 231 | if sent_len > len(source_ids): 232 | source_ids += [PAD_ID] * (sent_len - len(source_ids)) 233 | assert len(source_ids) == sent_len 234 | _X[context].append(source_ids) 235 | assert len(_X['left']) == len(_X['middle']) 236 | assert len(_X['right']) == len(_X['middle']) 237 | 238 | _y = [] 239 | with codecs_open(target_path, mode="r", encoding="utf-8") as target_file: 240 | for target in target_file.readlines(): 241 | target_ids = [np.float32(y.strip()) for y in target.split()] 242 | _y.append(target_ids) 243 | assert len(_X['left']) == len(_y) 244 | print "\t%d examples found." % len(_y) 245 | 246 | _a = None 247 | if attention_path is not None: 248 | with codecs_open(attention_path, mode="r", encoding="utf-8") as att_file: 249 | _a = [np.float32(att.strip()) for att in att_file.readlines()] 250 | assert len(_a) == len(_y) 251 | 252 | return shuffle_split_contextwise(_X, _y, a=_a, train_size=train_size, shuffle=shuffle) 253 | 254 | 255 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 256 | """Generates a batch iterator. 257 | 258 | Original taken from 259 | https://github.com/dennybritz/cnn-text-classification-tf/blob/master/data_helpers.py 260 | """ 261 | data = np.array(data) 262 | data_size = len(data) 263 | num_batches_per_epoch = int(np.ceil(float(data_size)/batch_size)) 264 | for epoch in range(num_epochs): 265 | # Shuffle data at each epoch 266 | if shuffle: 267 | #np.random.seed(RANDOM_SEED) 268 | shuffle_indices = np.random.permutation(np.arange(data_size)) 269 | shuffled_data = data[shuffle_indices] 270 | else: 271 | shuffled_data = data 272 | for batch_num in range(num_batches_per_epoch): 273 | start_index = batch_num * batch_size 274 | end_index = min((batch_num + 1) * batch_size, data_size) 275 | yield shuffled_data[start_index:end_index] 276 | 277 | 278 | def dump_to_file(filename, obj): 279 | with open(filename, 'wb') as outfile: 280 | pickle.dump(obj, file=outfile) 281 | return 282 | 283 | 284 | def load_from_dump(filename): 285 | with open(filename, 'rb') as infile: 286 | obj = pickle.load(infile) 287 | return obj 288 | 289 | 290 | def _load_bin_vec(fname, vocab): 291 | """ 292 | Loads 300x1 word vecs from Google (Mikolov) word2vec 293 | 294 | Original taken from 295 | https://github.com/yuhaozhang/sentence-convnet/blob/master/text_input.py 296 | """ 297 | word_vecs = {} 298 | with open(fname, "rb") as f: 299 | header = f.readline() 300 | vocab_size, layer1_size = map(int, header.split()) 301 | binary_len = np.dtype('float32').itemsize * layer1_size 302 | for line in xrange(vocab_size): 303 | word = [] 304 | while True: 305 | ch = f.read(1) 306 | if ch == ' ': 307 | word = ''.join(word) 308 | break 309 | if ch != '\n': 310 | word.append(ch) 311 | if word in vocab: 312 | word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32') 313 | else: 314 | f.read(binary_len) 315 | return (word_vecs, layer1_size) 316 | 317 | 318 | def _add_random_vec(word_vecs, vocab, emb_size=300): 319 | for word in vocab: 320 | if word not in word_vecs: 321 | word_vecs[word] = np.random.uniform(-0.25,0.25,emb_size) 322 | return word_vecs 323 | 324 | 325 | def prepare_pretrained_embedding(fname, word2id): 326 | print 'Reading pretrained word vectors from file ...' 327 | word_vecs, emb_size = _load_bin_vec(fname, word2id) 328 | word_vecs = _add_random_vec(word_vecs, word2id, emb_size) 329 | embedding = np.zeros([len(word2id), emb_size]) 330 | for w,idx in word2id.iteritems(): 331 | embedding[idx,:] = word_vecs[w] 332 | print 'Generated embeddings with shape ' + str(embedding.shape) 333 | return embedding 334 | 335 | 336 | def offset(array, pre, post): 337 | ret = np.array(array) 338 | ret = np.insert(ret, 0, pre) 339 | ret = np.append(ret, post) 340 | return ret 341 | 342 | 343 | def calc_auc_pr(precision, recall): 344 | assert len(precision) == len(recall) 345 | return np.trapz(offset(precision, 1, 0), x=offset(recall, 0, 1), dx=5) 346 | 347 | 348 | def prepare_ids(data_dir, vocab_path): 349 | for context in ['left', 'middle', 'right', 'txt']: 350 | data_path = os.path.join(data_dir, 'mlmi', 'source.%s' % context) 351 | target_path = os.path.join(data_dir, 'mlmi', 'ids.%s' % context) 352 | if context == 'left': 353 | bos, eos = True, False 354 | elif context == 'middle': 355 | bos, eos = False, False 356 | elif context == 'right': 357 | bos, eos = False, True 358 | else: 359 | bos, eos = True, True 360 | data_to_token_ids(data_path, target_path, vocab_path, bos=bos, eos=eos) 361 | 362 | 363 | def main(): 364 | data_dir = os.path.join(THIS_DIR, 'data') 365 | 366 | # multi-label multi-instance (MLMI-CNN) dataset 367 | vocab_path = os.path.join(data_dir, 'mlmi', 'vocab.txt') 368 | data_path = os.path.join(data_dir, 'mlmi', 'source.txt') 369 | max_vocab_size = 36500 370 | create_vocabulary(vocab_path, data_path, max_vocab_size) 371 | prepare_ids(data_dir, vocab_path) 372 | 373 | # pretrained embeddings 374 | embedding_path = os.path.join(THIS_DIR, 'word2vec', 'GoogleNews-vectors-negative300.bin') 375 | if os.path.exists(embedding_path): 376 | word2id, _ = initialize_vocabulary(vocab_path) 377 | embedding = prepare_pretrained_embedding(embedding_path, word2id) 378 | np.save(os.path.join(data_dir, 'mlmi', 'emb.npy'), embedding) 379 | else: 380 | print "Pretrained embeddings file %s not found." % embedding_path 381 | 382 | # single-label single-instance (ER-CNN) dataset 383 | vocab_er = os.path.join(data_dir, 'er', 'vocab.txt') 384 | data_er = os.path.join(data_dir, 'er', 'source.txt') 385 | target_er = os.path.join(data_dir, 'er', 'ids.txt') 386 | max_vocab_size = 11500 387 | tokenizer = lambda x: x.split() 388 | create_vocabulary(vocab_er, data_er, max_vocab_size, tokenizer=tokenizer) 389 | data_to_token_ids(data_er, target_er, vocab_er, tokenizer=tokenizer) 390 | 391 | 392 | if __name__ == '__main__': 393 | main() 394 | -------------------------------------------------------------------------------- /word2vec/.gitignore: -------------------------------------------------------------------------------- 1 | GoogleNews-vectors-negative300.bin 2 | --------------------------------------------------------------------------------