├── LICENSE ├── README.md ├── data ├── readme.txt └── wordnet_lexnames.txt ├── pics ├── l2_maxnorm.svg ├── link.svg ├── maxnorm.svg ├── mf1.svg ├── mf_dot.svg ├── rescal.svg ├── tf1.svg ├── tf1_small.png ├── transe.svg ├── transe2.png └── tsne_small.png ├── tf_rl_tutorial.ipynb └── tf_rl_tutorial ├── __init__.py ├── models.py ├── util.py └── wordnet_eval.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | tf_rl_tutorial 2 | =================== 3 | 4 | | ![tf](pics/tf1_small.png) | ![tsne](pics/tsne_small.png) | ![transe](pics/transe2.png) | 5 | | ------------------------- | ---------------------------- | --------------------------- | 6 | 7 | Accompanying code for "Relational Learning with TensorFlow" tutorial 8 | 9 | The tutorial can be be viewed here: http://nbviewer.jupyter.org/github/fireeye/tf_rl_tutorial/blob/master/tf_rl_tutorial.ipynb 10 | 11 | Please use the nbviewer link above instead of viewing the notebook directly from this GitHub repo. The GitHub renderer won't correctly display the images, equations, or the scatter plot. 12 | 13 | ### Dependencies 14 | * Python 3: http://www.python.org 15 | * TensorFlow: http://www.tensorflow.org 16 | * Numpy: http://www.numpy.org 17 | * Pandas: http://pandas.pydata.org 18 | 19 | ### Running the WordNet Example 20 | In addition to the tutorial notebook, there is also a script which demonstrates training and evaluating a model. 21 | 22 | 1. Clone this repository, make sure all dependencies are installed 23 | 2. Download and unpack the WordNet dataset into the /data directory (link is in the tutorial) 24 | 3. Make sure that this project directory is in your PYTHONPATH 25 | 4. cd tf_rl_tutorial 26 | 5. python wordnet_eval.py 27 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | This directory only contains the WordNet lexnames file for the visualization example. Please follow the link in the tutorial to download the WordNet data for training and testing the models. 2 | -------------------------------------------------------------------------------- /pics/l2_maxnorm.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-12 03:33Z12 ColumnsLayer 1CL2 PenaltyMax-norm 4 | -------------------------------------------------------------------------------- /pics/link.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-15 21:07Z12 ColumnsLayer 1? 4 | -------------------------------------------------------------------------------- /pics/maxnorm.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-05 18:17Z12 ColumnsLayer 1C 4 | -------------------------------------------------------------------------------- /pics/mf1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-05 08:01Z12 ColumnsLayer 1THTDji 4 | -------------------------------------------------------------------------------- /pics/mf_dot.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-03-07 01:44Z12 ColumnsLayer 1retinarod_cellsingapore𝛳 4 | -------------------------------------------------------------------------------- /pics/rescal.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-05 07:35Z12 ColumnsLayer 1Rkkjieiej 4 | -------------------------------------------------------------------------------- /pics/tf1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-05 07:36Z12 ColumnsLayer 1ktjhijirk 4 | -------------------------------------------------------------------------------- /pics/tf1_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandiant/tf_rl_tutorial/c58d10c60cfd79b2e0661b4a49cccae8d4584c57/pics/tf1_small.png -------------------------------------------------------------------------------- /pics/transe.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Produced by OmniGraffle 6.0.5 2016-02-12 03:27Z12 ColumnsLayer 1eiejrkboathullbowrudder 4 | -------------------------------------------------------------------------------- /pics/transe2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandiant/tf_rl_tutorial/c58d10c60cfd79b2e0661b4a49cccae8d4584c57/pics/transe2.png -------------------------------------------------------------------------------- /pics/tsne_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandiant/tf_rl_tutorial/c58d10c60cfd79b2e0661b4a49cccae8d4584c57/pics/tsne_small.png -------------------------------------------------------------------------------- /tf_rl_tutorial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandiant/tf_rl_tutorial/c58d10c60cfd79b2e0661b4a49cccae8d4584c57/tf_rl_tutorial/__init__.py -------------------------------------------------------------------------------- /tf_rl_tutorial/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Mandiant, A FireEye Company 2 | # Authors: Brian Jones 3 | # License: Apache 2.0 4 | 5 | ''' Model classes for "Relational Learning with TensorFlow" tutorial ''' 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from .util import ContrastiveTrainingProvider 11 | 12 | 13 | def least_squares_objective(output, target, add_bias=True): 14 | ''' Creates final model output and loss for least squares objective 15 | 16 | Args: 17 | output: Model output 18 | target: Training target placeholder 19 | add_bias: If True, a bias Variable will be added to the output 20 | 21 | Returns: 22 | tuple (final output, loss) 23 | ''' 24 | y = output 25 | if add_bias: 26 | bias = tf.Variable([0.0]) 27 | y = output + bias 28 | loss = tf.reduce_sum(tf.square(y - target)) 29 | return y, loss 30 | 31 | 32 | def logistic_objective(output, target, add_bias=True): 33 | ''' Creates final model output and loss for logistic objective 34 | 35 | Args: 36 | output: Model output 37 | target: Training target placeholder 38 | add_bias: If True, a bias Variable will be added to the output 39 | 40 | Returns: 41 | tuple (final output, loss) 42 | ''' 43 | y = output 44 | if add_bias: 45 | bias = tf.Variable([0.0]) 46 | y = output + bias 47 | sig_y = tf.clip_by_value(tf.sigmoid(y), 0.001, 0.999) # avoid NaNs 48 | loss = -tf.reduce_sum(target*tf.log(sig_y) + (1-target)*tf.log(1-sig_y)) 49 | return sig_y, loss 50 | 51 | 52 | def ranking_margin_objective(output, margin=1.0): 53 | ''' Create final model output and loss for pairwise ranking margin objective 54 | 55 | Loss for single pair (f(p), f(n)) = [margin - f(p) + f(n)]+ 56 | This only works when given model output on alternating positive/negative 57 | pairs: [pos,neg,pos,neg,...]. TODO: check target placeholder 58 | at runtime to make sure this is the case? 59 | 60 | Args: 61 | output: Model output 62 | margin: The margin value for the pairwise hinge loss 63 | 64 | Returns: 65 | tuple (final output, loss) 66 | ''' 67 | y_pairs = tf.reshape(output, [-1,2]) # fold: 1 x n -> [n/2 x 2] 68 | pos_scores, neg_scores = tf.split(1, 2, y_pairs) # separate pairs 69 | hinge_losses = tf.nn.relu(margin - pos_scores + neg_scores) 70 | total_hinge_loss = tf.reduce_sum(hinge_losses) 71 | return output, total_hinge_loss 72 | 73 | 74 | def sparse_maxnorm_update(var_matrix, indices, maxnorm=1.0): 75 | '''Sparse update operation that ensures selected rows in var_matrix 76 | do not have a Euclidean norm greater than maxnorm. Rows that exceed 77 | it are scaled to length. 78 | 79 | Args: 80 | var_matrix: 2D mutable tensor (Variable) to operate on 81 | indices: 1D tensor with the row indices to constrain 82 | maxnorm: the maximum Euclidean norm 83 | 84 | Returns: 85 | An operation that will update var_matrix when run in a Session 86 | ''' 87 | selected_rows = tf.nn.embedding_lookup(var_matrix, indices) 88 | row_norms = tf.sqrt(tf.reduce_sum(tf.square(selected_rows), 1)) 89 | scaling = maxnorm / tf.maximum(row_norms, maxnorm) 90 | scaled = selected_rows * tf.expand_dims(scaling, 1) 91 | return tf.scatter_update(var_matrix, indices, scaled) 92 | 93 | 94 | def dense_maxnorm_update(var_matrix, maxnorm=1.0): 95 | '''Dense update operation that ensures all rows in var_matrix 96 | do not have a Euclidean norm greater than maxnorm. Rows that exceed 97 | it are scaled to length. 98 | 99 | Args: 100 | var_matrix: 2D mutable tensor (Variable) to operate on 101 | maxnorm: the maximum Euclidean norm 102 | 103 | Returns: 104 | An operation that will update var_matrix when run in a Session 105 | ''' 106 | row_norms = tf.sqrt(tf.reduce_sum(tf.square(var_matrix), 1)) 107 | scaling = maxnorm / tf.maximum(row_norms, maxnorm) 108 | scaled = var_matrix * tf.expand_dims(scaling, 1) 109 | return tf.assign(var_matrix, scaled) 110 | 111 | 112 | def dense_maxnorm(var_matrix, maxnorm=1.0): 113 | '''Similar to dense_maxnorm_update(), except this returns a new Tensor 114 | instead of an operation that modifies var_matrix. 115 | 116 | Args: 117 | var_matrix: 2D tensor (Variable) 118 | maxnorm: the maximum Euclidean norm 119 | 120 | Returns: 121 | A new tensor where all rows have been scaled as necessary 122 | ''' 123 | axis_norms = tf.sqrt(tf.reduce_sum(tf.square(var_matrix), 1)) 124 | scaling = maxnorm / tf.maximum(axis_norms, maxnorm) 125 | return var_matrix * tf.expand_dims(scaling, 1) 126 | 127 | 128 | class BaseModel(object): 129 | ''' Base class for embedding-based relational learning models that use 130 | maxnorm regularization. Subclasses must implement _create_model() and 131 | populate self.train_step, and can optionally populate self.post_step for 132 | post-processing. 133 | 134 | Note: When model_type is 'ranking_margin', the mini-batch provider returned 135 | by _create_batch_provider() must provide instances in alternating 136 | pos/neg pairs: [pos, neg, pos, neg, ...]. This is satisfied when using 137 | ContrastiveTrainingProvider; be careful if you use a different one. 138 | 139 | Args: 140 | embedding_size: Embedding vector length 141 | maxnorm: Maximum Euclidean norm for embedding vectors 142 | batch_pos_cnt: Number of positive examples to use in each mini-batch 143 | max_iter: Maximum number of optimization iterations to perform 144 | model_type: Possible values: 145 | 'least_squares': squared loss on 0/1 targets 146 | 'logistic': sigmoid link function, crossent loss on 0/1 targets 147 | 'ranking_margin': ranking margin on pos/neg pairs 148 | add_bias: If True, a bias Variable will be added to the output for 149 | least_squares and logistic models. 150 | opt: An optimizer object to use. If None, the default optimizer is 151 | tf.train.AdagradOptimizer(1.0) 152 | 153 | TODO: add support for other regularizers like L2 154 | ''' 155 | 156 | def __init__(self, embedding_size, maxnorm=1.0, 157 | batch_pos_cnt=100, max_iter=1000, 158 | model_type='least_squares', add_bias=True, 159 | opt=None): 160 | self.embedding_size = embedding_size 161 | self.maxnorm = maxnorm 162 | self.batch_pos_cnt = batch_pos_cnt 163 | self.max_iter = max_iter 164 | self.model_type = model_type 165 | self.add_bias = add_bias 166 | if opt is None: 167 | opt = tf.train.AdagradOptimizer(1.0) 168 | self.opt = opt 169 | self.sess = None 170 | self.train_step = None 171 | self.post_step = None 172 | self.graph = tf.Graph() 173 | with self.graph.as_default(): 174 | self.head_input = tf.placeholder(tf.int32, shape=[None]) 175 | self.rel_input = tf.placeholder(tf.int32, shape=[None]) 176 | self.tail_input = tf.placeholder(tf.int32, shape=[None]) 177 | self.target = tf.placeholder(tf.float32, shape=[None]) 178 | 179 | def _create_model(self, train_triples): 180 | ''' Subclasses must build Graph and set self.train_step ''' 181 | raise Exception('subclass must implement') 182 | 183 | def _create_batch_provider(self, train_triples): 184 | ''' Default implementation ''' 185 | return ContrastiveTrainingProvider(train_triples, self.batch_pos_cnt) 186 | 187 | def _create_output_and_loss(self, raw_output): 188 | if self.model_type == 'least_squares': 189 | return least_squares_objective(raw_output, self.target, self.add_bias) 190 | elif self.model_type == 'logistic': 191 | return logistic_objective(raw_output, self.target, self.add_bias) 192 | elif self.model_type == 'ranking_margin': 193 | return ranking_margin_objective(raw_output, 1.0) 194 | else: 195 | raise Exception('Unknown model_type') 196 | 197 | def _norm_constraint_op(self, var_matrix, row_indices, maxnorm): 198 | ''' 199 | Args: 200 | var_matrix: A 2D Tensor holding the vectors to constrain (in rows) 201 | row_indices: The rows in var_tensor that are being considered for 202 | constraint application (typically embedding vectors for 203 | entities observed for a minibatch of training data). These 204 | will be used for a sparse variable update operation if the 205 | chosen optimizer only modified these entries. Otherwise 206 | a dense operation is used and row_indices are ignored. 207 | maxnorm: The maximum Euclidean norm for the rows in var_tensor 208 | 209 | Returns: 210 | An operation which will apply the constraints when run in a Session 211 | ''' 212 | # Currently, TF optimizers do not update variables with zero gradient 213 | # except AdamOptimizer 214 | if isinstance(self.opt, tf.train.AdamOptimizer): 215 | return dense_maxnorm_update(var_matrix, maxnorm) 216 | else: 217 | return sparse_maxnorm_update(var_matrix, row_indices, maxnorm) 218 | 219 | def embeddings(self): 220 | ''' Subclass should override this if it uses different embedding 221 | variables 222 | 223 | Returns: 224 | A list of pairs: [(embedding name, embedding 2D Tensor)] 225 | ''' 226 | return [('entity', self.entity_embedding_vars), 227 | ('rel', self.rel_embedding_vars)] 228 | 229 | def create_feed_dict(self, triples, labels=None, training=False): 230 | ''' Create a TensorFlow feed dict for relationship triples 231 | 232 | Args: 233 | triples: A numpy integer array of relationship triples, where each 234 | row contains [head idx, relationship idx, tail idx] 235 | labels: (optional) A label array for triples 236 | training: (optional) A flag indicating whether the feed dict is 237 | for training or test purposes. Useful for things like 238 | dropout where a dropout_probability variable is set differently 239 | in the two contexts. 240 | ''' 241 | feed_dict = {self.head_input: triples[:, 0], 242 | self.rel_input: triples[:, 1], 243 | self.tail_input: triples[:, 2]} 244 | if labels is not None: 245 | feed_dict[self.target] = labels 246 | return feed_dict 247 | 248 | def close(self): 249 | ''' Closes the TensorFlow Session object ''' 250 | self.sess.close(); 251 | 252 | def fit(self, train_triples, step_callback=None): 253 | ''' Trains the model on relationship triples 254 | 255 | Args: 256 | train_triples: A numpy integer array of relationship triples, where 257 | each row of contains [head idx, relationship idx, tail idx] 258 | step_callback: (optional) A function that will be called before each 259 | optimization step, step_callback(iteration, feed_dict) 260 | ''' 261 | if self.sess is not None: 262 | self.sess.close() 263 | self.sess = tf.Session(graph=self.graph) 264 | with self.graph.as_default(): 265 | self._create_model(train_triples) 266 | self.sess.run(tf.initialize_all_variables()) 267 | batch_provider = self._create_batch_provider(train_triples) 268 | for i in range(self.max_iter): 269 | batch_triples, batch_labels = batch_provider.next_batch() 270 | feed_dict = self.create_feed_dict(batch_triples, batch_labels, training=True) 271 | if step_callback: 272 | keep_going = step_callback(i, feed_dict) 273 | if not keep_going: 274 | break 275 | self.sess.run(self.train_step, feed_dict) 276 | if self.post_step is not None: 277 | self.sess.run(self.post_step, feed_dict) 278 | 279 | def predict(self, triples): 280 | ''' Runs a trained model on the supplied relationship triples. fit() 281 | must be called before calling this function. 282 | 283 | Args: 284 | triples: A numpy integer array of relationship triples, where each 285 | row of contains [head idx, relationship idx, tail idx] 286 | ''' 287 | feed_dict = self.create_feed_dict(triples, training=False) 288 | return self.sess.run(self.output, feed_dict=feed_dict) 289 | 290 | 291 | class Contrastive_CP(BaseModel): 292 | ''' Model with a scoring function based on CANDECOMP/PARAFAC tensor 293 | decomposition. Optimization differs, however, in the use of maxnorm 294 | regularization and contrastive negative sampling. 295 | 296 | Score for (head i, rel k, tail j) triple is: h_i^T * diag(r_k) * t_j, 297 | where h_i and t_j are embedding vectors for the head and tail entities, 298 | and r_k is an embedding vector for the relationship type. 299 | 300 | Args: 301 | embedding_size: Embedding vector length 302 | maxnorm: Maximum Euclidean norm for embedding vectors 303 | batch_pos_cnt: Number of positive examples to use in each mini-batch 304 | max_iter: Maximum number of optimization iterations to perform 305 | model_type: Possible values: 306 | 'least_squares': squared loss on 0/1 targets 307 | 'logistic': sigmoid link function, crossent loss on 0/1 targets 308 | 'ranking_margin': ranking margin on pos/neg pairs 309 | add_bias: If True, a bias Variable will be added to the output for 310 | least_squares and logistic models. 311 | opt: An optimizer object to use. If None, the default optimizer is 312 | tf.train.AdagradOptimizer(1.0) 313 | 314 | References: 315 | Kolda, Tamara G., and Brett W. Bader. "Tensor decompositions and 316 | applications." SIAM review 51.3 (2009): 455-500. 317 | ''' 318 | 319 | def _create_model(self, train_triples): 320 | # Count unique items to determine embedding matrix sizes 321 | head_cnt = len(set(train_triples[:,0])) 322 | rel_cnt = len(set(train_triples[:,1])) 323 | tail_cnt = len(set(train_triples[:,2])) 324 | init_sd = 1.0 / np.sqrt(self.embedding_size) 325 | # Embedding matrices for entities and relationship types 326 | head_init = tf.truncated_normal([head_cnt, self.embedding_size], stddev=init_sd) 327 | rel_init = tf.truncated_normal([rel_cnt, self.embedding_size], stddev=init_sd) 328 | tail_init = tf.truncated_normal([tail_cnt, self.embedding_size], stddev=init_sd) 329 | if self.maxnorm is not None: 330 | # Ensure maxnorm constraints are initially satisfied 331 | head_init = dense_maxnorm(head_init, self.maxnorm) 332 | rel_init = dense_maxnorm(rel_init, self.maxnorm) 333 | tail_init = dense_maxnorm(tail_init, self.maxnorm) 334 | self.head_embedding_vars = tf.Variable(head_init) 335 | self.rel_embedding_vars = tf.Variable(rel_init) 336 | self.tail_embedding_vars = tf.Variable(tail_init) 337 | # Embedding layer for each (head, rel, tail) triple being fed in as input 338 | head_embed = tf.nn.embedding_lookup(self.head_embedding_vars, self.head_input) 339 | rel_embed = tf.nn.embedding_lookup(self.rel_embedding_vars, self.rel_input) 340 | tail_embed = tf.nn.embedding_lookup(self.tail_embedding_vars, self.tail_input) 341 | # Model output 342 | raw_output = tf.reduce_sum(tf.mul(tf.mul(head_embed, rel_embed), tail_embed), 1) 343 | self.output, self.loss = self._create_output_and_loss(raw_output) 344 | # Optimization 345 | self.train_step = self.opt.minimize(self.loss) 346 | if self.maxnorm is not None: 347 | # Post-processing to limit embedding vars to L2 ball 348 | head_constraint = self._norm_constraint_op(self.head_embedding_vars, 349 | tf.unique(self.head_input)[0], 350 | self.maxnorm) 351 | rel_constraint = self._norm_constraint_op(self.rel_embedding_vars, 352 | tf.unique(self.rel_input)[0], 353 | self.maxnorm) 354 | tail_constraint = self._norm_constraint_op(self.tail_embedding_vars, 355 | tf.unique(self.tail_input)[0], 356 | self.maxnorm) 357 | self.post_step = [head_constraint, rel_constraint, tail_constraint] 358 | 359 | def _create_batch_provider(self, train): 360 | # CP treats head and tail entities separately 361 | return ContrastiveTrainingProvider(train, 362 | self.batch_pos_cnt, 363 | separate_head_tail=True) 364 | 365 | def embeddings(self): 366 | ''' 367 | Returns: 368 | A list of pairs: [(embedding name, embedding 2D Tensor)] 369 | ''' 370 | return [('head', self.head_embedding_vars), 371 | ('tail', self.head_embedding_vars), 372 | ('rel', self.rel_embedding_vars)] 373 | 374 | 375 | class Bilinear(BaseModel): 376 | ''' Model with a scoring function based on the bilinear formulation of 377 | RESCAL. Optimization differs, however, in the use of maxnorm 378 | regularization and contrastive negative sampling. 379 | 380 | Score for (head i, rel k, tail j) triple is: e_i^T * R_k * e_j 381 | where e_i and e_j are D-dimensional embedding vectors for the head and tail 382 | entities, and R_k is a (D x D) matrix for the relationship type 383 | acting as a bilinear operator. 384 | 385 | Args: 386 | embedding_size: Embedding vector length 387 | maxnorm: Maximum Euclidean norm for embedding vectors 388 | rel_maxnorm_mult: Multiplier for the maxnorm threshold used for 389 | relationship embeddings. Example: If maxnorm=2.0 and 390 | rel_maxnorm_mult=4.0, then the maxnorm constrain for relationships 391 | will be 2.0 * 4.0 = 8.0. 392 | batch_pos_cnt: Number of positive examples to use in each mini-batch 393 | max_iter: Maximum number of optimization iterations to perform 394 | model_type: Possible values: 395 | 'least_squares': squared loss on 0/1 targets 396 | 'logistic': sigmoid link function, crossent loss on 0/1 targets 397 | 'ranking_margin': ranking margin on pos/neg pairs 398 | add_bias: If True, a bias Variable will be added to the output for 399 | least_squares and logistic models. 400 | opt: An optimizer object to use. If None, the default optimizer is 401 | tf.train.AdagradOptimizer(1.0) 402 | 403 | References: 404 | Nickel, Maximilian, Volker Tresp, and Hans-Peter Kriegel. "A three-way 405 | model for collective learning on multi-relational data." Proceedings of 406 | the 28th international conference on machine learning (ICML-11). 2011. 407 | ''' 408 | 409 | def __init__(self, embedding_size, maxnorm=1.0, rel_maxnorm_mult=3.0, 410 | batch_pos_cnt=100, max_iter=1000, 411 | model_type='least_squares', add_bias=True, opt=None): 412 | super(Bilinear, self).__init__( 413 | embedding_size=embedding_size, 414 | maxnorm=maxnorm, 415 | batch_pos_cnt=batch_pos_cnt, 416 | max_iter=max_iter, 417 | model_type=model_type, 418 | opt=opt) 419 | self.rel_maxnorm_mult = rel_maxnorm_mult 420 | 421 | def _create_model(self, train_triples): 422 | # Count unique items to determine embedding matrix sizes 423 | entity_cnt = len(set(train_triples[:,0]).union(train_triples[:,2])) 424 | rel_cnt = len(set(train_triples[:,1])) 425 | init_sd = 1.0 / np.sqrt(self.embedding_size) 426 | # Embedding variables for all entities and relationship types 427 | entity_embedding_shape = [entity_cnt, self.embedding_size] 428 | # Relationship embeddings will be stored in flattened format to make 429 | # applying maxnorm constraints easier 430 | rel_embedding_shape = [rel_cnt, self.embedding_size * self.embedding_size] 431 | entity_init = tf.truncated_normal(entity_embedding_shape, stddev=init_sd) 432 | rel_init = tf.truncated_normal(rel_embedding_shape, stddev=init_sd) 433 | if self.maxnorm is not None: 434 | # Ensure maxnorm constraints are initially satisfied 435 | entity_init = dense_maxnorm(entity_init, self.maxnorm) 436 | rel_init = dense_maxnorm(rel_init, self.maxnorm) 437 | self.entity_embedding_vars = tf.Variable(entity_init) 438 | self.rel_embedding_vars = tf.Variable(rel_init) 439 | # Embedding layer for each (head, rel, tail) triple being fed in as input 440 | head_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.head_input) 441 | tail_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.tail_input) 442 | rel_embed = tf.nn.embedding_lookup(self.rel_embedding_vars, self.rel_input) 443 | # Reshape rel_embed into square D x D matrices 444 | rel_embed_square = tf.reshape(rel_embed, (-1, self.embedding_size, self.embedding_size)) 445 | # Reshape head_embed and tail_embed to be suitable for the matrix multiplication 446 | head_embed_row = tf.expand_dims(head_embed, 1) # embeddings as row vectors 447 | tail_embed_col = tf.expand_dims(tail_embed, 2) # embeddings as column vectors 448 | head_rel_mult = tf.batch_matmul(head_embed_row, rel_embed_square) 449 | # Output needs a squeeze into a 1d vector 450 | raw_output = tf.squeeze(tf.batch_matmul(head_rel_mult, tail_embed_col)) 451 | self.output, self.loss = self._create_output_and_loss(raw_output) 452 | # Optimization 453 | self.train_step = self.opt.minimize(self.loss) 454 | if self.maxnorm is not None: 455 | # Post-processing to limit embedding vars to L2 ball 456 | rel_maxnorm = self.maxnorm * self.rel_maxnorm_mult 457 | unique_ent_indices = tf.unique(tf.concat(0, [self.head_input, self.tail_input]))[0] 458 | unique_rel_indices = tf.unique(self.rel_input)[0] 459 | entity_constraint = self._norm_constraint_op(self.entity_embedding_vars, 460 | unique_ent_indices, 461 | self.maxnorm) 462 | rel_constraint = self._norm_constraint_op(self.rel_embedding_vars, 463 | unique_rel_indices, 464 | rel_maxnorm) 465 | self.post_step = [entity_constraint, rel_constraint] 466 | 467 | 468 | class TransE(BaseModel): 469 | ''' TransE: Translational Embeddings Model 470 | 471 | Score for (head i, rel k, tail j) triple is: d(e_i + t_k, e_i) 472 | where e_i and e_j are D-dimensional embedding vectors for the head and 473 | tail entities, t_k is a another D-dimensional vector acting as a 474 | translation, and d() is a dissimilarity function like Euclidean distance. 475 | 476 | Optimization is performed uing SGD on ranking margin loss between 477 | contrastive training pairs. Entity embeddings are contrained to lie within 478 | the unit L2 ball, relationship vectors are left unconstrained. 479 | 480 | Args: 481 | embedding_size: Embedding vector length 482 | batch_pos_cnt: Number of positive examples to use in each mini-batch 483 | max_iter: Maximum number of optimization iterations to perform 484 | dist: Distance function used in loss: 485 | 'euclidean': sqrt(sum((x - y)^2)) 486 | 'sqeuclidean': squared Euclidean, sum((x - y)^2) 487 | 'manhattan': sum of absolute differences, sum(|x - y|) 488 | margin: Margin parameter for parwise ranking hinge loss 489 | opt: An optimizer object to use. If None, the default optimizer is 490 | tf.train.AdagradOptimizer(1.0) 491 | 492 | References: 493 | Bordes, Antoine, et al. "Translating embeddings for modeling multi-relational 494 | data." Advances in Neural Information Processing Systems. 2013. 495 | ''' 496 | def __init__(self, embedding_size, batch_pos_cnt=100, 497 | max_iter=1000, dist='euclidean', 498 | margin=1.0, opt=None): 499 | super(TransE, self).__init__(embedding_size=embedding_size, 500 | maxnorm=1.0, 501 | batch_pos_cnt=batch_pos_cnt, 502 | max_iter=max_iter, 503 | model_type='ranking_margin', 504 | opt=opt) 505 | self.dist = dist 506 | self.margin = margin 507 | self.EPS = 1e-3 # for sqrt gradient when dist='euclidean' 508 | 509 | def _create_model(self, train_triples): 510 | # Count unique items to determine embedding matrix sizes 511 | entity_cnt = len(set(train_triples[:,0]).union(train_triples[:,2])) 512 | rel_cnt = len(set(train_triples[:,1])) 513 | init_sd = 1.0 / np.sqrt(self.embedding_size) 514 | # Embedding variables 515 | entity_var_shape = [entity_cnt, self.embedding_size] 516 | rel_var_shape = [rel_cnt, self.embedding_size] 517 | entity_init = tf.truncated_normal(entity_var_shape, stddev=init_sd) 518 | rel_init = tf.truncated_normal(rel_var_shape, stddev=init_sd) 519 | # Ensure maxnorm constraints are initially satisfied 520 | entity_init = dense_maxnorm(entity_init, self.maxnorm) 521 | self.entity_embedding_vars = tf.Variable(entity_init) 522 | self.rel_embedding_vars = tf.Variable(rel_init) 523 | # Embedding layer for each (head, rel, tail) triple being fed in as input 524 | head_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.head_input) 525 | tail_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.tail_input) 526 | rel_embed = tf.nn.embedding_lookup(self.rel_embedding_vars, self.rel_input) 527 | # Relationship vector acts as a translation in entity embedding space 528 | diff_vec = tail_embed - (head_embed + rel_embed) 529 | # negative dist so higher scores are better (important for pairwise loss) 530 | if self.dist == 'manhattan': 531 | raw_output = -tf.reduce_sum(tf.abs(diff_vec), 1) 532 | elif self.dist == 'euclidean': 533 | # +eps because gradients can misbehave for small values in sqrt 534 | raw_output = -tf.sqrt(tf.reduce_sum(tf.square(diff_vec), 1) + self.EPS) 535 | elif self.dist == 'sqeuclidean': 536 | raw_output = -tf.reduce_sum(tf.square(diff_vec), 1) 537 | else: 538 | raise Exception('Unknown distance type') 539 | # Model output 540 | self.output, self.loss = ranking_margin_objective(raw_output, self.margin) 541 | # Optimization with postprocessing to limit embedding vars to L2 ball 542 | self.train_step = self.opt.minimize(self.loss) 543 | unique_ent_indices = tf.unique(tf.concat(0, [self.head_input, self.tail_input]))[0] 544 | self.post_step = self._norm_constraint_op(self.entity_embedding_vars, 545 | unique_ent_indices, 546 | self.maxnorm) -------------------------------------------------------------------------------- /tf_rl_tutorial/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Mandiant, A FireEye Company 2 | # Authors: Brian Jones 3 | # License: Apache 2.0 4 | 5 | ''' Utility functions for "Relational Learning with TensorFlow" tutorial ''' 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | def df_to_idx_array(df): 12 | '''Converts a Pandas DataFrame containing relationship triples 13 | into a numpy index array. 14 | 15 | Args: 16 | df: Pandas DataFrame with columns 'head', 'rel', and 'tail'. These 17 | columns must be Categorical. See make_categorical(). 18 | 19 | Returns: 20 | A (N x 3) numpy integer index array built from the column Categorical 21 | codes. 22 | ''' 23 | idx_array = np.zeros((len(df),3), dtype=np.int) 24 | idx_array[:,0] = df['head'].cat.codes 25 | idx_array[:,1] = df['rel'].cat.codes 26 | idx_array[:,2] = df['tail'].cat.codes 27 | return idx_array 28 | 29 | 30 | def make_categorical(df, field_sets): 31 | '''Make DataFrame columns Categorical so that they can be converted to 32 | index arrays for feeding into TensorFlow models. 33 | 34 | Args: 35 | df: Pandas DataFrame with columns 'head', 'rel', and 'tail' 36 | field_sets: A tuples containing the item category sets: (head_set, 37 | rel_set, tail_set). Note that head_set and tail_set can 38 | be the same if the model embeds all entities into a common 39 | space. 40 | 41 | Returns: 42 | A new Pandas DataFrame where the 'head', 'rel', and 'tail' columns have 43 | been made Caetgorical using the supplied field_sets. 44 | ''' 45 | head_set, rel_set, tail_set = field_sets 46 | result = pd.DataFrame() 47 | result['head'] = pd.Categorical(df['head'].values, categories=head_set) 48 | result['rel'] = pd.Categorical(df['rel'].values, categories=rel_set) 49 | result['tail'] = pd.Categorical(df['tail'].values, categories=tail_set) 50 | if 'truth_flag' in df: 51 | result['truth_flag'] = df['truth_flag'] 52 | return result, df_to_idx_array(result) 53 | 54 | 55 | def corrupt(triple, field_replacements, forbidden_set, 56 | rng, fields=[0,2], max_tries=1000): 57 | ''' Produces a corrupted negative triple for the supplied positive triple 58 | using rejection sampling. Only a single field (from one in the fields 59 | argument) is changed. 60 | 61 | Args: 62 | triple: A tuple or list with 3 entries: (head, rel, tail) 63 | 64 | field_replacements: A tuple of array-like: (head entities, relationships, 65 | tail entities), each containing the (unique) items to use as 66 | replacements for the corruption 67 | 68 | forbidden_set: A set of triples (typically all known true triples) 69 | that we should not accidentally create when generating corrupted 70 | negatives. 71 | 72 | rng: Numpy RandomState object 73 | 74 | fields: The fields that can be replaced in the triple. Default is 75 | [0,2] which corresponds to the head and tail entries. [0,1,2] 76 | would randomly replace any of the three entries. 77 | 78 | max_tries: The maximum number of random corruption attempts before 79 | giving up and throwing an exception. A corruption attempt can fail 80 | if the sampled negative is a triple found in forbidden_set. 81 | 82 | Returns: 83 | A corrupted tuple (head, rel, tail) where one entry is different 84 | than the triple passed in. 85 | ''' 86 | collision = False 87 | for _ in range(max_tries): 88 | field = rng.choice(fields) 89 | replacements = field_replacements[field] 90 | corrupted = list(triple) 91 | corrupted[field] = replacements[rng.randint(len(replacements))] 92 | collision = (tuple(corrupted) in forbidden_set) 93 | if not collision: 94 | break 95 | if collision: 96 | raise Exception('Failed to sample a corruption for {} after {} tries'.format(triple, max_tries)) 97 | return corrupted 98 | 99 | 100 | def create_tf_pairs(true_df, all_true_df, rng): 101 | '''Creates a DataFrame with constrastive positive/negative pairs given 102 | true triples to constrast and set of "all known" true triples in order 103 | to avoid accidentally sampling a negative from this set. 104 | 105 | Args: 106 | true_df: Pandas DataFrame containing true triples to contrast. 107 | It must contain columns 'head', 'rel', and 'tail'. One 108 | random negative will be created for each. 109 | all_true_df: Pandas DataFrame containing "all known" true triples. 110 | This will be used to to avoid randomly generating negatives 111 | that happen to be true but were not in true_df. 112 | rng: A Numpy RandomState object 113 | 114 | Returns: 115 | A new Pandas DataFrame with alternating pos/neg pairs. If true_df 116 | contains rows [p1, p2, ..., pN], then this will contain 2N rows in the 117 | form [p1, n1, p2, n2, ..., pN, nN]. 118 | ''' 119 | all_true_tuples = set(all_true_df.itertuples(index=False)) 120 | replacements = (list(set(true_df['head'])), [], list(set(true_df['tail']))) 121 | result = [] 122 | for triple in true_df.itertuples(index=False): 123 | corruption = corrupt(triple, replacements, all_true_tuples, rng) 124 | result.append(triple) 125 | result.append(corruption) 126 | result = pd.DataFrame(result, columns=['head', 'rel', 'tail']) 127 | result['truth_flag'] = np.tile([True, False], len(true_df)) 128 | return result 129 | 130 | 131 | def threshold_and_eval(test_df, test_scores, val_df, val_scores): 132 | ''' Test set evaluation protocol from: 133 | Socher, Richard, et al. "Reasoning with neural tensor networks for 134 | knowledge base completion." Advances in Neural Information Processing 135 | Systems. 2013. 136 | 137 | Finds model output thresholds using val_df to create a binary 138 | classifier, and then measures classification accuracy on the test 139 | set scores using these thresholds. A different threshold is found 140 | for each relationship type. All Dataframes must have a 'rel' column. 141 | 142 | Args: 143 | test_df: Pandas DataFrame containing the test triples 144 | test_scores: A numpy array of test set scores, one for each triple 145 | in test_df 146 | val_df: A Pandas DataFrame containing the validation triples 147 | test_scores: A numpy array of validation set scores, one for each triple 148 | in val_df 149 | 150 | Returns: 151 | A tuple containing (accuracy, test_predictions, test_scores, threshold_map) 152 | accuracy: the overall classification accuracy on the test set 153 | test_predictions: True/False output for test set 154 | test_scores: Test set scores 155 | threshold_map: A dict containing the per-relationship thresholds 156 | found on the validation set, e.g. {'_has_part': 0.562} 157 | ''' 158 | def find_thresh(df, scores): 159 | ''' find threshold that maximizes accuracy on validation set ''' 160 | #print(df.shape, scores.shape) 161 | sorted_scores = sorted(scores) 162 | best_score, best_thresh = -np.inf, -np.inf 163 | for i in range(len(sorted_scores)-1): 164 | thresh = (sorted_scores[i] + sorted_scores[i+1]) / 2.0 165 | predictions = (scores > thresh) 166 | correct = np.sum(predictions == df['truth_flag']) 167 | if correct >= best_score: 168 | best_score, best_thresh = correct, thresh 169 | return best_thresh 170 | threshold_map = {} 171 | for relationship in set(val_df['rel']): 172 | mask = np.array(val_df['rel'] == relationship) 173 | threshold_map[relationship] = find_thresh(val_df.loc[mask], val_scores[mask]) 174 | test_entry_thresholds = np.array([threshold_map[r] for r in test_df['rel']]) 175 | test_predictions = (test_scores > test_entry_thresholds) 176 | accuracy = np.sum(test_predictions == test_df['truth_flag']) / len(test_predictions) 177 | return accuracy, test_predictions, test_scores, threshold_map 178 | 179 | 180 | def model_threshold_and_eval(model, test_df, val_df): 181 | ''' See threshold_and_eval(). This is the same except that the supplied 182 | model will be used to generate the test_scores and val_scores. 183 | 184 | Args: 185 | model: A trained relational learning model whose predict() will be 186 | called on index arrays generated from test_df and val_df 187 | test_df: Pandas DataFrame containing the test triples 188 | val_df: A Pandas DataFrame containing the validation triples 189 | 190 | Returns: 191 | A tuple containing (accuracy, test_predictions, test_scores, threshold_map) 192 | accuracy: the overall classification accuracy on the test set 193 | test_predictions: True/False output for test set 194 | test_scores: Test set scores 195 | threshold_map: A dict containing the per-relationship thresholds 196 | found on the validation set, e.g. {'_has_part': 0.562} 197 | ''' 198 | val_scores = model.predict(df_to_idx_array(val_df)) 199 | test_scores = model.predict(df_to_idx_array(test_df)) 200 | return threshold_and_eval(test_df, test_scores, val_df, val_scores) 201 | 202 | 203 | def pair_ranking_accuracy(model_output): 204 | ''' Pair ranking accuracy. This only works when model_output comes from 205 | alternating positive/negative pairs: [pos,neg,pos,neg,...,pos,neg] 206 | 207 | Returns: 208 | The fraction of pairs for which the positive example is scored higher 209 | than the negative example 210 | ''' 211 | output_pairs = np.reshape(model_output, [-1,2]) 212 | correct = np.sum(output_pairs[:,0] > output_pairs[:,1]) 213 | return float(correct) / len(output_pairs) 214 | 215 | 216 | def model_pair_ranking_accuracy(model, data): 217 | ''' See pair_ranking_accuracy(), this simply calls model.predict(data) to 218 | generate model_output 219 | 220 | Returns: 221 | The fraction of pairs for which the positive example is scored higher 222 | than the negative example 223 | ''' 224 | return pair_ranking_accuracy(model.predict(data)) 225 | 226 | 227 | class ContrastiveTrainingProvider(object): 228 | ''' Provides mini-batches for stochastic gradient descent by augmenting 229 | a set of positive training triples with random contrastive negative samples. 230 | 231 | Args: 232 | train: A 2D numpy array with positive training triples in its rows 233 | batch_pos_cnt: Number of positive examples to use in each mini-batch 234 | separate_head_tail: If True, head and tail corruptions are sampled 235 | from entity sets limited to those found in the respective location. 236 | If False, head and tail replacements are sampled from the set of 237 | all entities, regardless of location. 238 | rng: (optional) A NumPy RandomState object 239 | 240 | TODO: Allow a variable number of negative examples per positive. Right 241 | now this class always provides a single negative per positive, generating 242 | pairs: [pos, neg, pos, neg, ...] 243 | ''' 244 | 245 | def __init__(self, train, batch_pos_cnt=50, 246 | separate_head_tail=False, rng=None): 247 | self.train = train 248 | self.batch_pos_cnt = batch_pos_cnt 249 | self.separate_head_tail = separate_head_tail 250 | if rng is None: 251 | rng = np.random.RandomState() 252 | self.rng = rng 253 | self.num_examples = len(train) 254 | self.epochs_completed = 0 255 | self.index_in_epoch = 0 256 | # store set of training tuples for quickly checking negatives 257 | self.triples_set = set(tuple(t) for t in train) 258 | # replacement entities 259 | if separate_head_tail: 260 | head_replacements = list(set(train[:,0])) 261 | tail_replacements = list(set(train[:,2])) 262 | else: 263 | all_entities = set(train[:,0]).union(train[:,2]) 264 | head_replacements = tail_replacements = list(all_entities) 265 | self.field_replacements = [head_replacements, 266 | list(set(train[:,1])), 267 | tail_replacements] 268 | self._shuffle_data() 269 | 270 | def _shuffle_data(self): 271 | self.rng.shuffle(self.train) 272 | 273 | def next_batch(self): 274 | ''' 275 | Returns: 276 | A tuple (batch_triples, batch_labels): 277 | batch_triples: Bx3 numpy array of triples, where B=2*batch_pos_cnt 278 | batch_labels: numpy array with 0/1 labels for each row in 279 | batch_triples 280 | Each positive is followed by a constrasting negative, so batch_labels 281 | will alternate: [1, 0, 1, 0, ..., 1, 0] 282 | ''' 283 | start = self.index_in_epoch 284 | self.index_in_epoch += self.batch_pos_cnt 285 | if self.index_in_epoch > self.num_examples: 286 | # Finished epoch, shuffle data 287 | self.epochs_completed += 1 288 | self.index_in_epoch = self.batch_pos_cnt 289 | start = 0 290 | self._shuffle_data() 291 | end = self.index_in_epoch 292 | batch_triples = [] 293 | batch_labels = [] 294 | for positive in self.train[start:end]: 295 | batch_triples.append(positive) 296 | batch_labels.append(1.0) 297 | negative = corrupt(positive, self.field_replacements, self.triples_set, self.rng) 298 | batch_triples.append(negative) 299 | batch_labels.append(0.0) 300 | batch_triples = np.vstack(batch_triples) 301 | batch_labels = np.array(batch_labels) 302 | return batch_triples, batch_labels -------------------------------------------------------------------------------- /tf_rl_tutorial/wordnet_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Mandiant, A FireEye Company 2 | # Authors: Brian Jones 3 | # License: Apache 2.0 4 | 5 | ''' Example run script for "Relational Learning with TensorFlow" tutorial ''' 6 | 7 | import os 8 | from pprint import pprint 9 | import time 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import tensorflow as tf 15 | 16 | import tf_rl_tutorial.models as models 17 | import tf_rl_tutorial.util as util 18 | 19 | 20 | ####################################### 21 | # Data preprocessing 22 | 23 | def read_wordnet(fpath, def_df): 24 | df = pd.read_table(fpath, names=['head', 'rel', 'tail']) 25 | df['head'] = def_df.loc[df['head']]['word'].values 26 | df['tail'] = def_df.loc[df['tail']]['word'].values 27 | return df 28 | 29 | 30 | def wordnet_preprocess(train, val, test): 31 | mask = np.zeros(len(train)).astype(bool) 32 | lookup = defaultdict(list) 33 | for idx,h,_,t in train.itertuples(): 34 | lookup[(h,t)].append(idx) 35 | for h,_,t in pd.concat((val,test)).itertuples(index=False): 36 | mask[lookup[(h,t)]] = True 37 | mask[lookup[(t,h)]] = True 38 | train = train.loc[~mask] 39 | heads, tails = set(train['head']), set(train['tail']) 40 | val = val.loc[val['head'].isin(heads) & val['tail'].isin(tails)] 41 | test = test.loc[test['head'].isin(heads) & test['tail'].isin(tails)] 42 | return train, val, test 43 | 44 | 45 | ####################################### 46 | # Models used in tutorial 47 | 48 | def cp(): 49 | opt = tf.train.AdagradOptimizer(1.0) 50 | return models.Contrastive_CP(embedding_size=20, 51 | maxnorm=1.5, 52 | batch_pos_cnt=100, 53 | max_iter=30000, 54 | model_type='least_squares', 55 | add_bias=False, 56 | opt=opt) 57 | 58 | 59 | def bilinear(): 60 | opt = tf.train.AdagradOptimizer(1.0) 61 | return models.Bilinear(embedding_size=20, 62 | maxnorm=1.0, 63 | rel_maxnorm_mult=6.0, 64 | batch_pos_cnt=100, 65 | max_iter=30000, 66 | model_type='logistic', 67 | add_bias=True, 68 | opt=opt) 69 | 70 | 71 | def transe(): 72 | opt = tf.train.AdagradOptimizer(1.0) 73 | return models.TransE(embedding_size=20, 74 | batch_pos_cnt=100, 75 | max_iter=30000, 76 | dist='euclidean', 77 | margin=1.0, 78 | opt=opt) 79 | 80 | 81 | if __name__ == '__main__': 82 | 83 | ################################### 84 | # MODEL 85 | 86 | rng = np.random.RandomState(123) 87 | model = transe() #transe() # bilinear() # cp() 88 | 89 | print(model.__class__) 90 | pprint(model.__dict__) 91 | 92 | ################################### 93 | # DATA 94 | 95 | data_dir = '../data/wordnet-mlj12' 96 | definitions = pd.read_table(os.path.join(data_dir, 'wordnet-mlj12-definitions.txt'), 97 | index_col=0, names=['word', 'definition']) 98 | train = read_wordnet(os.path.join(data_dir, 'wordnet-mlj12-train.txt'), definitions) 99 | val = read_wordnet(os.path.join(data_dir, 'wordnet-mlj12-valid.txt'), definitions) 100 | test = read_wordnet(os.path.join(data_dir, 'wordnet-mlj12-test.txt'), definitions) 101 | combined_df = pd.concat((train, val, test)) 102 | all_train_entities = set(train['head']).union(train['tail']) 103 | all_train_relationships = set(train['rel']) 104 | 105 | print() 106 | print('Train shape:', train.shape) 107 | print('Validation shape:', val.shape) 108 | print('Test shape:', test.shape) 109 | print('Training entity count: {}'.format(len(all_train_entities))) 110 | print('Training relationship type count: {}'.format(len(all_train_relationships))) 111 | 112 | print() 113 | print('Preprocessing to remove instances from train that have a similar counterpart in val/test...') 114 | train,val,test = wordnet_preprocess(train, val, test) 115 | all_train_entities = set(train['head']).union(train['tail']) 116 | all_train_relationships = set(train['rel']) 117 | 118 | print('Adding negative examples to val and test...') 119 | combined_df = pd.concat((train, val, test)) 120 | val = util.create_tf_pairs(val, combined_df, rng) 121 | test = util.create_tf_pairs(test, combined_df, rng) 122 | print('Train shape:', train.shape) 123 | print('Validation shape:', val.shape) 124 | print('Test shape:', test.shape) 125 | print() 126 | 127 | if isinstance(model, models.Contrastive_CP): 128 | print('Using separate encoding for head and tail entities') 129 | field_categories = (set(train['head']), 130 | all_train_relationships, 131 | set(train['tail'])) 132 | else: 133 | print('Using the same encoding for head and tail entities') 134 | field_categories = (all_train_entities, 135 | all_train_relationships, 136 | all_train_entities) 137 | 138 | train, train_idx_array = util.make_categorical(train, field_categories) 139 | val, val_idx_array = util.make_categorical(val, field_categories) 140 | test, test_idx_array = util.make_categorical(test, field_categories) 141 | print('Train check:', train.shape, not train.isnull().values.any()) 142 | print('Val check:', val.shape, not val.isnull().values.any()) 143 | print('Test check:', test.shape, not test.isnull().values.any()) 144 | 145 | ################################### 146 | # TRAIN 147 | 148 | # Monitor progress on current training batch and validation set 149 | start = time.time() 150 | val_labels = np.array(val['truth_flag'], dtype=np.float) 151 | val_feed_dict = model.create_feed_dict(val_idx_array, val_labels) 152 | 153 | def train_step_callback(itr, batch_feed_dict): 154 | if (itr % 2000) == 0 or (itr == (model.max_iter-1)): 155 | elapsed = int(time.time() - start) 156 | avg_batch_loss = model.sess.run(model.loss, batch_feed_dict) / len(batch_feed_dict[model.target]) 157 | avg_val_loss = model.sess.run(model.loss, val_feed_dict) / len(val_labels) 158 | val_acc = util.model_pair_ranking_accuracy(model, val_idx_array) 159 | msg = 'Itr {}, train loss: {:.3}, val loss: {:.3}, val rank_acc: {:.2}, elapsed: {}' 160 | print(msg.format(itr, avg_batch_loss, avg_val_loss, val_acc, elapsed)) 161 | # Check embedding norms 162 | names,model_vars = zip(*model.embeddings()) 163 | var_vals = model.sess.run(model_vars) 164 | for name,var in zip(names, var_vals): 165 | norms = np.linalg.norm(var, axis=1) 166 | print('{} min/max norm: {:.2} {:.2}'.format(name, np.min(norms), np.max(norms))) 167 | return True 168 | 169 | print('Training...') 170 | model.fit(train_idx_array, train_step_callback) 171 | 172 | ################################### 173 | # TEST 174 | 175 | print() 176 | print('Done training, evaluating on test set.') 177 | test_labels = np.array(test['truth_flag'], dtype=np.float) 178 | test_feed_dict = model.create_feed_dict(test_idx_array, test_labels, training=False) 179 | acc, pred, scores, thresh_map = util.model_threshold_and_eval(model, test, val) 180 | 181 | print('Test set accuracy: {:.2}'.format(acc)) 182 | print() 183 | print('Relationship breakdown:') 184 | results_df = test.copy() 185 | results_df['score'] = scores 186 | results_df['prediction'] = pred 187 | results_df['is_correct'] = pred == test['truth_flag'] 188 | for rel in set(results_df['rel']): 189 | rows = results_df[results_df['rel'] == rel] 190 | n = len(rows) 191 | correct = rows['is_correct'].sum() 192 | wrong = n - correct 193 | print('acc:{:.2} rel:{}, {} / {}'.format(float(correct)/n, rel, correct, n)) 194 | --------------------------------------------------------------------------------